﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <memory>
#include <string>
#include <nn/nn_Common.h>
#include <nn/nn_Log.h>
#include <nn/crypto.h>
#include <nn/migration/idc/migration_MessageEncryptor.h>
#include <nnt/result/testResult_Assert.h>
#include <nnt/nntest.h>

#if defined(NN_BUILD_TARGET_PLATFORM_NX)
#include <nn/spl/spl_Api.h>
#endif

using namespace nn;

template <typename E, size_t MessageSize>
void TestEncryptDecrypt(E& e0, E& e1)
{
    // e0 暗号化 → e1 復号化 → e1 暗号化 → e0 復号化。

    Bit8 plain[MessageSize + 1]; // MessageSize = 0 もできるように + 1。
    for( size_t i = 0; i < sizeof(plain); i++ )
    {
        plain[i] = static_cast<Bit8>(i);
    }

    Bit8 encrypted[MessageSize + migration::idc::MessageEncryptorConfig::SequenceManagementDataSize + 1] = {};
    Bit8 mac[crypto::Aes128GcmEncryptor::MacSize] = {};

    size_t encryptedSize;
    e0.Encrypt(&encryptedSize, encrypted, sizeof(encrypted), mac, sizeof(mac), NN_STATIC_CONDITION(MessageSize > 0) ? plain : nullptr, MessageSize);
    EXPECT_EQ(MessageSize + migration::idc::MessageEncryptorConfig::SequenceManagementDataSize, encryptedSize);

    Bit8 decrypted[MessageSize + 1] = {};
    size_t decryptedSize;
    EXPECT_TRUE(e1.Decrypt(&decryptedSize, decrypted, sizeof(decrypted), encrypted, encryptedSize, mac, sizeof(mac)));
    EXPECT_EQ(MessageSize, decryptedSize);
    EXPECT_EQ(0, std::memcmp(plain, decrypted, decryptedSize));

    e1.Encrypt(&encryptedSize, encrypted, sizeof(encrypted), mac, sizeof(mac), NN_STATIC_CONDITION(MessageSize > 0) ? plain : nullptr, MessageSize);
    EXPECT_TRUE(e0.Decrypt(&decryptedSize, decrypted, sizeof(decrypted), encrypted, encryptedSize, mac, sizeof(mac)));
    EXPECT_EQ(MessageSize, decryptedSize);
    EXPECT_EQ(0, std::memcmp(plain, decrypted, decryptedSize));
}

template <typename E>
void TestDecryptFailureWithCorruptedMac(E& e0, E& e1)
{
    Bit8 plain[128] = {};

    for( size_t i = 0; i < sizeof(plain); i++ )
    {
        plain[i] = static_cast<Bit8>(i);
    }

    Bit8 encrypted[256] = {};
    Bit8 mac[crypto::Aes128GcmEncryptor::MacSize] = {};

    size_t encryptedSize;
    e0.Encrypt(&encryptedSize, encrypted, sizeof(encrypted), mac, sizeof(mac), plain, sizeof(plain));

    Bit8 decrypted[128] = {};
    size_t decryptedSize;

    // MAC を適当に操作し、Decrypt が MAC 検証で失敗することを確認。
    std::swap(mac[0], mac[1]);
    EXPECT_FALSE(e1.Decrypt(&decryptedSize, decrypted, sizeof(decrypted), encrypted, encryptedSize, mac, sizeof(mac)));
}

template <typename E>
void TestDecryptFailureWithCorruptedMessage(E& e0, E& e1)
{
    Bit8 plain[128] = {};

    for( size_t i = 0; i < sizeof(plain); i++ )
    {
        plain[i] = static_cast<Bit8>(i);
    }

    Bit8 encrypted[256] = {};
    Bit8 mac[crypto::Aes128GcmEncryptor::MacSize] = {};

    size_t encryptedSize;
    e0.Encrypt(&encryptedSize, encrypted, sizeof(encrypted), mac, sizeof(mac), plain, sizeof(plain));

    Bit8 decrypted[128] = {};
    size_t decryptedSize;

    // 暗号化されたメッセージを適当に操作し、Decrypt が MAC 検証で失敗することを確認。
    std::swap(encrypted[0], encrypted[1]);
    EXPECT_FALSE(e1.Decrypt(&decryptedSize, decrypted, sizeof(decrypted), encrypted, encryptedSize, mac, sizeof(mac)));
}

template <typename E>
void TestInvalidSequence(E& e0, E& e1)
{
    // e0 暗号化（使わない） → e0 暗号化 → e1 復号化

    Bit8 plain[128] = {};

    for( size_t i = 0; i < sizeof(plain); i++ )
    {
        plain[i] = static_cast<Bit8>(i);
    }

    Bit8 encrypted[256] = {};
    Bit8 mac[crypto::Aes128GcmEncryptor::MacSize] = {};

    size_t encryptedSize;
    e0.Encrypt(&encryptedSize, encrypted, sizeof(encrypted), mac, sizeof(mac), plain, sizeof(plain));
    e0.Encrypt(&encryptedSize, encrypted, sizeof(encrypted), mac, sizeof(mac), plain, sizeof(plain));

    Bit8 decrypted[128] = {};
    size_t decryptedSize;

    EXPECT_FALSE(e1.Decrypt(&decryptedSize, decrypted, sizeof(decrypted), encrypted, encryptedSize, mac, sizeof(mac)));
}

class MigrationIdcMessageEncryptorTest : public testing::Test
{
public:
    typedef migration::idc::MessageEncryptor EncryptorType;

    static void SetUpTestCase()
    {
#if defined(NN_BUILD_TARGET_PLATFORM_NX)
        nn::spl::InitializeForCrypto();
#endif
    }

    static void TearDownTestCase()
    {
#if defined(NN_BUILD_TARGET_PLATFORM_NX)
        nn::spl::Finalize();
#endif
    }
};

class MigrationIdcDebugMessageEncryptorTest : public testing::Test
{
public:
    typedef migration::idc::DebugMessageEncryptor EncryptorType;

    static void SetUpTestCase()
    {
#if defined(NN_BUILD_TARGET_PLATFORM_NX)
        nn::spl::InitializeForCrypto();
#endif
    }

    static void TearDownTestCase()
    {
#if defined(NN_BUILD_TARGET_PLATFORM_NX)
        nn::spl::Finalize();
#endif
    }
};

template <typename EncryptorType>
void GetEncryptor(EncryptorType* pOut)
{
    Bit8 passphrase[40] = {};
    for( size_t i = 0; i < sizeof(passphrase); i++ )
    {
        passphrase[i] = static_cast<Bit8>(i);
    }
    Bit8 salt[32] = {};
    for( size_t i = 0; i < sizeof(salt); i++ )
    {
        salt[i] = static_cast<Bit8>(i);
    }
    int iteration = 1000;
    int counter = 0;

    pOut->Initialize(passphrase, sizeof(passphrase), salt, sizeof(salt), iteration, counter);
}

/// 本番用。

TEST_F(MigrationIdcMessageEncryptorTest, EncryptDecrypt0)
{
    EncryptorType e0, e1;
    GetEncryptor(&e0);
    GetEncryptor(&e1);
    TestEncryptDecrypt<EncryptorType, 0>(e0, e1);
}

TEST_F(MigrationIdcMessageEncryptorTest, EncryptDecrypt128)
{
    EncryptorType e0, e1;
    GetEncryptor(&e0);
    GetEncryptor(&e1);
    TestEncryptDecrypt<EncryptorType, 128>(e0, e1);
}

TEST_F(MigrationIdcMessageEncryptorTest, DecryptFailureWithCorruptedMac)
{
    EncryptorType e0, e1;
    GetEncryptor(&e0);
    GetEncryptor(&e1);
    TestDecryptFailureWithCorruptedMac<>(e0, e1);
}

TEST_F(MigrationIdcMessageEncryptorTest, DecryptFailureWithCorruptedMessage)
{
    EncryptorType e0, e1;
    GetEncryptor(&e0);
    GetEncryptor(&e1);
    TestDecryptFailureWithCorruptedMessage<>(e0, e1);
}

TEST_F(MigrationIdcMessageEncryptorTest, InvalidSequence)
{
    EncryptorType e0, e1;
    GetEncryptor(&e0);
    GetEncryptor(&e1);
    TestInvalidSequence<>(e0, e1);
}

/// デバッグ用。MAC検証等は機能していないのでテストしない。

TEST_F(MigrationIdcDebugMessageEncryptorTest, EncryptDecrypt0)
{
    EncryptorType e0, e1;
    GetEncryptor(&e0);
    GetEncryptor(&e1);
    TestEncryptDecrypt<EncryptorType, 0>(e0, e1);
}

TEST_F(MigrationIdcDebugMessageEncryptorTest, EncryptDecrypt128)
{
    EncryptorType e0, e1;
    GetEncryptor(&e0);
    GetEncryptor(&e1);
    TestEncryptDecrypt<EncryptorType, 128>(e0, e1);
}

TEST_F(MigrationIdcDebugMessageEncryptorTest, InvalidSequence)
{
    EncryptorType e0, e1;
    GetEncryptor(&e0);
    GetEncryptor(&e1);
    TestInvalidSequence<>(e0, e1);
}
