﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <cstring>
#include <algorithm>
#include <nn/nn_Common.h>
#include <nn/nn_SdkAssert.h>
#include <nn/nn_Abort.h>
#include <nn/crypto/detail/crypto_Clear.h>
#include "../crypto_AesDecryptor.h"
#include "../crypto_AesEncryptor.h"

namespace nn { namespace crypto { namespace detail {

/*
 * CBC モード暗号化の本体実装
 */
template <typename BlockCipher>
inline void EncryptBlocks(void* pDst, void* pIv, const void* pSrc, int numBlocks, const BlockCipher* pEncryptor) NN_NOEXCEPT
{
    const uint8_t* pSrc8 = static_cast<const uint8_t*>(pSrc);
    uint8_t* pDst8 = static_cast<uint8_t*>(pDst);
    uint8_t* pIv8 = static_cast<uint8_t*>(pIv);
    uint8_t  tmp[BlockCipher::BlockSize];

    while (numBlocks--)
    {
        /* IV と入力の XOR を取り */
        for (int i = 0; i < static_cast<int>(BlockCipher::BlockSize); i++)
        {
            tmp[i] = pSrc8[i] ^ pIv8[i];
        }

        /* 暗号化する */
        pEncryptor->EncryptBlock(pDst8, BlockCipher::BlockSize, tmp, BlockCipher::BlockSize);

        /* 暗号結果が次の IV となる */
        pIv8 = pDst8;
        pSrc8 += BlockCipher::BlockSize;
        pDst8 += BlockCipher::BlockSize;
    }

    /* 次の IV を保存する */
    std::memcpy(pIv, pIv8, BlockCipher::BlockSize);
}

/*
 * CBC モード復号化の本体実装
 */
template <typename BlockCipher>
inline void DecryptBlocks(void* pDst, void* pIv, const void* pSrc, int numBlocks, const BlockCipher* pDecryptor) NN_NOEXCEPT
{
    const uint8_t* pSrc8 = static_cast<const uint8_t*>(pSrc);
    const uint8_t* pIv8 = static_cast<const uint8_t*>(pIv);
    uint8_t* pDst8 = static_cast<uint8_t*>(pDst);
    uint8_t  tmp[BlockCipher::BlockSize];
    uint8_t  nextIv[BlockCipher::BlockSize];

    if (pDst == pSrc)
    {
        /* 入出力バッファが同一の場合の処理。
           復号化によって次の IV に相当する暗号文が書き換わるので退避が必要になる */

        uint8_t  currIv[BlockCipher::BlockSize];

        while (numBlocks--)
        {
            /* 復号化によって暗号文ブロックが書き換わる可能性があるので処理前に退避しておく */
            std::memcpy(currIv, pIv8, BlockCipher::BlockSize);
            std::memcpy(nextIv, pSrc8, BlockCipher::BlockSize);

            /* 復号化して */
            pDecryptor->DecryptBlock(tmp, BlockCipher::BlockSize, pSrc8, BlockCipher::BlockSize);

            /* IV と入力の XOR を取る */
            for (int i = 0; i < static_cast<int>(BlockCipher::BlockSize); i++)
            {
                pDst8[i] = tmp[i] ^ currIv[i];
            }

            /* 暗号文が次の IV となる */
            pIv8 = nextIv;
            pSrc8 += BlockCipher::BlockSize;
            pDst8 += BlockCipher::BlockSize;
        }

        /* 退避しておいた IV を保存する */
        std::memcpy(pIv, nextIv, BlockCipher::BlockSize);
    }
    else
    {
        /* 入出力バッファが異なる場合の処理 */

        /* 最後の暗号文ブロックが次の IV になるので、処理前に退避しておく */
        std::memcpy(nextIv, pSrc8 + ((numBlocks - 1) * BlockCipher::BlockSize), BlockCipher::BlockSize);

        while (numBlocks--)
        {
            /* 復号化して */
            pDecryptor->DecryptBlock(tmp, BlockCipher::BlockSize, pSrc8, BlockCipher::BlockSize);

            /* IV と入力の XOR を取る */
            for (int i = 0; i < static_cast<int>(BlockCipher::BlockSize); i++)
            {
                pDst8[i] = tmp[i] ^ pIv8[i];
            }

            /* 暗号文が次の IV となる */
            pIv8 = pSrc8;
            pSrc8 += BlockCipher::BlockSize;
            pDst8 += BlockCipher::BlockSize;
        }

        /* 退避しておいた IV を保存する */
        std::memcpy(pIv, nextIv, BlockCipher::BlockSize);
    }
}


// Optimized encryption & decryption for AES-128
#ifdef NN_BUILD_CONFIG_CPU_ARM_V8A

struct CbcModeAes128Helper
{
    static void DecryptBlocks(void* pDst, void* pIv, const void* pSrc, int numBlocks, const AesDecryptor<16>* pDecryptor) NN_NOEXCEPT;
    static void EncryptBlocks(void* pDst, void* pIv, const void* pSrc, int numBlocks, const AesEncryptor<16>* pEncryptor) NN_NOEXCEPT;
};

inline void EncryptBlocks(void* pDst, void* pIv, const void* pSrc, int numBlocks, const AesEncryptor128* pEncryptor) NN_NOEXCEPT
{
    CbcModeAes128Helper::EncryptBlocks(pDst, pIv, pSrc, numBlocks, pEncryptor);
}

inline void DecryptBlocks(void* pDst, void* pIv, const void* pSrc, int numBlocks, const AesDecryptor128* pDecryptor) NN_NOEXCEPT
{
    CbcModeAes128Helper::DecryptBlocks(pDst, pIv, pSrc, numBlocks, pDecryptor);
}

#endif


template <typename BlockCipher>
class CbcModeImpl
{
public:
    static const int BlockSize = BlockCipher::BlockSize;
    static const int IvSize = BlockSize;

public:
    CbcModeImpl() NN_NOEXCEPT : m_State(State_None) {}
    ~CbcModeImpl() NN_NOEXCEPT
    {
        ClearMemory(this, sizeof(*this));
    }

    void   Initialize(const BlockCipher* pBlockCipher, const void* pIv, size_t ivSize) NN_NOEXCEPT;
    size_t EncryptUpdate(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) NN_NOEXCEPT;
    size_t DecryptUpdate(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) NN_NOEXCEPT;
    size_t GetBufferedDataSize() const NN_NOEXCEPT
    {
        return m_BufferedByte;
    }

private:
    template <typename T>
    size_t Update(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize, const T* p) NN_NOEXCEPT;

private:
    enum State
    {
        State_None,
        State_Initialized,
    };

    class Encryptor
    {
    public:
        explicit Encryptor(const BlockCipher* p) NN_NOEXCEPT : m_Encryptor(p) {}

        void ProcessBlocks(void* pDst, void* pIv, const void* pSrc, int numBlocks) const NN_NOEXCEPT
        {
            EncryptBlocks(pDst, pIv, pSrc, numBlocks, m_Encryptor);
        }
    private:
        const BlockCipher* m_Encryptor;
    };

    class Decryptor
    {
    public:
        explicit Decryptor(const BlockCipher* p) NN_NOEXCEPT : m_Decryptor(p) {}

        void ProcessBlocks(void* pDst, void* pIv, const void* pSrc, int numBlocks) const NN_NOEXCEPT
        {
            DecryptBlocks(pDst, pIv, pSrc, numBlocks, m_Decryptor);
        }
    private:
        const BlockCipher* m_Decryptor;
    };

private:
    const BlockCipher* m_pBlockCipher;
    uint8_t            m_Iv[IvSize];
    uint8_t            m_TemporalBlockBuffer[BlockSize];
    size_t             m_BufferedByte;
    State              m_State;
};

template <typename BlockCipher>
inline void CbcModeImpl<BlockCipher>::Initialize(const BlockCipher* pBlockCipher, const void* pIv, size_t ivSize) NN_NOEXCEPT
{
    NN_SDK_REQUIRES(ivSize == IvSize);

    m_pBlockCipher = pBlockCipher;
    std::memcpy(m_Iv, pIv, ivSize);
    m_BufferedByte = 0;

    m_State = State_Initialized;
}

template <typename BlockCipher>
inline size_t CbcModeImpl<BlockCipher>::EncryptUpdate(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) NN_NOEXCEPT
{
    NN_ABORT_UNLESS(dstSize >= ((srcSize + GetBufferedDataSize()) / BlockSize) * BlockSize, "Precondition does not met.");
    NN_SDK_REQUIRES(m_State == State_Initialized, "Invalid state. Please restart from Initialize().");

    Encryptor encryptor(m_pBlockCipher);

    return Update(pDst, dstSize, pSrc, srcSize, &encryptor);
}

template <typename BlockCipher>
inline size_t CbcModeImpl<BlockCipher>::DecryptUpdate(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) NN_NOEXCEPT
{
    NN_ABORT_UNLESS(dstSize >= ((srcSize + GetBufferedDataSize()) / BlockSize) * BlockSize, "Precondition does not met.");
    NN_SDK_REQUIRES(m_State == State_Initialized, "Invalid state. Please restart from Initialize().");

    Decryptor decryptor(m_pBlockCipher);

    return Update(pDst, dstSize, pSrc, srcSize, &decryptor);
}

template <typename BlockCipher> template <typename T>
inline size_t CbcModeImpl<BlockCipher>::Update(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize, const T* pCipherModeOperator) NN_NOEXCEPT
{
    NN_UNUSED(dstSize);

    const uint8_t* pSrc8 = static_cast<const uint8_t*>(pSrc);
    uint8_t* pDst8 = static_cast<uint8_t*>(pDst);
    size_t remaining = srcSize;
    size_t processed = 0;

    /* 前の処理の残りがあったら1ブロックに到達するかデータが無くなるまで埋める */
    if (m_BufferedByte > 0)
    {
        size_t fillSize = std::min(BlockSize - m_BufferedByte, remaining);

        std::memcpy(m_TemporalBlockBuffer + m_BufferedByte, pSrc8, fillSize);
        pSrc8 += fillSize;
        remaining -= fillSize;
        m_BufferedByte += fillSize;

        if (m_BufferedByte == BlockSize)
        {
            pCipherModeOperator->ProcessBlocks(pDst, m_Iv, m_TemporalBlockBuffer, 1);
            processed += BlockSize;
            pDst8 += BlockSize;
            m_BufferedByte = 0;
        }
    }

    /* ブロックサイズ以上の残りがある場合はブロックごとに処理 */
    if (remaining >= BlockSize)
    {
        int numBlocks = static_cast<int>(remaining / BlockSize);

        pCipherModeOperator->ProcessBlocks(pDst8, m_Iv, pSrc8, numBlocks);
        pSrc8 += numBlocks * BlockSize;
        remaining -= numBlocks * BlockSize;
        processed += numBlocks * BlockSize;
    }

    /* ブロックサイズ以下の端数が出たら次の処理のために保存しておく */
    if (remaining > 0)
    {
        m_BufferedByte = remaining;
        std::memcpy(m_TemporalBlockBuffer, pSrc8, remaining);
    }

    return processed;
}

}}} // namespace nn::crypto::detail

