﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/nn_Common.h>
#include <nn/crypto/crypto_AesDecryptor.h>
#include <nn/crypto/crypto_AesEncryptor.h>
#include <nn/crypto/crypto_XtsDecryptor.h>
#include <nn/crypto/crypto_XtsEncryptor.h>
#include <cstring>
#include <cstdlib>

#include <nnt/nntest.h>
#include "CavsParser.h"
#include "ProfileCheck.h"
#include "OctetString.h"

class Xts256Parser : public CavsParser
{
public:
    enum
    {
        Encrypt,
        Decrypt,
        Count,
        DataUnitLen,
        Key,
        Tweak,
        PlainText,
        CipherText
    };

    Xts256Parser()
    {
        m_AllowedTokens[Encrypt]     = AllowedToken("[ENCRYPT]",      false);
        m_AllowedTokens[Decrypt]     = AllowedToken("[DECRYPT]",      false);
        m_AllowedTokens[Count]       = AllowedToken("COUNT = ",       false);
        m_AllowedTokens[DataUnitLen] = AllowedToken("DataUnitLen = ", false);
        m_AllowedTokens[Key]         = AllowedToken("Key = ",         true);
        m_AllowedTokens[Tweak]       = AllowedToken("i = ",           true);
        m_AllowedTokens[PlainText]   = AllowedToken("PT = ",          true);
        m_AllowedTokens[CipherText]  = AllowedToken("CT = ",          true);
        m_TokenForUnitComplete       = Key;
    }

    virtual void TestUnit() override
    {
        ASSERT_EQ(m_UnitTokens.count(DataUnitLen), 1);
        ASSERT_EQ(m_UnitTokens.count(Key), 1);
        ASSERT_EQ(m_UnitTokens.count(Tweak), 1);
        ASSERT_EQ(m_UnitTokens.count(PlainText), 1);
        ASSERT_EQ(m_UnitTokens.count(CipherText), 1);

        const std::string& keyStr   = m_UnitTokens[Key];
        const std::string& tweakStr = m_UnitTokens[Tweak];
        const std::string& ptStr    = m_UnitTokens[PlainText];
        const std::string& ctStr    = m_UnitTokens[CipherText];
        const int lengthInBits      = std::atoi(m_UnitTokens[DataUnitLen].c_str());

        size_t keySize = keyStr.size() / 2;
        const char* key1 = keyStr.c_str();
        const char* key2 = key1 + keySize;

        // Some test vector length is not byte size aligned (e.g. 140 bits).
        // Such test vectors cannot be supported in the current interface.
        // (It cannot differenciate 130 bits in 17 bytes v.s. full 136 bits in 17 bytes)
        if (lengthInBits % 8 != 0)
            m_Skipped = true;
        else
        {
            if (Has(Decrypt))
            {
                // Decryption
                {
                    std::string resPt(ctStr.size(), char(0));

                    nn::crypto::AesDecryptor256               aes1;
                    nn::crypto::AesEncryptor256               aes2;
                    nn::crypto::XtsDecryptor<nn::crypto::AesDecryptor256> aes256Xts;

                    aes1.Initialize(key1, keySize);
                    aes2.Initialize(key2, keySize);
                    aes256Xts.Initialize(&aes1, &aes2, tweakStr.c_str(), tweakStr.size());
                    size_t ret = aes256Xts.Update(const_cast<char*>(resPt.c_str()), resPt.size(), ctStr.c_str(), ctStr.size());
                    aes256Xts.Finalize(const_cast<char*>(resPt.c_str()) + ret, resPt.size() - ret);

                    ASSERT_TRUE(resPt == ptStr);
                }

                // In-place decryption
                {
                    std::string resPt(ctStr.begin(), ctStr.end());

                    nn::crypto::AesDecryptor256               aes1;
                    nn::crypto::AesEncryptor256               aes2;
                    nn::crypto::XtsDecryptor<nn::crypto::AesDecryptor256> aes256Xts;

                    aes1.Initialize(key1, keySize);
                    aes2.Initialize(key2, keySize);
                    aes256Xts.Initialize(&aes1, &aes2, tweakStr.c_str(), tweakStr.size());
                    size_t ret = aes256Xts.Update(const_cast<char*>(resPt.c_str()), resPt.size(), resPt.c_str(), resPt.size());
                    aes256Xts.Finalize(const_cast<char*>(resPt.c_str()) + ret, resPt.size() - ret);

                    ASSERT_TRUE(resPt == ptStr);
                }
            }
            else
            {
                // Encryption
                {
                    std::string resCt(ptStr.size(), char(0));

                    nn::crypto::AesEncryptor256               aes1;
                    nn::crypto::AesEncryptor256               aes2;
                    nn::crypto::XtsEncryptor<nn::crypto::AesEncryptor256> aes256Xts;

                    aes1.Initialize(key1, keySize);
                    aes2.Initialize(key2, keySize);
                    aes256Xts.Initialize(&aes1, &aes2, tweakStr.c_str(), tweakStr.size());
                    size_t ret = aes256Xts.Update(const_cast<char*>(resCt.c_str()), resCt.size(), ptStr.c_str(), ptStr.size());
                    aes256Xts.Finalize(const_cast<char*>(resCt.c_str()) + ret, resCt.size() - ret);

                    ASSERT_TRUE(resCt == ctStr);
                }

                // In-place encryption
                {
                    std::string resCt(ptStr.begin(), ptStr.end());

                    nn::crypto::AesEncryptor256               aes1;
                    nn::crypto::AesEncryptor256               aes2;
                    nn::crypto::XtsEncryptor<nn::crypto::AesEncryptor256> aes256Xts;

                    aes1.Initialize(key1, keySize);
                    aes2.Initialize(key2, keySize);
                    aes256Xts.Initialize(&aes1, &aes2, tweakStr.c_str(), tweakStr.size());
                    size_t ret = aes256Xts.Update(const_cast<char*>(resCt.c_str()), resCt.size(), resCt.c_str(), resCt.size());
                    aes256Xts.Finalize(const_cast<char*>(resCt.c_str()) + ret, resCt.size() - ret);

                    ASSERT_TRUE(resCt == ctStr);
                }
            }
        }

        m_UnitTokens.erase( Key );
    }
};

TEST(Aes256Xts, rsp)
{
    Xts256Parser tester;
    tester.TestFile("XTSTestVectors/128_hex_str/XTSGenAES256.rsp", 600);
}
