﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <nn/crypto/crypto_Sha256Generator.h>
#include <nn/crypto/crypto_RsaOaepEncryptor.h>
#include <nn/crypto/crypto_Aes128CtrEncryptor.h>
#include <nn/crypto/crypto_Csrng.h>
#include <nn/erpt/server/erpt_ServerTypes.h>

#include "erptsrv_Formatter.h"
#include "erptsrv_Keys.h"

namespace nn   {
namespace erpt {
namespace srv  {

const uint32_t RsaKeyLength = 256;
const uint32_t SeedLength   = 32;

class Cipher : private Formatter
{
private:
    static uint8_t s_Key[nn::crypto::Aes128CtrEncryptor::BlockSize + 2 * nn::crypto::Aes128CtrEncryptor::IvSize];
    static bool    s_NeedToStoreCipher;

    // Encrypted fields are stored in report with the following header
    // 'CRPT'|element_size|element_count|00000000|encrypted_data
    #if defined(NN_BUILD_CONFIG_OS_WIN)
    #pragma warning(push)
    #pragma warning(disable : 4200)
    #endif
    struct Header
    {
        uint32_t magic;
        uint32_t fieldType;
        uint32_t elementCount;
        uint32_t reserved;
        uint8_t  data[0];
    };
    #if defined(NN_BUILD_CONFIG_OS_WIN)
    #pragma warning(pop)
    #endif

    static uint32_t RoundUp(uint32_t size, uint32_t ceil)
    NN_NOEXCEPT
    {
        return ((size + ceil - 1) / ceil) * ceil;
    }

    template <typename T>
    static nn::Result EncryptArray(Report* pReport, FieldId id, T* pArray, uint32_t count)
    NN_NOEXCEPT
    {
        nn::Result result;
        Header*    pHeader;
        uint32_t   dataLength = RoundUp(count * sizeof(T), nn::crypto::Aes128CtrEncryptor::BlockSize);

        pHeader = reinterpret_cast<Header*>(lmem::AllocateFromExpHeap(
            g_HeapHandle,
            sizeof(Header) + dataLength,
            nn::crypto::Aes128CtrEncryptor::BlockSize));

        if (pHeader == nullptr)
        {
            return nn::erpt::ResultOutOfMemory();
        }

        pHeader->magic        = 'C' | 'R' << 8 | 'P' << 16 | 'T' << 24;
        pHeader->fieldType    = static_cast<uint32_t>(erpt::srv::FieldToTypeMap[id]);
        pHeader->elementCount = count;
        pHeader->reserved     = 0;

        std::memset(pHeader->data, 0x0, dataLength);
        std::memcpy(pHeader->data, pArray, count * sizeof(T));

        // encrypt in place
        nn::crypto::EncryptAes128Ctr(
            pHeader->data, dataLength,
            s_Key,  nn::crypto::Aes128CtrEncryptor::BlockSize,
            s_Key + nn::crypto::Aes128CtrEncryptor::BlockSize,  nn::crypto::Aes128CtrEncryptor::IvSize,
            pHeader->data, dataLength);

        result = Formatter::AddField(pReport, id, reinterpret_cast<uint8_t*>(pHeader), sizeof(Header) + dataLength);

        memset(pHeader, 0x0, sizeof(Header) + dataLength);
        lmem::FreeToExpHeap(g_HeapHandle, pHeader);
        s_NeedToStoreCipher = true;

        return result;
    }

public:
    static nn::Result AddField(Report* pReport, FieldId id, bool value)
    NN_NOEXCEPT
    {
        return Formatter::AddField(pReport, id, value);
    }

    template <typename T>
    static nn::Result AddField(Report* pReport, FieldId id, T value)
    NN_NOEXCEPT
    {
        return Formatter::AddField<T>(pReport, id, value);
    }

    static nn::Result AddField(Report* pReport, FieldId id, char* pString, uint32_t length)
    NN_NOEXCEPT
    {
        return FieldToFlagMap[id] == FieldFlag_Encrypt ?
               EncryptArray<char>(pReport,  id, pString, length) :
               Formatter::AddField(pReport, id, pString, length);
    }

    static nn::Result AddField(Report* pReport, FieldId id, uint8_t* pArray, uint32_t count)
    NN_NOEXCEPT
    {
        return FieldToFlagMap[id] == FieldFlag_Encrypt ?
               EncryptArray<uint8_t>(pReport, id, pArray, count) :
               Formatter::AddField(pReport, id, pArray, count);
    }

    template <typename T>
    static nn::Result AddField(Report* pReport, FieldId id, T* pArray, uint32_t count)
    NN_NOEXCEPT
    {
        return FieldToFlagMap[id] == FieldFlag_Encrypt ?
               EncryptArray<T>(pReport, id, pArray, count) :
               Formatter::AddField<T>(pReport, id, pArray, count);
    }

    static nn::Result Begin(Report* pReport, uint32_t recordCount)
    NN_NOEXCEPT
    {
        s_NeedToStoreCipher = false;
        nn::crypto::GenerateCryptographicallyRandomBytes(s_Key, sizeof(s_Key));

        #if !defined(NN_SDK_BUILD_RELEASE)
        // for verification of key decryption operation, store recognizable pattern after IV
        for (uint8_t i = 0; i < nn::crypto::Aes128CtrEncryptor::IvSize; i++)
        {
            s_Key[nn::crypto::Aes128CtrEncryptor::BlockSize + nn::crypto::Aes128CtrEncryptor::IvSize + i] = i;
        }
        #endif

        return Formatter::Begin(pReport, recordCount + 1 /* +1 for cipher key */);
    }

    static nn::Result End(Report* pReport)
    NN_NOEXCEPT
    {
        uint8_t cipher[RsaKeyLength] = {0};

        if (s_NeedToStoreCipher)
        {
            uint8_t seed[SeedLength];
            nn::crypto::RsaOaepEncryptor<RsaKeyLength, nn::crypto::Sha256Generator> rsa;
            nn::crypto::GenerateCryptographicallyRandomBytes(seed, sizeof(seed));

            rsa.Initialize(GetPublicKeyModulus(), GetPublicKeyModulusSize(), GetPublicKeyExponent(), GetPublicKeyExponentSize());
            rsa.Encrypt(cipher, sizeof(cipher), s_Key, sizeof(s_Key), seed, sizeof(seed));
        }

        Formatter::AddField(pReport, nn::erpt::CipherKey, cipher, sizeof(cipher));
        std::memset(s_Key, 0x0, sizeof(s_Key));

        return Formatter::End(pReport);
    }
};

}}}
