﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/nn_Common.h>
#include <nn/nn_SdkAssert.h>
#include <nn/crypto.h>
#include <nn/crypto/crypto_Compare.h>
#include <nn/crypto/crypto_PasswordBasedKeyGenerator.h>
#include <nn/migration/detail/migration_Log.h>
#include <nn/migration/idc/migration_MessageEncryptor.h>
#include <nn/migration/idc/detail/migration_Util.h>

namespace nn { namespace migration { namespace idc {

MessageEncryptor::MessageEncryptor() NN_NOEXCEPT
    : m_Counter(nullptr)
{
}

MessageEncryptor::~MessageEncryptor() NN_NOEXCEPT
{
    detail::SecureMemoryZero(m_P2, sizeof(m_P2));
}

void MessageEncryptor::Initialize(const Bit8 passphrase[], size_t passphraseSize, const Bit8 salt[], size_t saltSize, int iteration, MessageEncryptorConfig::Counter counterInitialValue) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(passphrase);
    NN_SDK_REQUIRES_GREATER(passphraseSize, 0u);
    NN_SDK_REQUIRES_NOT_NULL(salt);
    NN_SDK_REQUIRES_GREATER(saltSize, 0u);
    NN_SDK_REQUIRES_GREATER(iteration, 0);
    NN_SDK_REQUIRES(!m_Counter);

    crypto::PasswordBasedKeyGenerator<crypto::Sha256Generator> keyGenerator;
    keyGenerator.Initialize(passphrase, passphraseSize, salt, saltSize, iteration);
    Bit8 bytes[sizeof(m_P2) + sizeof(m_Iv1)];
    keyGenerator.GetBytes(bytes, sizeof(bytes));

    std::memcpy(m_P2, bytes, sizeof(m_P2));
    std::memcpy(m_Iv1, bytes + sizeof(m_P2), sizeof(m_Iv1));

    m_Counter = counterInitialValue;
}

void MessageEncryptor::Encrypt(
    size_t* pOutEncryptedSize,
    void* outEncrypted, size_t outEncryptedSize,
    void* outMac, size_t outMacSize,
    const void* plain, size_t plainSize) NN_NOEXCEPT
{
    const auto SequenceManagementDataSize = MessageEncryptorConfig::SequenceManagementDataSize;
    NN_SDK_REQUIRES_NOT_NULL(pOutEncryptedSize);
    NN_SDK_REQUIRES_NOT_NULL(outEncrypted);
    NN_SDK_REQUIRES_GREATER_EQUAL(outEncryptedSize, plainSize + SequenceManagementDataSize);
    NN_SDK_REQUIRES_NOT_NULL(outMac);
    NN_SDK_REQUIRES_EQUAL(outMacSize, nn::crypto::GcmEncryptor<nn::crypto::AesEncryptor128>::MacSize);
    NN_SDK_REQUIRES(plain != nullptr || (plain == nullptr && plainSize == 0u));
    NN_UNUSED(SequenceManagementDataSize);

    MessageEncryptorConfig::Iv2 iv2;
    GenerateIv2(iv2);

    crypto::Aes128GcmEncryptor encryptor;
    encryptor.Initialize(m_P2, sizeof(m_P2), &iv2, sizeof(iv2));

    size_t encryptedIv2Size = encryptor.Update(outEncrypted, sizeof(iv2), &iv2, sizeof(iv2));
    NN_SDK_ASSERT_EQUAL(encryptedIv2Size, sizeof(iv2));
    size_t encryptedPlainSize = encryptor.Update(reinterpret_cast<Bit8*>(outEncrypted) + encryptedIv2Size, outEncryptedSize - sizeof(iv2), plain, plainSize);
    NN_SDK_ASSERT_EQUAL(encryptedPlainSize, plainSize);
    encryptor.GetMac(outMac, outMacSize);

    *pOutEncryptedSize = encryptedIv2Size + encryptedPlainSize;
}

bool MessageEncryptor::Decrypt(
    size_t* pOutPlainSize,
    void* outPlain, size_t outPlainSize,
    const void* encrypted, size_t encryptedSize,
    const void* mac, size_t macSize) NN_NOEXCEPT
{
    const auto SequenceManagementDataSize = MessageEncryptorConfig::SequenceManagementDataSize;
    NN_SDK_REQUIRES_NOT_NULL(pOutPlainSize);
    NN_SDK_REQUIRES_NOT_NULL(outPlain);
    NN_SDK_REQUIRES_GREATER_EQUAL(outPlainSize, encryptedSize - SequenceManagementDataSize);
    NN_SDK_REQUIRES_NOT_NULL(encrypted);
    NN_SDK_REQUIRES_GREATER_EQUAL(encryptedSize, SequenceManagementDataSize);
    NN_SDK_REQUIRES_NOT_NULL(mac);
    NN_SDK_REQUIRES_EQUAL(macSize, nn::crypto::GcmEncryptor<nn::crypto::AesEncryptor128>::MacSize);
    NN_UNUSED(SequenceManagementDataSize);

    MessageEncryptorConfig::Iv2 iv2;
    GenerateIv2(iv2);

    crypto::Aes128GcmDecryptor decryptor;
    decryptor.Initialize(m_P2, sizeof(m_P2), &iv2, sizeof(iv2));

    Bit8 decryptedSequenceManagementData[MessageEncryptorConfig::SequenceManagementDataSize];
    size_t decryptedSequenceManagementDataSize = decryptor.Update(decryptedSequenceManagementData, sizeof(decryptedSequenceManagementData), encrypted, sizeof(decryptedSequenceManagementData));
    NN_SDK_ASSERT_EQUAL(decryptedSequenceManagementDataSize, sizeof(decryptedSequenceManagementData));
    NN_UNUSED(decryptedSequenceManagementDataSize);

    *pOutPlainSize = decryptor.Update(outPlain, outPlainSize,
        reinterpret_cast<const Bit8*>(encrypted) + sizeof(decryptedSequenceManagementData), encryptedSize - sizeof(decryptedSequenceManagementData));
    NN_SDK_ASSERT_EQUAL(*pOutPlainSize, encryptedSize - sizeof(decryptedSequenceManagementData));

    Bit8 calculatedMac[crypto::Aes128GcmEncryptor::MacSize] = {};
    decryptor.GetMac(calculatedMac, sizeof(calculatedMac));

    if( !(crypto::IsSameBytes(mac, calculatedMac, macSize) && crypto::IsSameBytes(&iv2, decryptedSequenceManagementData, sizeof(iv2))) )
    {
        NN_DETAIL_MIGRATION_ERROR("MessageEncryptor::Decrypt : Verification Failure\n");
        detail::TraceByteArray("   Self.IV2 : ", iv2);
        detail::TraceByteArray<MessageEncryptorConfig::SequenceManagementDataSize>("   Peer.IV2 : ", decryptedSequenceManagementData);
        detail::TraceByteArray("   Self.MAC : ", calculatedMac);
        detail::TraceByteArray<crypto::Aes128GcmEncryptor::MacSize>("   Peer.MAC : ", mac);
        return false;
    }

    return true;
}

void MessageEncryptor::GenerateIv2(MessageEncryptorConfig::Iv2& outValue) NN_NOEXCEPT
{
    NN_ABORT_UNLESS(m_Counter);
    NN_ABORT_UNLESS_NOT_EQUAL(*m_Counter, std::numeric_limits<MessageEncryptorConfig::Counter>::max());
    std::memcpy(outValue, m_Iv1, sizeof(m_Iv1));
    util::StoreBigEndian(reinterpret_cast<MessageEncryptorConfig::Counter*>(outValue + sizeof(m_Iv1)), *m_Counter);
    (*m_Counter)++;
}

// DebugMessageEncryptor

DebugMessageEncryptor::DebugMessageEncryptor() NN_NOEXCEPT
    : m_Counter(nullptr)
{
}

void DebugMessageEncryptor::Initialize(const Bit8 passphrase[], size_t passphraseSize, const Bit8 salt[], size_t saltSize, int iteration, MessageEncryptorConfig::Counter counterInitialValue) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(passphrase);
    NN_SDK_REQUIRES_GREATER(passphraseSize, 0u);
    NN_SDK_REQUIRES_NOT_NULL(salt);
    NN_SDK_REQUIRES_GREATER(saltSize, 0u);
    NN_SDK_REQUIRES_GREATER(iteration, 0);
    NN_SDK_REQUIRES(!m_Counter);

    m_Counter = counterInitialValue;

    NN_UNUSED(passphrase);
    NN_UNUSED(passphraseSize);
    NN_UNUSED(salt);
    NN_UNUSED(saltSize);
    NN_UNUSED(iteration);
}

void DebugMessageEncryptor::Encrypt(
    size_t* pOutEncryptedSize,
    void* outEncrypted, size_t outEncryptedSize,
    void* outMac, size_t outMacSize,
    const void* plain, size_t plainSize) NN_NOEXCEPT
{
    const auto SequenceManagementDataSize = MessageEncryptorConfig::SequenceManagementDataSize;
    NN_SDK_REQUIRES_NOT_NULL(pOutEncryptedSize);
    NN_SDK_REQUIRES_NOT_NULL(outEncrypted);
    NN_SDK_REQUIRES_GREATER_EQUAL(outEncryptedSize, plainSize + SequenceManagementDataSize);
    NN_SDK_REQUIRES_NOT_NULL(outMac);
    NN_SDK_REQUIRES_EQUAL(outMacSize, nn::crypto::GcmEncryptor<nn::crypto::AesEncryptor128>::MacSize);
    NN_SDK_REQUIRES(plain != nullptr || (plain == nullptr && plainSize == 0u));
    NN_UNUSED(SequenceManagementDataSize);
    NN_UNUSED(outEncryptedSize);

    *pOutEncryptedSize = plainSize + MessageEncryptorConfig::SequenceManagementDataSize;

    Bit8 sequenceManagementData[SequenceManagementDataSize] = {};
    GenerateManagementData(sequenceManagementData, sizeof(sequenceManagementData));

    std::memcpy(outEncrypted, sequenceManagementData, sizeof(sequenceManagementData));
    std::memcpy(reinterpret_cast<Bit8*>(outEncrypted) + sizeof(sequenceManagementData), plain, plainSize);
    // MAC値は作成せず、Decrypt でも検証しない。
    std::memset(outMac, 0, outMacSize);
}

bool DebugMessageEncryptor::Decrypt(
    size_t* pOutPlainSize,
    void* outPlain, size_t outPlainSize,
    const void* encrypted, size_t encryptedSize,
    const void* mac, size_t macSize) NN_NOEXCEPT
{
    const auto SequenceManagementDataSize = MessageEncryptorConfig::SequenceManagementDataSize;
    NN_SDK_REQUIRES_NOT_NULL(pOutPlainSize);
    NN_SDK_REQUIRES_NOT_NULL(outPlain);
    NN_SDK_REQUIRES_GREATER_EQUAL(outPlainSize, encryptedSize - SequenceManagementDataSize);
    NN_SDK_REQUIRES_NOT_NULL(encrypted);
    NN_SDK_REQUIRES_GREATER_EQUAL(encryptedSize, SequenceManagementDataSize);
    NN_SDK_REQUIRES_NOT_NULL(mac);
    NN_SDK_REQUIRES_EQUAL(macSize, nn::crypto::GcmEncryptor<nn::crypto::AesEncryptor128>::MacSize);
    NN_UNUSED(SequenceManagementDataSize);
    NN_UNUSED(outPlainSize);
    NN_UNUSED(mac);
    NN_UNUSED(macSize);

    Bit8 sequenceManagementData[SequenceManagementDataSize] = {};
    GenerateManagementData(sequenceManagementData, sizeof(sequenceManagementData));

    if( !crypto::IsSameBytes(sequenceManagementData, encrypted, sizeof(sequenceManagementData)) )
    {
        NN_DETAIL_MIGRATION_ERROR("DebugMessageEncryptor::Decrypt : Verification Failure\n");
        detail::TraceByteArray("   Self.SequenceManagementData : ", sequenceManagementData);
        detail::TraceByteArray<MessageEncryptorConfig::SequenceManagementDataSize>("   Peer.SequenceManagementData : ", encrypted);
        return false;
    }

    std::memcpy(outPlain, reinterpret_cast<const Bit8*>(encrypted) + SequenceManagementDataSize, encryptedSize - SequenceManagementDataSize);
    *pOutPlainSize = encryptedSize - SequenceManagementDataSize;

    return true;
}

void DebugMessageEncryptor::GenerateManagementData(Bit8 outManagementData[], size_t outManagementDataSize) NN_NOEXCEPT
{
    const auto SequenceManagementDataSize = MessageEncryptorConfig::SequenceManagementDataSize;
    NN_SDK_REQUIRES_EQUAL(outManagementDataSize, SequenceManagementDataSize);
    NN_UNUSED(SequenceManagementDataSize);
    NN_UNUSED(outManagementDataSize);
    // 通常版 MessageEncryptor に合わせてチェックを入れておく。
    NN_ABORT_UNLESS(m_Counter);
    NN_ABORT_UNLESS_NOT_EQUAL(*m_Counter, std::numeric_limits<MessageEncryptorConfig::Counter>::max());

    util::StoreBigEndian(reinterpret_cast<MessageEncryptorConfig::Counter*>(outManagementData), *m_Counter);
    std::memset(outManagementData + sizeof(*m_Counter), 0, MessageEncryptorConfig::SequenceManagementDataSize - sizeof(*m_Counter));
    (*m_Counter)++;
}

}}};
