﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <algorithm>
#include <cstring>
#include <nn/nn_Common.h>
#include <nn/nn_SdkAssert.h>

namespace nn { namespace crypto {

// forward class declaration
template <size_t KeySize> class AesEncryptor;
typedef AesEncryptor<16> AesEncryptor128;
template <size_t KeySize> class AesDecryptor;
typedef AesDecryptor<16> AesDecryptor128;


namespace detail {

class XtsModeImpl
{
public:
    static const size_t BlockSize = 16;

public:
    XtsModeImpl() NN_NOEXCEPT : m_BufferedByte(0), m_State(State_None)
    {
    }
    ~XtsModeImpl() NN_NOEXCEPT;

    template <typename BlockCipher1, typename BlockCipher2>
    void InitializeEncryption(const BlockCipher1* pBlockCipher1, const BlockCipher2* pBlockCipher2,
                              const void* pTweak, size_t tweakSize) NN_NOEXCEPT
    {
        // 暗号化のコンテキストを設定
        m_pCipherContext = pBlockCipher1;
        m_pCipherFunction = &EncryptBlockCallback<BlockCipher1>;

        // 共通の初期化処理
        Initialize(pBlockCipher2, pTweak, tweakSize);
    }

    template <typename BlockCipher1, typename BlockCipher2>
    void InitializeDecryption(const BlockCipher1* pBlockCipher1, const BlockCipher2* pBlockCipher2,
                              const void* pTweak, size_t tweakSize) NN_NOEXCEPT
    {
        // 復号化のコンテキストを設定
        m_pCipherContext = pBlockCipher1;
        m_pCipherFunction = &DecryptBlockCallback<BlockCipher1>;

        // 共通の初期化処理
        Initialize(pBlockCipher2, pTweak, tweakSize);
    }

    template <typename BlockCipher>
    size_t Update(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) NN_NOEXCEPT
    {
        return UpdateGeneric(pDst, dstSize, pSrc, srcSize);
    }

    template <typename BlockCipher>
    size_t ProcessBlocks(uint8_t* pDst8, const uint8_t* pSrc8, int numBlocks) NN_NOEXCEPT
    {
        return ProcessBlocksGeneric(pDst8, pSrc8, numBlocks);
    }

    size_t GetBufferedDataSize() const NN_NOEXCEPT
    {
        return m_BufferedByte;
    }

    size_t GetBlockSize() const NN_NOEXCEPT
    {
        return BlockSize;
    }

    size_t FinalizeEncryption(void* pDst, size_t dstSize) NN_NOEXCEPT;
    size_t FinalizeDecryption(void* pDst, size_t dstSize) NN_NOEXCEPT;

    size_t UpdateGeneric(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) NN_NOEXCEPT;
    size_t ProcessBlocksGeneric(uint8_t* pDst8, const uint8_t* pSrc8, int numBlocks) NN_NOEXCEPT;
    size_t ProcessPartialData(uint8_t* pDst8, const uint8_t* pSrc8, size_t dataSize) NN_NOEXCEPT;
    size_t ProcessRemainingData(uint8_t* pDst8, const uint8_t* pSrc8, size_t dataSize) NN_NOEXCEPT;

private:
    template <typename BlockCipher>
    void Initialize(const BlockCipher pBlockCipher, const void* pTweak, size_t tweakSize) NN_NOEXCEPT
    {
        NN_SDK_REQUIRES(tweakSize > 0);
        NN_UNUSED(tweakSize);

        // 鍵 2 を使って tweak を暗号化する（鍵 2 を使うのはここだけ）
        pBlockCipher->EncryptBlock(m_Tweak, BlockSize, pTweak, BlockSize);

        m_BufferedByte = 0;

        m_State = State_Initialized;
    }

    template <typename BlockCipher>
    static void EncryptBlockCallback(void* outBlock, const void* inBlock, const void* self) NN_NOEXCEPT
    {
        static_cast<const BlockCipher*>(self)->EncryptBlock(outBlock, BlockCipher::BlockSize,
                                                            inBlock, BlockCipher::BlockSize );
    }

    template <typename BlockCipher>
    static void DecryptBlockCallback( void* outBlock, const void* inBlock, const void* self) NN_NOEXCEPT
    {
        static_cast<const BlockCipher*>(self)->DecryptBlock(outBlock, BlockCipher::BlockSize,
                                                            inBlock, BlockCipher::BlockSize );
    }

    void ProcessBlock(uint8_t* pDst8, const uint8_t* pSrc8) NN_NOEXCEPT;

private:
    enum State
    {
        State_None,
        State_Initialized, // 1ブロックも処理していない状態
        State_Processing,  // 1ブロック以上処理した状態
        State_Done
    };

private:
    uint8_t      m_TemporalBlockBuffer[BlockSize];
    uint8_t      m_Tweak[BlockSize];
    uint8_t      m_LastBlock[BlockSize];
    size_t       m_BufferedByte;
    const void*  m_pCipherContext;
    void      (* m_pCipherFunction)(void* outBlock, const void* inBlock, const void* cipherContext);
    State        m_State;
};

// template specialization
template <>
size_t XtsModeImpl::Update<AesEncryptor128>(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) NN_NOEXCEPT;
template <>
size_t XtsModeImpl::Update<AesDecryptor128>(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) NN_NOEXCEPT;

}}} // namespace nn::crypto::detail
