﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <nn/nn_Common.h>
#include <nn/nn_SdkAssert.h>
#include <nn/util/util_Endian.h>
#include <nn/crypto/crypto_HmacGenerator.h>
#include <nn/crypto/crypto_Sha1Generator.h>
#include <nn/crypto/crypto_Sha256Generator.h>
#include <nn/crypto/detail/crypto_Clear.h>
#include <algorithm>

namespace nn { namespace crypto {

/*!
    @brief      パスワードベースの鍵生成機能の実装クラスです。

    @details
    本クラスは RFC 2898 で定義されている PBKDF2 を利用して鍵を生成します。
*/
template <typename HashFunction> class PasswordBasedKeyGenerator
{
private:
    NN_DISALLOW_COPY(PasswordBasedKeyGenerator);
    NN_DISALLOW_MOVE(PasswordBasedKeyGenerator);

public:
    static const size_t HashSize = HashFunction::HashSize; //!< ハッシュ値のサイズを表す定数です。

public:
    /*!
        @brief      デフォルトコンストラクタです。
    */
    PasswordBasedKeyGenerator() NN_NOEXCEPT :
        m_pPassphrase(nullptr),
        m_PassphraseSize(0),
        m_pSalt(nullptr),
        m_SaltSize(0),
        m_IterationCount(0),
        m_RemainSize(0),
        m_BlockIndex(1)
    {
    }

    /*!
        @brief      デストラクタです。
    */
    ~PasswordBasedKeyGenerator() NN_NOEXCEPT
    {
        detail::ClearMemory(this, sizeof(*this));
    }

    /*!
        @brief      初期化します。

        @param[in]  pPassphrase     パスフレーズへのポインタ。
                                    ディープコピーは行わないため、本オブジェクトの生存期間中に破棄しないで下さい。
        @param[in]  passphraseSize  パスフレーズのバイトサイズ。
        @param[in]  pSalt           ソルトへのポインタ。
                                    ディープコピーは行わないため、本オブジェクトの生存期間中に破棄しないで下さい。
        @param[in]  saltSize        ソルトのバイトサイズ。
        @param[in]  iterationCount  反復回数。

        @post
        - インスタンスは処理可能状態になる。
    */
    void Initialize(const Bit8* pPassphrase, size_t passphraseSize,
        const Bit8* pSalt, size_t saltSize, uint32_t iterationCount) NN_NOEXCEPT
    {
        NN_SDK_REQUIRES_NOT_NULL(pPassphrase);
        NN_SDK_REQUIRES_GREATER(passphraseSize, 0U);
        NN_SDK_REQUIRES_NOT_NULL(pSalt);
        NN_SDK_REQUIRES_GREATER(saltSize, 0U);

        m_pPassphrase = pPassphrase;
        m_PassphraseSize = passphraseSize;
        m_pSalt = pSalt;
        m_SaltSize = saltSize;
        m_IterationCount = iterationCount;

        m_RemainSize = 0;
        m_BlockIndex = 1;
    }

    /*!
        @brief      導出された鍵のバイト列を取得します。

        @param[out] pBuffer 導出された鍵のバイト列を格納するバッファへのポインタ。
        @param[in]  size    pBuffer が指すバッファのバイトサイズ。

        @pre
        - インスタンスは処理可能状態である。

        @details
        本関数は呼び出し毎に異なるバイト列を生成します。
        生成されたバイト列を順番に結合することで、大きい値を指定した場合と同一のバイト列が取得可能です。
        例えば、GetBytes(&bytes[0], 10), GetBytes(&bytes[10], 20) と GetBytes(&bytes[0], 30) は同一のバイト列になります。
    */
    void GetBytes(Bit8* pBuffer, size_t size) NN_NOEXCEPT
    {
        NN_SDK_REQUIRES_NOT_NULL(pBuffer);
        NN_SDK_REQUIRES_GREATER(size, 0U);

        NN_SDK_ASSERT_NOT_NULL(m_pPassphrase);

        do
        {
            if (m_RemainSize == 0)
            {
                // T = F(Passphrase, Salt, IterationCount, BlockIndex)
                GenerateBytes(m_BlockT, m_pPassphrase, m_PassphraseSize, m_pSalt, m_SaltSize, m_IterationCount, m_BlockIndex++);
                m_RemainSize = HashSize;
            }

            size_t copySize = std::min(m_RemainSize, size);

            std::memcpy(pBuffer, &m_BlockT[HashSize - m_RemainSize], copySize);
            pBuffer += copySize;
            size -= copySize;

            m_RemainSize -= copySize;
        }
        while (size > 0);
    }

private:
    //
    const Bit8* m_pPassphrase;
    size_t m_PassphraseSize;
    //
    const Bit8* m_pSalt;
    size_t m_SaltSize;
    //
    uint32_t m_IterationCount;
    //
    Bit8 m_BlockT[HashSize];
    size_t m_RemainSize;
    //
    uint32_t m_BlockIndex;

private:
    static void GenerateBytes(Bit8* blockT,
        const Bit8* pPassphrase, size_t passphraseSize,
        const Bit8* pSalt, size_t saltSize,
        uint32_t iterationCount, uint32_t blockIndex) NN_NOEXCEPT
    {
        nn::crypto::HmacGenerator<HashFunction> generator;

        Bit8 blockU[HashSize];

        // U_1 = PRF(Passphrase, Salt || INT_32_BE(BlockIndex))
        nn::util::StoreBigEndian(&blockIndex, blockIndex);
        generator.Initialize(pPassphrase, passphraseSize);
        generator.Update(pSalt, saltSize);
        generator.Update(&blockIndex, 4);
        generator.GetMac(blockU, HashSize);

        // T = U_1
        memcpy(blockT, blockU, HashSize);

        for (size_t i = 1; i < iterationCount; i++)
        {
            // U_2 = PRF(P, U_1), ..., U_c = PRF(P, U_{c-1})
            generator.Initialize(pPassphrase, passphraseSize);
            generator.Update(blockU, HashSize);
            generator.GetMac(blockU, HashSize);

            // T = U_1 xor U_2 xor ... xor U_c
            for (size_t x = 0; x < HashSize; x++)
            {
                blockT[x] ^= blockU[x];
            }
        }
    }
};

NN_DEFINE_STATIC_CONSTANT(template<typename T> const size_t PasswordBasedKeyGenerator<T>::HashSize);

}}
