﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <nn/nn_Common.h>
#include <nn/nn_SdkAssert.h>

namespace nn { namespace spl {

    const int BitsPerByte = 8;

    template <typename BlockCipher, size_t KeySize, bool UseDerivation>
    class CtrDrbg
    {
        friend class TestHelper;

    public:
        static const int KeyLen         = KeySize                * BitsPerByte;
        static const int OutLen         = BlockCipher::BlockSize * BitsPerByte;
        static const int SeedLen        = OutLen + KeyLen;
        static const int MaxNumberOfBitsPerRequest  = (1 << 19);
        static const int ReseedInterval             = 0x7FFFFFF0;

        static const size_t OutSize     = OutLen  / BitsPerByte;
        static const size_t SeedSize    = SeedLen / BitsPerByte;
        static const size_t RequestSizeMax      = MaxNumberOfBitsPerRequest / BitsPerByte;

    public:
        void Initialize(
            const void* pEntropyInput, size_t entropyInputSize,
            const void* pNonce, size_t nonceSize,
            const void* pPersonalizationString, size_t personalizationStringSize);
        void Reseed(
            const void* pEntropyInput, size_t entropyInputSize,
            const void* pAdditionalInput, size_t addtionalInputSize );
        bool Generate(
            void* pBuffer, size_t bufferSize,
            const void* pAdditionalInput, size_t addtionalInputSize);

    private:
        void UpdateStates(void* pKey, void* pV, const void* pProvidedData);
        void DeriveSeed(
            void* pSeed,
            const void* pA, size_t aSize,
            const void* pB, size_t bSize,
            const void* pC, size_t cSize );

        static void Increment(void* pV);
        static void Xor(void* pA, const void* pB, size_t size);

    private:
        class Bcc
        {
        public:
            Bcc(Bit8* pBuffer, const BlockCipher* pCipher)
                : m_pBuffer(pBuffer), m_pCipher(pCipher), m_Offset(0)
            {
            }

            void Process(const void* pData, size_t dataSize);
            void Flush();

        private:
            Bit8*               m_pBuffer;
            const BlockCipher*  m_pCipher;
            size_t              m_Offset;
        };

    private:
    public:
        BlockCipher m_BlockCipher;
        Bit8        m_V[OutSize];
        Bit8        m_Key[KeySize];
        Bit8        m_Work1[SeedSize];
        Bit8        m_Work2[SeedSize];
        int         m_ReseedCounter;
    };

    template <typename BlockCipher, size_t KeySize, bool UseDerivation>    const int CtrDrbg<BlockCipher, KeySize, UseDerivation>::KeyLen;
    template <typename BlockCipher, size_t KeySize, bool UseDerivation>    const int CtrDrbg<BlockCipher, KeySize, UseDerivation>::OutLen;
    template <typename BlockCipher, size_t KeySize, bool UseDerivation>    const int CtrDrbg<BlockCipher, KeySize, UseDerivation>::SeedLen;
    template <typename BlockCipher, size_t KeySize, bool UseDerivation>    const int CtrDrbg<BlockCipher, KeySize, UseDerivation>::MaxNumberOfBitsPerRequest;
    template <typename BlockCipher, size_t KeySize, bool UseDerivation>    const int CtrDrbg<BlockCipher, KeySize, UseDerivation>::ReseedInterval;
    template <typename BlockCipher, size_t KeySize, bool UseDerivation>    const size_t CtrDrbg<BlockCipher, KeySize, UseDerivation>::OutSize;
    template <typename BlockCipher, size_t KeySize, bool UseDerivation>    const size_t CtrDrbg<BlockCipher, KeySize, UseDerivation>::SeedSize;
    template <typename BlockCipher, size_t KeySize, bool UseDerivation>    const size_t CtrDrbg<BlockCipher, KeySize, UseDerivation>::RequestSizeMax;


    template <typename BlockCipher, size_t KeySize, bool UseDerivation>
    void CtrDrbg<BlockCipher, KeySize, UseDerivation>::Bcc::Process(const void* pData, size_t dataSize)
    {
        const Bit8* p8 = reinterpret_cast<const Bit8*>(pData);
        size_t remainSize = dataSize;

        while( m_Offset + remainSize >= OutSize )
        {
            size_t xorSize = OutSize - m_Offset;
            Xor(m_pBuffer + m_Offset, p8, OutSize - m_Offset);
            m_pCipher->EncryptBlock(m_pBuffer, OutSize, m_pBuffer, OutSize);

            p8         += xorSize;
            remainSize -= xorSize;
            m_Offset   = 0;
        }

        Xor(m_pBuffer + m_Offset, p8, remainSize);
        m_Offset += remainSize;
    }

    template <typename BlockCipher, size_t KeySize, bool UseDerivation>
    void CtrDrbg<BlockCipher, KeySize, UseDerivation>::Bcc::Flush()
    {
        if( m_Offset != 0 )
        {
            m_pCipher->EncryptBlock(m_pBuffer, OutSize, m_pBuffer, OutSize);
            m_Offset   = 0;
        }
    }

    template <typename BlockCipher, size_t KeySize, bool UseDerivation>
    void CtrDrbg<BlockCipher, KeySize, UseDerivation>::Increment(void* pV)
    {
        Bit8* pByte = reinterpret_cast<Bit8*>(pV);

        for( int i = OutSize - 1; i >= 0; --i )
        {
            pByte[i]++;

            if( pByte[i] != 0 )
            {
                break;
            }
        }
    }

    template <typename BlockCipher, size_t KeySize, bool UseDerivation>
    void CtrDrbg<BlockCipher, KeySize, UseDerivation>::Xor(void* pA, const void* pB, size_t size)
    {
        Bit8*       pA8 = reinterpret_cast<Bit8*>(pA);
        const Bit8* pB8 = reinterpret_cast<const Bit8*>(pB);

        for( size_t i = 0; i < size; ++i )
        {
            pA8[i] ^= pB8[i];
        }
    }

    template <typename BlockCipher, size_t KeySize, bool UseDerivation>
    void CtrDrbg<BlockCipher, KeySize, UseDerivation>::DeriveSeed(
        void* pSeed,
        const void* pA, size_t aSize,
        const void* pB, size_t bSize,
        const void* pC, size_t cSize )
    {
        NN_STATIC_ASSERT( SeedSize % OutSize == 0 );

        // 2.  L = len (input_string)/8.
        // 3.  N = number_of_bits_to_return/8.
        uint32_t inSize = aSize + bSize + cSize;
        uint32_t outSize = SeedSize;

        // 4.  S = L || N || input_string || 0x80.
        Bit32 header[2];
        util::StoreBigEndian(&header[0], inSize);
        util::StoreBigEndian(&header[1], outSize);
        Bit8 footer = 0x80;

        Bit8* p8 = reinterpret_cast<Bit8*>(pSeed);

        // 8.  K = Leftmost keylen bits of 0x00010203...1D1E1F.
        for( size_t i = 0; i < KeySize; ++i )
        {
            p8[i] = i;
        }
        m_BlockCipher.Initialize(p8, KeySize);

        // 6.  temp = the Null string.
        // 7.  i = 0.
        // 9.  While len (temp) < keylen + outlen, do
        for( uint32_t b = 0; b < SeedSize / OutSize; ++b )
        {
            // 9.2  temp = temp || BCC (K, (IV || S)).
            Bit32 bb;
            Bit8* pTarget = p8 + b * OutSize;

            util::StoreBigEndian(&bb, b);
            std::memset(pTarget, 0, OutSize);

            Bcc bcc(pTarget, &m_BlockCipher);

            // 9.1  IV = i || 0<outlen - len (i)>.
            bcc.Process(&bb, sizeof(bb));
            bcc.Flush();

            bcc.Process(header, sizeof(header));
            bcc.Process(pA, aSize);
            bcc.Process(pB, bSize);
            bcc.Process(pC, cSize);
            bcc.Process(&footer, sizeof(footer));
            // 5.  While (len (S) mod outlen)   0, S = S || 0x00.
            bcc.Flush();

            // 9.3  i = i + 1.
        }

        // 10. K = Leftmost keylen bits of temp.
        // 11. X = Next outlen bits of temp.
        // 12. temp = the Null string.
        m_BlockCipher.Initialize(p8, KeySize);
        m_BlockCipher.EncryptBlock(p8, OutSize, p8 + KeySize, OutSize);
        // 13. While len (temp) < number_of_bits_to_return, do
        for( size_t offset = 0; offset < SeedSize - OutSize; offset += OutSize )
        {
            // 13.1  X = Block_Encrypt (K, X).
            // 13.2  temp = temp || X.
            m_BlockCipher.EncryptBlock(p8 + offset + OutSize, OutSize, p8 + offset, OutSize);
        }
        // 14. requested_bits = Leftmost number_of_bits_to_return of temp.
    }

    template <typename BlockCipher, size_t KeySize, bool UseDerivation>
    void CtrDrbg<BlockCipher, KeySize, UseDerivation>::UpdateStates(void* pKey, void* pV, const void* pProvidedData)
    {
        NN_STATIC_ASSERT( SeedSize % OutSize == 0 );

        // 1.  temp = Null.
        // 2.  While (len (temp) < seedlen) do
        m_BlockCipher.Initialize(pKey, KeySize);
        for( size_t offset = 0; offset < SeedSize; offset += OutSize )
        {
            // 2.1  V = (V + 1) mod 2<outlen>.
            Increment(pV);
            // 2.2  output_block = Block_Encrypt (Key, V).
            // 2.3  temp = temp || ouput_block.
            m_BlockCipher.EncryptBlock(&m_Work2[offset], OutSize, pV, OutSize);
        }
        // 3.  temp = Leftmost seedlen bits of temp.

        // 4  temp = temp ^ provided_data.
        Xor(m_Work2, pProvidedData, SeedSize);

        // 5.  Key = Leftmost keylen bits of temp.
        // 6.  V = Rightmost outlen bits of temp.
        std::memcpy(pKey, m_Work2 + 0,       KeySize);
        std::memcpy(pV,   m_Work2 + KeySize, OutSize);
    }

    template <typename BlockCipher, size_t KeySize, bool UseDerivation>
    void CtrDrbg<BlockCipher, KeySize, UseDerivation>::Initialize(
        const void* pEntropyInput, size_t entropyInputSize,
        const void* pNonce, size_t nonceSize,
        const void* pPersonalizationString, size_t personalizationStringSize)
    {
        if( UseDerivation )
        {
            // 1.  seed_material = entropy_input || nonce || personalization_string.
            // 2.  seed_material = Block_Cipher_df (seed_material, seedlen).
            DeriveSeed(m_Work1, pEntropyInput, entropyInputSize, pNonce, nonceSize, pPersonalizationString, personalizationStringSize);
        }
        else
        {
            NN_SDK_ASSERT( entropyInputSize          == SeedSize );
            NN_SDK_ASSERT( nonceSize                 == 0        );
            NN_SDK_ASSERT( personalizationStringSize <= SeedSize );

            // 1.  temp = len (personalization_string).
            // 2.  If (temp < seedlen), then personalization_string = personalization_string || 0<seedlen - temp>.
            // 3.  seed_material = entropy_input ~ personalization_string.
            std::memcpy(m_Work1, pEntropyInput, SeedSize);
            Xor(m_Work1, pPersonalizationString, personalizationStringSize);
        }

        // 4.  Key = 0<keylen>.
        // 5.  V = 0<outlen>.
        std::memset(m_Key, 0, sizeof(m_Key));
        std::memset(m_V, 0, sizeof(m_V));

        // 6.  (Key, V) = CTR_DRBG_Update (seed_material, Key, V).
        UpdateStates(&m_Key, &m_V, m_Work1);

        // 7.  reseed_counter = 1.
        m_ReseedCounter = 1;
    }

    template <typename BlockCipher, size_t KeySize, bool UseDerivation>
    void CtrDrbg<BlockCipher, KeySize, UseDerivation>::Reseed(
        const void* pEntropyInput, size_t entropyInputSize,
        const void* pAdditionalInput, size_t addtionalInputSize )
    {
        if( UseDerivation )
        {
            // 1.  seed_material = entropy_input || additional_input.
            // 2.  seed_material = Block_Cipher_df (seed_material, seedlen).
            DeriveSeed(m_Work1, pEntropyInput, entropyInputSize, pAdditionalInput, addtionalInputSize, NULL, 0);
        }
        else
        {
            NN_SDK_ASSERT( entropyInputSize   == SeedSize );
            NN_SDK_ASSERT( addtionalInputSize <= SeedSize );

            // 1.  temp = len (additional_input).
            // 2.  If (temp < seedlen), then additional_input = additional_input || 0<seedlen - temp>.
            // 3.  seed_material = entropy_input ~ additional_input.
            std::memcpy(m_Work1, pEntropyInput, SeedSize);
            Xor(m_Work1, pAdditionalInput, addtionalInputSize);
        }

        // 4.  (Key, V) = CTR_DRBG_Update (seed_material, Key, V).
        UpdateStates(&m_Key, &m_V, m_Work1);

        // 5.  reseed_counter = 1.
        m_ReseedCounter = 1;
    }

    template <typename BlockCipher, size_t KeySize, bool UseDerivation>
    bool CtrDrbg<BlockCipher, KeySize, UseDerivation>::Generate(
        void* pBuffer, size_t bufferSize,
        const void* pAdditionalInput, size_t addtionalInputSize)
    {
        if( bufferSize > RequestSizeMax )
        {
            return false;
        }

        // 1.  If reseed_counter  > reseed_interval, then return an indication that a reseed is required.
        if( m_ReseedCounter > ReseedInterval )
        {
            return false;
        }

        std::memset(m_Work1, 0, sizeof(m_Work1));

        // 2.  If (additional_input   Null), then
        if( addtionalInputSize > 0 )
        {
            if( UseDerivation )
            {
                // 2.1  additional_input = Block_Cipher_df (additional_input, seedlen).
                DeriveSeed(m_Work1, pAdditionalInput, addtionalInputSize, NULL, 0, NULL, 0);
            }
            else
            {
                // 2.1  temp = len (additional_input).
                // 2.2  If (temp < seedlen), then  additional_input = additional_input || 0<seedlen - temp>.
                NN_SDK_ASSERT( addtionalInputSize <= SeedSize );
                std::memcpy(m_Work1, pAdditionalInput, addtionalInputSize);
            }

            // 2.2  (Key, V) = CTR_DRBG_Update (additional_input, Key, V).
            UpdateStates(&m_Key, m_V, m_Work1);
        }

        Bit8* p8 = reinterpret_cast<Bit8*>(pBuffer);
        size_t alignedSize = util::align_down(bufferSize, OutSize);

        // 3.  temp = Null.
        // 4.  While (len (temp) < requested_number_of_bits) do:
        m_BlockCipher.Initialize(m_Key, KeySize);
        for( size_t offset = 0; offset < alignedSize; offset += OutSize )
        {
            // 4.1  V = (V + 1) mod 2<outlen>.
            Increment(m_V);

            // 4.2  output_block = Block_Encrypt (Key, V).
            // 4.3  temp = temp || output_block.
            m_BlockCipher.EncryptBlock(p8 + offset, OutSize, m_V, OutSize);
        }
        // 5.  returned_bits = Leftmost requested_number_of_bits of temp.

        if( bufferSize > alignedSize )
        {
            Bit8 temp[OutSize];
            Increment(m_V);
            m_BlockCipher.EncryptBlock(temp, sizeof(temp), m_V, OutSize);
            std::memcpy(p8 + alignedSize, temp, bufferSize - alignedSize);
        }

        // 6.  (Key, V) = CTR_DRBG_Update (additional_input, Key, V).
        UpdateStates(&m_Key, m_V, m_Work1);

        // 7.  reseed_counter = reseed_counter + 1.
        m_ReseedCounter++;
        return true;
    }



}} // namespace nn::spl
