﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <nn/nn_Common.h>
#include <nn/nn_SdkAssert.h>
#include <nn/nn_Result.h>
#include <nn/migration/idc/migration_CommandTypes.h>
#include <nn/util/util_Endian.h>
#include <nn/util/util_IntUtil.h>

namespace nn { namespace migration { namespace idc {

const size_t UserCommandAuthenticationSize = MessageEncryptorConfig::SequenceManagementDataSize + crypto::Aes128GcmEncryptor::MacSize;

size_t GetEncryptedUserCommandBlockSize(size_t blockSize) NN_NOEXCEPT;
size_t GetEncryptedUserCommandSize(size_t blockSize, size_t plainSize) NN_NOEXCEPT;
size_t GetDecryptedUserCommandSize(size_t blockSize, size_t encryptedSize) NN_NOEXCEPT;
bool IsAcceptableUserCommandSize(size_t blockSize, size_t commandSize) NN_NOEXCEPT;

template <typename MsgEncryptor>
class UserCommandEncryptor
{
    NN_DISALLOW_COPY(UserCommandEncryptor);
    NN_DISALLOW_MOVE(UserCommandEncryptor);
public:
    explicit UserCommandEncryptor(MsgEncryptor& encryptor) NN_NOEXCEPT;

    void Initialize(UserCommandHeader* pOutHeader, size_t userCommandSize, size_t blockSize) NN_NOEXCEPT;
    /**
    * @brief        入力データを暗号化します。暗号化されたデータには検証用のシーケンス管理データとMACが付与されます。
    *
    * @param[out]   pOutEncryptedSize outEncrypted に書き込まれた暗号化データのサイズ。
    * @param[out]   outEncrypted 暗号化データを書き込むバッファ。
    * @param[in]    outEncryptedSize outEncrypted のサイズ。
    * @param[in]    plain 入力データ。
    * @param[in]    plainSize 入力データのサイズ。
    *
    * @pre
    *               - Initialize 呼び出し済み
    *               - pOutEncryptedSize != nullptr
    *               - outEncrypted != nullptr
    *               - outEncryptedSize >= GetEncryptedUserCommandSize(blockSize, plainSize)
    *               - plain != nullptr
    *               - plainSize がブロックサイズの自然数倍か、未暗号化データの大きさと一致する。
    *
    * @details      Initialize を呼び出した後、本関数を Completed が返るまで繰り返し呼ぶことでユーザーコマンドを暗号化します。
    */
    void Update(
        size_t* pOutEncryptedSize,
        void* outEncrypted, size_t outEncryptedSize,
        const void* plain, size_t plainSize) NN_NOEXCEPT;

    /**
    * @brief        Update から出力される *pOutEncryptedSize の合計（＝暗号化後のユーザーコマンドのサイズ）を取得します。
    * @pre
    *               - Initialize 呼び出し済み
    */
    size_t GetTotalOutputSize() const NN_NOEXCEPT;

private:
    size_t m_BlockSize;                 // ブロックサイズ。
    size_t m_UserCommandSize;           // ユーザーコマンドのサイズ。
    size_t m_EncryptedUserCommandSize;  // 現在までに暗号化されたユーザーコマンドのサイズ。最終的に m_UserCommandSize と一致する。
    MsgEncryptor& m_Encryptor;
};

template <typename MsgEncryptor>
class UserCommandDecryptor
{
    NN_DISALLOW_COPY(UserCommandDecryptor);
    NN_DISALLOW_MOVE(UserCommandDecryptor);
public:
    explicit UserCommandDecryptor(MsgEncryptor& encryptor) NN_NOEXCEPT;

    bool Initialize(const UserCommandHeader& header) NN_NOEXCEPT;
    /**
    * @brief        入力データを検証・復号化します。
    *
    * @param[out]   pOutPlainSize outPlain に書き込まれた復号化データのサイズ。
    * @param[out]   outPlain 復号化データを書き込むバッファ。
    * @param[in]    outPlainSize outPlain のサイズ。
    * @param[in]    encrypted 入力データ。
    * @param[in]    encryptedSize 入力データのサイズ。
    *
    * @return       検証結果。
    *
    * @pre
    *               - Initialize 呼び出し済み
    *               - pOutPlainSize != nullptr
    *               - outPlain != nullptr
    *               - outPlainSize >= GetDecryptedUserCommandSize(blockSize, encryptedSize)
    *               - encrypted != nullptr
    *               - encryptedSize が暗号化後のブロックサイズの自然数倍か、未復号化データの大きさと一致する。
    *
    * @details      Initialize を呼び出した後、本関数を Completed が返るまで繰り返し呼ぶことでユーザーコマンドを復号化します。
    */
    bool Update(
        size_t* pOutPlainSize,
        void* outPlain, size_t outPlainSize,
        const void* encrypted, size_t encryptedSize) NN_NOEXCEPT;

    /**
    * @brief        uppderSize 以下で Update の encryptedSize に与えられる有効な値を取得します。
    * @param[in]    uppserSize サイズ上限。
    * @return       サイズ。
    */
    size_t GetAcceptableUpdateSize(size_t upperSize) const NN_NOEXCEPT;

    /**
    * @brief        Update に入力する encryptedSize の合計（＝暗号化後のユーザーコマンドのサイズ）を取得します。
    * @pre
    *               - Initialize 呼び出し済み
    */
    size_t GetExpectedTotalInputSize() const NN_NOEXCEPT;

    /**
    * @brief        ブロックサイズを取得します。
    * @pre
    *               - Initialize 呼び出し済み。
    */
    size_t GetBlockSize() const NN_NOEXCEPT;

    /**
    * @brief        コマンドサイズを取得します。
    * @pre
    *               - Initialize 呼び出し済み。
    */
    size_t GetCommandSize() const NN_NOEXCEPT;

private:
    size_t m_BlockSize;                 // ブロックサイズ。
    size_t m_UserCommandSize;           // ユーザーコマンドのサイズ。
    size_t m_DecryptedUserCommandSize;  // 現在までに復号化されたユーザーコマンドのサイズ。最終的に m_UserCommandSize と一致する。
    MsgEncryptor& m_Encryptor;
};

}}}

// 以下実装。

#include <nn/crypto.h>
#include <nn/crypto/crypto_Compare.h>
#include <nn/migration/detail/migration_Log.h>

namespace nn { namespace migration { namespace idc {

// UserCommandEncryptor

template <typename MsgEncryptor>
UserCommandEncryptor<MsgEncryptor>::UserCommandEncryptor(MsgEncryptor& encryptor) NN_NOEXCEPT
    : m_BlockSize(0u)
    , m_UserCommandSize(0u)
    , m_EncryptedUserCommandSize(0u)
    , m_Encryptor(encryptor)
{
}

template <typename MsgEncryptor>
void UserCommandEncryptor<MsgEncryptor>::Initialize(UserCommandHeader* pOutHeader, size_t userCommandSize, size_t blockSize) NN_NOEXCEPT
{
    const auto SequenceManagementDataSize = MessageEncryptorConfig::SequenceManagementDataSize;
    NN_SDK_REQUIRES_NOT_NULL(pOutHeader);
    NN_SDK_REQUIRES_GREATER(userCommandSize, 0u);
    NN_SDK_REQUIRES_GREATER(blockSize, SequenceManagementDataSize + nn::crypto::GcmEncryptor<nn::crypto::AesEncryptor128>::MacSize);
    NN_UNUSED(SequenceManagementDataSize);

    m_BlockSize = blockSize;
    m_UserCommandSize = userCommandSize;
    m_EncryptedUserCommandSize = 0;

    pOutHeader->commandId = CommandKind::User;
    std::memset(pOutHeader->padding, 0, sizeof(pOutHeader->padding));

    Bit8 data[sizeof(uint64_t) + sizeof(uint64_t)];
    util::StoreBigEndian(reinterpret_cast<uint64_t*>(data), static_cast<uint64_t>(blockSize));
    util::StoreBigEndian(reinterpret_cast<uint64_t*>(data) + 1, static_cast<uint64_t>(userCommandSize));
    size_t encryptedSize;
    m_Encryptor.Encrypt(
        &encryptedSize,
        pOutHeader->data, sizeof(pOutHeader->data),
        pOutHeader->mac, sizeof(pOutHeader->mac),
        data, sizeof(data));
    NN_SDK_ASSERT_EQUAL(encryptedSize, sizeof(pOutHeader->data));

    NN_DETAIL_MIGRATION_TRACE("[UserCommandEncryptor::Initialize] BlockSize = %zu, UserCommandSize = %zu\n",
        m_BlockSize, m_UserCommandSize);
}

template <typename MsgEncryptor>
void UserCommandEncryptor<MsgEncryptor>::Update(
    size_t* pOutEncryptedSize,
    void* outEncrypted, size_t outEncryptedSize,
    const void* plain, size_t plainSize) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_GREATER(m_BlockSize, 0u);

    const auto SequenceManagementDataSize = MessageEncryptorConfig::SequenceManagementDataSize;
    NN_SDK_REQUIRES_NOT_NULL(pOutEncryptedSize);
    NN_SDK_REQUIRES_NOT_NULL(outEncrypted);
    NN_SDK_REQUIRES_GREATER_EQUAL(outEncryptedSize, GetEncryptedUserCommandSize(m_BlockSize, plainSize));
    NN_SDK_REQUIRES_NOT_NULL(plain);
    NN_SDK_REQUIRES_GREATER_EQUAL(plainSize, 0u);
    NN_SDK_REQUIRES((plainSize % m_BlockSize == 0u) || (plainSize == m_UserCommandSize - m_EncryptedUserCommandSize));
    NN_UNUSED(SequenceManagementDataSize);
    NN_UNUSED(outEncryptedSize);

    size_t encryptedBlockSize = GetEncryptedUserCommandBlockSize(m_BlockSize);

    for( size_t i = 0; (i * m_BlockSize) < plainSize; i++ )
    {
        auto sizeToEncrypt = std::min(m_BlockSize, plainSize - i * m_BlockSize);

        size_t encryptedSize;
        m_Encryptor.Encrypt(
            &encryptedSize,
            reinterpret_cast<Bit8*>(outEncrypted) + i * encryptedBlockSize, MessageEncryptorConfig::SequenceManagementDataSize + sizeToEncrypt,
            reinterpret_cast<Bit8*>(outEncrypted) + i * encryptedBlockSize + MessageEncryptorConfig::SequenceManagementDataSize + sizeToEncrypt, crypto::Aes128GcmEncryptor::MacSize,
            reinterpret_cast<const Bit8*>(plain) + i * m_BlockSize, sizeToEncrypt);
        NN_SDK_ASSERT_EQUAL(encryptedSize, sizeToEncrypt + MessageEncryptorConfig::SequenceManagementDataSize);
    }

    *pOutEncryptedSize = GetEncryptedUserCommandSize(m_BlockSize, plainSize);
    m_EncryptedUserCommandSize += plainSize;
}

template <typename MsgEncryptor>
size_t UserCommandEncryptor<MsgEncryptor>::GetTotalOutputSize() const NN_NOEXCEPT
{
    NN_SDK_REQUIRES_GREATER(m_BlockSize, 0u);
    NN_SDK_REQUIRES_GREATER(m_UserCommandSize, 0u);
    return GetEncryptedUserCommandSize(m_BlockSize, m_UserCommandSize);
}

// UserCommandDecryptor

template <typename MsgEncryptor>
UserCommandDecryptor<MsgEncryptor>::UserCommandDecryptor(MsgEncryptor& encryptor) NN_NOEXCEPT
    : m_BlockSize(0u)
    , m_UserCommandSize(0u)
    , m_DecryptedUserCommandSize(0u)
    , m_Encryptor(encryptor)
{
}

template <typename MsgEncryptor>
bool UserCommandDecryptor<MsgEncryptor>::Initialize(const UserCommandHeader& header) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_EQUAL(header.commandId, CommandKind::User);

    size_t decryptedSize;
    Bit8 decrypted[sizeof(uint64_t) + sizeof(uint64_t)];
    if( !m_Encryptor.Decrypt(
        &decryptedSize,
        decrypted, sizeof(decrypted),
        header.data, sizeof(header.data),
        header.mac, sizeof(header.mac)) )
    {
        return false;
    }

    auto blockSize = util::LoadBigEndian(reinterpret_cast<uint64_t*>(decrypted));
    auto userCommandSize = util::LoadBigEndian(reinterpret_cast<uint64_t*>(decrypted) + 1);

    if( !(util::IsIntValueRepresentable<size_t>(blockSize) && util::IsIntValueRepresentable<size_t>(userCommandSize)) )
    {
        NN_DETAIL_MIGRATION_TRACE("[UserCommandDecryptor::Initialize] BlockSize(%llu) or UserCommandSize(%llu) is too large(over size_t limit).\n", blockSize, userCommandSize);
        return false;
    }
    if( !IsAcceptableUserCommandSize(static_cast<size_t>(blockSize), static_cast<size_t>(userCommandSize)) )
    {
        NN_DETAIL_MIGRATION_TRACE("[UserCommandDecryptor::Initialize] BlockSize(%llu) is too small or UserCommandSize(%llu) is too large(encrypted user command size exceeds size_t limit).\n", blockSize, userCommandSize);
        return false;
    }
    m_BlockSize = static_cast<size_t>(blockSize);
    m_UserCommandSize = static_cast<size_t>(userCommandSize);

    m_DecryptedUserCommandSize = 0u;

    NN_DETAIL_MIGRATION_TRACE("[UserCommandDecryptor::Initialize] BlockSize = %zu, UserCommandSize = %zu\n",
        m_BlockSize, m_UserCommandSize);

    return (m_BlockSize > 0) && (m_UserCommandSize > 0);
}

template <typename MsgEncryptor>
bool UserCommandDecryptor<MsgEncryptor>::Update(
    size_t* pOutPlainSize,
    void* outPlain, size_t outPlainSize,
    const void* encrypted, size_t encryptedSize) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_GREATER(m_BlockSize, 0u);

    const auto SequenceManagementDataSize = MessageEncryptorConfig::SequenceManagementDataSize;
    NN_SDK_REQUIRES_NOT_NULL(pOutPlainSize);
    NN_SDK_REQUIRES_NOT_NULL(outPlain);
    NN_SDK_REQUIRES_GREATER_EQUAL(outPlainSize, GetDecryptedUserCommandSize(m_BlockSize, encryptedSize));
    NN_SDK_REQUIRES_NOT_NULL(encrypted);
    NN_SDK_REQUIRES(((encryptedSize % GetEncryptedUserCommandBlockSize(m_BlockSize)) == 0u) || (encryptedSize == GetEncryptedUserCommandSize(m_BlockSize, m_UserCommandSize - m_DecryptedUserCommandSize)));
    NN_UNUSED(SequenceManagementDataSize);
    NN_UNUSED(outPlainSize);

    size_t encryptedBlockSize = GetEncryptedUserCommandBlockSize(m_BlockSize);

    bool isValid = true;

    for( size_t i = 0; isValid && (i * encryptedBlockSize) < encryptedSize; i++ )
    {
        size_t sizeToDecrypt = std::min(encryptedBlockSize, encryptedSize - i * encryptedBlockSize);

        size_t decryptedSize;
        isValid = m_Encryptor.Decrypt(
            &decryptedSize,
            reinterpret_cast<Bit8*>(outPlain) + i * m_BlockSize, sizeToDecrypt - MessageEncryptorConfig::SequenceManagementDataSize,
            reinterpret_cast<const Bit8*>(encrypted) + i * encryptedBlockSize, sizeToDecrypt - crypto::Aes128GcmEncryptor::MacSize,
            reinterpret_cast<const Bit8*>(encrypted) + i * encryptedBlockSize + sizeToDecrypt - crypto::Aes128GcmEncryptor::MacSize, crypto::Aes128GcmEncryptor::MacSize);
        NN_SDK_ASSERT(!isValid || (isValid && (decryptedSize == (sizeToDecrypt - MessageEncryptorConfig::SequenceManagementDataSize - crypto::Aes128GcmEncryptor::MacSize))));
    }
    if( !isValid )
    {
        return false;
    }

    *pOutPlainSize = GetDecryptedUserCommandSize(m_BlockSize, encryptedSize);
    m_DecryptedUserCommandSize += *pOutPlainSize;

    return true;
}

template <typename MsgEncryptor>
size_t UserCommandDecryptor<MsgEncryptor>::GetAcceptableUpdateSize(size_t upperSize) const NN_NOEXCEPT
{
    NN_SDK_REQUIRES_GREATER(m_BlockSize, 0u);

    auto remainEncryptedUserCommandSize = GetEncryptedUserCommandSize(m_BlockSize, m_UserCommandSize - m_DecryptedUserCommandSize);
    auto encryptedUserCommandBlockSize = GetEncryptedUserCommandBlockSize(m_BlockSize);
    return (upperSize >= remainEncryptedUserCommandSize) ?
        remainEncryptedUserCommandSize : encryptedUserCommandBlockSize * (upperSize / encryptedUserCommandBlockSize);
}

template <typename MsgEncryptor>
size_t UserCommandDecryptor<MsgEncryptor>::GetExpectedTotalInputSize() const NN_NOEXCEPT
{
    NN_SDK_REQUIRES_GREATER(m_BlockSize, 0u);
    NN_SDK_REQUIRES_GREATER(m_UserCommandSize, 0u);
    return GetEncryptedUserCommandSize(m_BlockSize, m_UserCommandSize);
}

template <typename MsgEncryptor>
size_t UserCommandDecryptor<MsgEncryptor>::GetBlockSize() const NN_NOEXCEPT
{
    return m_BlockSize;
}

template <typename MsgEncryptor>
size_t UserCommandDecryptor<MsgEncryptor>::GetCommandSize() const NN_NOEXCEPT
{
    return m_UserCommandSize;
}

}}}
