﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <memory>
#include <string>
#include <nn/nn_Common.h>
#include <nn/nn_Log.h>
#include <nn/crypto.h>
#include <nn/migration/idc/migration_CommandApi.h>
#include <nn/migration/idc/migration_CommandTypes.h>
#include <nn/migration/idc/migration_KeyExchangeCommandEncryptor.h>
#include <nn/migration/idc/migration_MessageEncryptor.h>
#include <nn/migration/idc/migration_SharedBufferConnection.h>
#include <nn/migration/idc/migration_SharedBufferConnectionManager.h>
#include <nn/migration/idc/migration_UserCommandMediator.h>
#include <nnt/result/testResult_Assert.h>
#include <nnt/nntest.h>

#if defined(NN_BUILD_TARGET_PLATFORM_NX)
#include <nn/spl/spl_Api.h>
#endif

using namespace nn;
namespace midc = nn::migration::idc;

class MigrationIdcCommandTest : public testing::Test
{
protected:
    static void SetUpTestCase()
    {
#if defined(NN_BUILD_TARGET_PLATFORM_NX)
        spl::InitializeForFs();
#endif
    }

    static void TearDownTestCase()
    {
#if defined(NN_BUILD_TARGET_PLATFORM_NX)
        spl::Finalize();
#endif
    }
};

// Initiate0Request -> Iniiate0Response -> Initiate1Request -> Initiate1Response
template <typename KE, typename ME>
void TestInitiate()
{
    // 共通パラメータ。
    midc::KeyExchangeCommandConfig::Salt salt = {};
    int iteration = 1000;
    midc::MessageEncryptorConfig::Counter counter = 0;

    // Client 側のパラメータ。
    midc::KeyExchangeCommandConfig::Challenge clientChallenge = { 1,2,3,4,5 }; // 適当。
    KE clientKeyExchangeCommandEncryptor(salt);

    // Server 側のパラメータ。
    midc::KeyExchangeCommandConfig::Challenge serverChallenge = { 6,7,8,9,0 }; // 適当。
    KE serverKeyExchangeCommandEncryptor(salt);

    // Client : Initiate0Request の作成
    midc::Initiate0Request initiate0Request;
    midc::CreateInitiate0Request(&initiate0Request, clientChallenge, clientKeyExchangeCommandEncryptor);

    // Server : Initiate0Request の検証
    midc::KeyExchangeCommandConfig::Challenge receivedClientChallenge;
    EXPECT_TRUE(midc::ParseInitiate0Request(receivedClientChallenge, initiate0Request, serverKeyExchangeCommandEncryptor));
    EXPECT_TRUE(crypto::IsSameBytes(receivedClientChallenge, clientChallenge, sizeof(clientChallenge)));

    // Server : Initiate0Response の作成
    midc::Initiate0Response initiate0Response;
    midc::KeyExchangeCommandConfig::Iv0 iv0 = {};
    midc::KeyExchangeCommandConfig::Passphrase passphrase = {};
    midc::CreateInitiate0Response(&initiate0Response, iv0, passphrase, initiate0Request.data, serverChallenge, serverKeyExchangeCommandEncryptor);

    // Client : Initiate0Response の検証
    midc::KeyExchangeCommandConfig::Passphrase receivedPassphrase;
    midc::KeyExchangeCommandConfig::Challenge receivedServerChallenge;
    EXPECT_TRUE(midc::ParseInitiate0Response(receivedPassphrase, receivedServerChallenge, initiate0Response, clientChallenge, clientKeyExchangeCommandEncryptor));
    //          Client 側で受け取ったパスフレーズとチャレンジが Server 側のものと一致していることの検証
    EXPECT_TRUE(crypto::IsSameBytes(passphrase, receivedPassphrase, sizeof(passphrase)));
    EXPECT_TRUE(crypto::IsSameBytes(serverChallenge, receivedServerChallenge, sizeof(serverChallenge)));

    // Client : Initiate1Request の作成
    midc::Initiate1Request initiate1Request;
    ME clientEncryptor;
    clientEncryptor.Initialize(receivedPassphrase, sizeof(receivedPassphrase), salt, sizeof(salt), iteration, counter);
    midc::CreateInitiate1Request(&initiate1Request, receivedServerChallenge, clientEncryptor);

    // Server : Initiate1Request の検証
    ME serverEncryptor;
    serverEncryptor.Initialize(passphrase, sizeof(passphrase), salt, sizeof(salt), iteration, counter);
    EXPECT_TRUE(midc::ParseInitiate1Request(initiate1Request, serverChallenge, serverEncryptor));

    // Server : Initiate1Response の作成
    midc::Initiate1Response initiate1Response;
    midc::CreateInitiate1Response(&initiate1Response, serverEncryptor);

    // Client : Initiate1Response の検証
    EXPECT_TRUE(midc::ParseInitiate1Response(initiate1Response, clientEncryptor));
}

TEST_F(MigrationIdcCommandTest, Initiate)
{
    TestInitiate<midc::DebugKeyExchangeCommandEncryptor, midc::DebugMessageEncryptor>();
    TestInitiate<midc::DebugKeyExchangeCommandEncryptor, midc::MessageEncryptor>();
#if defined (NN_BUILD_CONFIG_OS_HORIZON)
    TestInitiate<midc::KeyExchangeCommandEncryptor, midc::DebugMessageEncryptor>();
    TestInitiate<midc::KeyExchangeCommandEncryptor, midc::MessageEncryptor>();
#endif
}

// Resume0Request -> Resume0Response -> Resume1Request -> Resume1Response
template <typename ME>
void TestResume()
{
    // 共通パラメータ
    midc::KeyExchangeCommandConfig::Salt salt = {};
    int iteration = 1000;
    midc::MessageEncryptorConfig::Counter counter = 0;
    midc::KeyExchangeCommandConfig::Passphrase passphrase = {}; // 交換済み

    // Client 側のパラメータ
    midc::KeyExchangeCommandConfig::Challenge clientChallenge = { 1,2,3,4,5 };
    ME clientEncryptor;
    clientEncryptor.Initialize(passphrase, sizeof(passphrase), salt, sizeof(salt), iteration, counter);

    // Server 側のパラメータ
    midc::KeyExchangeCommandConfig::Challenge serverChallenge = { 6,7,8,9,0 };
    ME serverEncryptor;
    serverEncryptor.Initialize(passphrase, sizeof(passphrase), salt, sizeof(salt), iteration, counter);

    // Client : Resume0Request の作成
    midc::Resume0Request resume0Request;
    midc::CreateResume0Request(&resume0Request, clientChallenge, clientEncryptor);

    // Server : Resume0Request の検証
    midc::KeyExchangeCommandConfig::Challenge receivedClientChallenge;
    EXPECT_TRUE(midc::ParseResume0Request(receivedClientChallenge, resume0Request, serverEncryptor));
    EXPECT_TRUE(crypto::IsSameBytes(clientChallenge, receivedClientChallenge, midc::KeyExchangeCommandConfig::ChallengeSize));

    // Server : Resume0Response の作成
    midc::Resume0Response resume0Response;
    midc::CreateResume0Response(&resume0Response, receivedClientChallenge, serverChallenge, serverEncryptor);

    // Client : Resume0Response の検証
    midc::KeyExchangeCommandConfig::Challenge receivedServerChallenge;
    EXPECT_TRUE(midc::ParseResume0Response(receivedServerChallenge, resume0Response, clientChallenge, clientEncryptor));

    // Client : Resume1Request の作成
    midc::Resume1Request resume1Request;
    midc::CreateResume1Request(&resume1Request, receivedServerChallenge, clientEncryptor);

    // Server : Resume1Request の検証
    EXPECT_TRUE(midc::ParseResume1Reqeust(resume1Request, serverChallenge, serverEncryptor));

    // Server : Resume1Response の作成
    midc::Resume1Response resume1Response;
    midc::CreateResume1Response(&resume1Response, serverEncryptor);

    // Client : Resume1Response の検証
    EXPECT_TRUE(midc::ParseResume1Response(resume1Response, clientEncryptor));
}

TEST_F(MigrationIdcCommandTest, Resume)
{
    TestResume<midc::DebugMessageEncryptor>();
    TestResume<midc::MessageEncryptor>();
}

// TerminateRequest -> TerminateResponse
template <typename ME>
void TestTerminate()
{
    // 共通パラメータ
    midc::KeyExchangeCommandConfig::Salt salt = {};
    int iteration = 1000;
    midc::MessageEncryptorConfig::Counter counter = 0;
    midc::KeyExchangeCommandConfig::Passphrase passphrase = {}; // 交換済み

    // Client 側のパラメータ
    ME clientEncryptor;
    clientEncryptor.Initialize(passphrase, sizeof(passphrase), salt, sizeof(salt), iteration, counter);

    // Server 側のパラメータ
    ME serverEncryptor;
    serverEncryptor.Initialize(passphrase, sizeof(passphrase), salt, sizeof(salt), iteration, counter);

    // Client : TerminateRequest の作成
    midc::TerminateRequest request;
    midc::CreateTerminateRequest(&request, clientEncryptor);

    // Server : TerminateRequest の検証
    EXPECT_TRUE(midc::ParseTerminateRequest(request, serverEncryptor));

    // Server : TerminateResponse の作成
    midc::TerminateResponse response;
    midc::CreateTerminateResponse(&response, serverEncryptor);

    // Client : TerminateResponse の検証
    EXPECT_TRUE(midc::ParseTerminateResponse(response, clientEncryptor));
}

TEST_F(MigrationIdcCommandTest, Terminate)
{
    TestTerminate<midc::DebugMessageEncryptor>();
    TestTerminate<midc::MessageEncryptor>();
}

// * -> ErrorResponse
template <typename ME>
void TestError()
{
    // 共通パラメータ
    midc::KeyExchangeCommandConfig::Salt salt = {};
    int iteration = 1000;
    midc::MessageEncryptorConfig::Counter counter = 0;
    midc::KeyExchangeCommandConfig::Passphrase passphrase = {}; // 交換済み

    // Client 側のパラメータ
    ME clientEncryptor;
    clientEncryptor.Initialize(passphrase, sizeof(passphrase), salt, sizeof(salt), iteration, counter);

    // Server 側のパラメータ
    ME serverEncryptor;
    serverEncryptor.Initialize(passphrase, sizeof(passphrase), salt, sizeof(salt), iteration, counter);

    // Client 送信 -> Server 受信は両方省く。両方省けば encryptor の状態は一致する。

    // Server : ErrorResponse の作成
    midc::ErrorInfo errorInfo = {};
    errorInfo.errorId = 3; // 適当な値。
    midc::ErrorResponse response;
    midc::CreateErrorResponse(&response, errorInfo, serverEncryptor);

    // Client : ErrorResponse の検証
    midc::ErrorInfo receivedErrorInfo;
    EXPECT_TRUE(midc::ParseErrorResponse(&receivedErrorInfo, response, clientEncryptor));
    EXPECT_EQ(errorInfo.errorId, receivedErrorInfo.errorId);
}

TEST_F(MigrationIdcCommandTest, Error)
{
    TestError<midc::DebugMessageEncryptor>();
    TestError<midc::MessageEncryptor>();
}

// Invalid Command. 間違った Challenge/Salt などを使用して Parse を失敗させる。
template <typename KE, typename ME>
void TestInvalid()
{
    // 共通パラメータ。
    midc::KeyExchangeCommandConfig::Salt salt = {};
    midc::KeyExchangeCommandConfig::Salt wrongSalt = { 1 };
    int iteration = 1000;
    midc::MessageEncryptorConfig::Counter counter = 0;

    // Client 側のパラメータ。
    midc::KeyExchangeCommandConfig::Challenge clientChallenge = { 1,2,3,4,5 }; // 適当。
    midc::KeyExchangeCommandConfig::Challenge wrongClientChallenge = { 9,2,3,4,5 }; // 適当。
    KE clientKeyExchangeCommandEncryptor(salt);

    // Server 側のパラメータ。
    midc::KeyExchangeCommandConfig::Challenge serverChallenge = { 6,7,8,9,0 }; // 適当。
    midc::KeyExchangeCommandConfig::Challenge wrongServerChallenge = { 1,7,8,9,0 }; // 適当。
    KE serverKeyExchangeCommandEncryptor(salt);
    KE wrongServerKeyExchangeCommandEncryptor(wrongSalt);

    // Client : Initiate0Request の作成
    midc::Initiate0Request initiate0Request;
    midc::CreateInitiate0Request(&initiate0Request, clientChallenge, clientKeyExchangeCommandEncryptor);

    // Server : Initiate0Request の検証（異なる Salt を使用）。Debug* は検証情報がなく Salt 不一致でも true になってしまうので飛ばす。
    if( typeid(KE) != typeid(midc::DebugKeyExchangeCommandEncryptor) )
    {
        midc::KeyExchangeCommandConfig::Challenge receivedClientChallenge;
        EXPECT_FALSE(midc::ParseInitiate0Request(receivedClientChallenge, initiate0Request, wrongServerKeyExchangeCommandEncryptor));
    }

    // Server : Initiate0Response の作成
    midc::Initiate0Response initiate0Response;
    midc::KeyExchangeCommandConfig::Iv0 iv0 = {};
    midc::KeyExchangeCommandConfig::Passphrase passphrase = {};
    midc::CreateInitiate0Response(&initiate0Response, iv0, passphrase, initiate0Request.data, serverChallenge, serverKeyExchangeCommandEncryptor);

    // Client : Initiate0Response の検証（1度異なる challenge を使用）
    midc::KeyExchangeCommandConfig::Passphrase receivedPassphrase;
    midc::KeyExchangeCommandConfig::Challenge receivedServerChallenge;
    EXPECT_FALSE(midc::ParseInitiate0Response(receivedPassphrase, receivedServerChallenge, initiate0Response, wrongClientChallenge, clientKeyExchangeCommandEncryptor));
    EXPECT_TRUE(midc::ParseInitiate0Response(receivedPassphrase, receivedServerChallenge, initiate0Response, clientChallenge, clientKeyExchangeCommandEncryptor));
    EXPECT_TRUE(crypto::IsSameBytes(passphrase, receivedPassphrase, sizeof(passphrase)));
    EXPECT_TRUE(crypto::IsSameBytes(serverChallenge, receivedServerChallenge, sizeof(serverChallenge)));

    // Client : Initiate1Request の作成
    midc::Initiate1Request initiate1Request;
    ME clientEncryptor;
    clientEncryptor.Initialize(receivedPassphrase, sizeof(receivedPassphrase), salt, sizeof(salt), iteration, counter);
    midc::CreateInitiate1Request(&initiate1Request, receivedServerChallenge, clientEncryptor);

    // Server : Initiate1Request の検証（1度異なる challenge を使用）
    ME serverEncryptor;
    serverEncryptor.Initialize(passphrase, sizeof(passphrase), salt, sizeof(salt), iteration, counter);
    EXPECT_FALSE(midc::ParseInitiate1Request(initiate1Request, wrongServerChallenge, serverEncryptor));
}

TEST_F(MigrationIdcCommandTest, Invalid)
{
    TestInvalid<midc::DebugKeyExchangeCommandEncryptor, midc::DebugMessageEncryptor>();
    TestInvalid<midc::DebugKeyExchangeCommandEncryptor, midc::MessageEncryptor>();
#if defined (NN_BUILD_CONFIG_OS_HORIZON)
    TestInvalid<midc::KeyExchangeCommandEncryptor, midc::DebugMessageEncryptor>();
    TestInvalid<midc::KeyExchangeCommandEncryptor, midc::MessageEncryptor>();
#endif
}

template <typename ME>
void TestUser(size_t commandSize, size_t blockSize)
{
    NN_LOG("CommandSize : %llu, BlockSize : %llu.\n", commandSize, blockSize);
    // 共通パラメータ
    midc::KeyExchangeCommandConfig::Salt salt = {};
    int iteration = 1000;
    midc::MessageEncryptorConfig::Counter counter = 0;
    midc::KeyExchangeCommandConfig::Passphrase passphrase = {}; // 交換済み

    // Client 側のパラメータ
    ME clientEncryptor;
    clientEncryptor.Initialize(passphrase, sizeof(passphrase), salt, sizeof(salt), iteration, counter);
    midc::UserCommandEncryptor<ME> userCommandEncryptor(clientEncryptor);

    // Server 側のパラメータ
    ME serverEncryptor;
    serverEncryptor.Initialize(passphrase, sizeof(passphrase), salt, sizeof(salt), iteration, counter);
    midc::UserCommandDecryptor<ME> userCommandDecryptor(serverEncryptor);

    // 送受信するユーザーコマンド
    Bit8 userCommand[1024];
    for( size_t i = 0; i < sizeof(userCommand); i++ )
    {
        userCommand[i] = static_cast<Bit8>(i);
    }

    // 暗号化後のユーザーコマンド（長さは適当に余裕を持たせている）
    Bit8 encryptedUserCommand[sizeof(userCommand) * 2];

    // 復号化後のユーザーコマンド
    Bit8 decryptedUserCommand[sizeof(userCommand)];

    // Client : ユーザーコマンド生成
    midc::UserCommandHeader header;
    userCommandEncryptor.Initialize(&header, commandSize, blockSize);

    size_t totalEncryptedSize = 0;
    size_t totalUsedCommandSize = 0;
    while( NN_STATIC_CONDITION(true) )
    {
        size_t encryptedSize;
        userCommandEncryptor.Update(
            &encryptedSize,
            encryptedUserCommand + totalEncryptedSize, sizeof(encryptedUserCommand) - totalEncryptedSize,
            userCommand + totalUsedCommandSize, std::min(blockSize, commandSize - totalUsedCommandSize));

        totalEncryptedSize += encryptedSize;
        totalUsedCommandSize += std::min(blockSize, commandSize - totalUsedCommandSize);

        if( totalUsedCommandSize == commandSize )
        {
            break;
        }
    }

    NN_LOG("Total Encrypted Size : %llu\n", totalEncryptedSize);

    // Server : ユーザーコマンド解釈
    EXPECT_TRUE(userCommandDecryptor.Initialize(header));

    size_t totalDecryptedSize = 0;
    size_t totalUsedEncryptedSize = 0;
    while( NN_STATIC_CONDITION(true) )
    {
        size_t decryptedSize = 0;
        EXPECT_TRUE(userCommandDecryptor.Update(
            &decryptedSize,
            decryptedUserCommand + totalDecryptedSize, sizeof(decryptedUserCommand) - totalDecryptedSize,
            encryptedUserCommand + totalUsedEncryptedSize, std::min(blockSize + midc::MessageEncryptorConfig::SequenceManagementDataSize + 16, totalEncryptedSize - totalUsedEncryptedSize)));

        totalDecryptedSize += decryptedSize;
        totalUsedEncryptedSize += midc::GetEncryptedUserCommandBlockSize(decryptedSize);

        if( totalDecryptedSize == commandSize )
        {
            break;
        }
    }
    EXPECT_EQ(totalEncryptedSize, totalUsedEncryptedSize);
    EXPECT_TRUE(crypto::IsSameBytes(userCommand, decryptedUserCommand, commandSize));
}

// User Command。
TEST_F(MigrationIdcCommandTest, User)
{
    for( size_t commandSize = 16; commandSize <= 1024; commandSize += 16 )
    {
        for( size_t blockSize = 64; blockSize < 128; blockSize += 16 )
        {
            TestUser<midc::DebugMessageEncryptor>(commandSize, blockSize);
            TestUser<midc::MessageEncryptor>(commandSize, blockSize);
        }
    }
}

template <typename ME>
void TestUserInvalidSize()
{
    // 共通パラメータ
    midc::KeyExchangeCommandConfig::Salt salt = {};
    int iteration = 1000;
    midc::MessageEncryptorConfig::Counter counter = 0;
    midc::KeyExchangeCommandConfig::Passphrase passphrase = {}; // 交換済み

    // Client 側のパラメータ
    ME clientEncryptor;
    clientEncryptor.Initialize(passphrase, sizeof(passphrase), salt, sizeof(salt), iteration, counter);
    midc::UserCommandEncryptor<ME> userCommandEncryptor(clientEncryptor);

    // Server 側のパラメータ
    ME serverEncryptor;
    serverEncryptor.Initialize(passphrase, sizeof(passphrase), salt, sizeof(salt), iteration, counter);
    midc::UserCommandDecryptor<ME> userCommandDecryptor(serverEncryptor);

    // 送受信するユーザーコマンド
    Bit8 userCommand[1024] = {};

    // 暗号化後のユーザーコマンド（長さは適当に余裕を持たせている）
    Bit8 encryptedUserCommand[sizeof(userCommand) * 2];

    // 復号化後のユーザーコマンド
    Bit8 decryptedUserCommand[sizeof(userCommand)];

    const size_t CommandSize = 128;
    const size_t BlockSize = 64;

    midc::UserCommandHeader header;

    userCommandEncryptor.Initialize(&header, CommandSize, BlockSize);

    // Client : ユーザーコマンド生成
    size_t encryptedSize;
    userCommandEncryptor.Update(
        &encryptedSize,
        encryptedUserCommand, sizeof(encryptedUserCommand),
        userCommand, CommandSize);

    // Server : ユーザーコマンド解釈（適当な値を代入してデータを破壊して失敗させる）
    EXPECT_TRUE(userCommandDecryptor.Initialize(header));

    encryptedUserCommand[0] = static_cast<Bit8>(0xFF);
    size_t decryptedSize;
    EXPECT_FALSE(userCommandDecryptor.Update(
        &decryptedSize,
        decryptedUserCommand, sizeof(decryptedUserCommand),
        encryptedUserCommand, encryptedSize));
}

// User Command (Invalid Size)。
TEST_F(MigrationIdcCommandTest, UserInvalidSize)
{
    TestUserInvalidSize<midc::MessageEncryptor>();
    TestUserInvalidSize<midc::DebugMessageEncryptor>();
}
