﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <cstring>
#include <nn/nn_Common.h>
#include <nn/nn_SdkAssert.h>
#include <nn/crypto/detail/crypto_XtsModeImpl.h>
#include <nn/crypto/detail/crypto_Clear.h>

namespace nn { namespace crypto { namespace detail {

namespace {

/* GF(2^128) 上での単位元を乗算する関数 */
void GfMult(uint64_t* pData64) NN_NOEXCEPT
{
    // 先にシフトしたときに桁あふれするか調べておく
    uint64_t carry = pData64[1] & 0x8000000000000000ULL;

    // 全体を1ビット左シフト
    pData64[1] = ((pData64[1] & 0xFFFFFFFFFFFFFFFFULL) << 1) | ((pData64[0] & 0xFFFFFFFFFFFFFFFFULL) >> 63);
    pData64[0] = (pData64[0] & 0xFFFFFFFFFFFFFFFFULL) << 1;

    if (carry)
    {
        // 最上位に桁上りがあったら最下位バイトに 0x87 を XOR する
        // a^128 = a^7 + a^2 + a^1 + 1 より
        pData64[0] ^= 0x87ULL;
    }
}

}   // anonymous namespace


XtsModeImpl::~XtsModeImpl() NN_NOEXCEPT
{
    ClearMemory(this, sizeof(*this));
}

size_t XtsModeImpl::ProcessPartialData(uint8_t* pDst8, const uint8_t* pSrc8, size_t srcSize) NN_NOEXCEPT
{
    size_t processed = 0;

    std::memcpy(m_TemporalBlockBuffer + m_BufferedByte, pSrc8, srcSize);
    m_BufferedByte += srcSize;

    if (m_BufferedByte == BlockSize)
    {
        // 1ブロック以上処理済みの場合はまずバッファされているブロックを処理する
        if (m_State == State_Processing)
        {
            ProcessBlock(pDst8, m_LastBlock);
            processed += BlockSize;
        }

        // 埋まった入力は平文のまま LastBlock に保存しておく
        std::memcpy(m_LastBlock, m_TemporalBlockBuffer, BlockSize);
        m_BufferedByte = 0;

        m_State = State_Processing;
    }

    return processed;
}

size_t XtsModeImpl::ProcessRemainingData(uint8_t* pDst8, const uint8_t* pSrc8, size_t srcSize) NN_NOEXCEPT
{
    NN_UNUSED(pDst8);

    std::memcpy(m_TemporalBlockBuffer, pSrc8, srcSize);
    m_BufferedByte = srcSize;

    return 0;
}

size_t XtsModeImpl::FinalizeEncryption(void* pDst, size_t dstSize) NN_NOEXCEPT
{
    NN_UNUSED(dstSize);
    NN_SDK_REQUIRES(m_State == State_Processing, "Invalid state. Please restart from Initialize().");

    uint8_t* pDst8 = static_cast<uint8_t*>(pDst);
    size_t processed = 0;

    // 最後の入力ブロックがブロックサイズの倍数かどうかで処理が変わる
    if (m_BufferedByte == 0)
    {
        // ブロックサイズの倍数の場合は保存してあった平文をそのまま処理して終了
        ProcessBlock(pDst8, m_LastBlock);
        processed = BlockSize;
    }
    else
    {
        // ブロックサイズの倍数ではない場合の暗号化の処理

        // まずは保存済み平文ブロックを暗号化する
        ProcessBlock(m_LastBlock, m_LastBlock);

        // 最後のブロックの末尾を前段の暗号化ブロックの末尾部分でパディングする
        std::memcpy(m_TemporalBlockBuffer + m_BufferedByte, m_LastBlock + m_BufferedByte, BlockSize - m_BufferedByte);

        // それを暗号化して最後から二番目の出力として書き出す
        ProcessBlock(pDst8, m_TemporalBlockBuffer);

        // 最後の出力ブロックは最後から二番目の暗号化ブロックを平文の長さにあわせたものになる
        std::memcpy(pDst8 + BlockSize, m_LastBlock, m_BufferedByte);

        processed = BlockSize + m_BufferedByte;
    }

    m_State = State_Done;
    return processed;
}

size_t XtsModeImpl::FinalizeDecryption(void* pDst, size_t dstSize) NN_NOEXCEPT
{
    NN_UNUSED(dstSize);
    NN_SDK_REQUIRES(m_State == State_Processing, "Invalid state. Please restart from Initialize().");

    uint8_t* pDst8 = static_cast<uint8_t*>(pDst);
    size_t processed = 0;

    // 最後の入力ブロックがブロックサイズの倍数かどうかで処理が変わる
    if (m_BufferedByte == 0)
    {
        // ブロックサイズの倍数の場合は保存してあった平文をそのまま処理して終了
        ProcessBlock(pDst8, m_LastBlock);
        processed = BlockSize;
    }
    else
    {
        // ブロックサイズの倍数ではない場合の復号化の処理

        // 最後から二番目の暗号文ブロックを「最後のブロック用の tweak」で復号化したい
        // そのために一度現在の tweak を保存しておき、インクリメントした上で復号化する
        uint8_t tweak[BlockSize];
        std::memcpy(tweak, m_Tweak, BlockSize);
        GfMult(reinterpret_cast<uint64_t*>(m_Tweak));

        ProcessBlock(m_LastBlock, m_LastBlock);

        // 最後のブロックの末尾を前段の暗号化ブロックの末尾部分でパディングする
        std::memcpy(m_TemporalBlockBuffer + m_BufferedByte, m_LastBlock + m_BufferedByte, BlockSize - m_BufferedByte);

        // それを tweak を巻き戻した上で復号化して最後から二番目の出力として書き出す
        std::memcpy(m_Tweak, tweak, BlockSize);
        ProcessBlock(pDst8, m_TemporalBlockBuffer);

        // 最後の出力ブロックは最後から二番目のブロックを平文の長さにあわせたものになる
        std::memcpy(static_cast<uint8_t*>(pDst) + BlockSize, m_LastBlock, m_BufferedByte);

        processed = BlockSize + m_BufferedByte;
    }

    m_State = State_Done;
    return processed;
}

/* XTS モード1ブロック分の処理 */
void XtsModeImpl::ProcessBlock(uint8_t* pDst8, const uint8_t* pSrc8) NN_NOEXCEPT
{
    uint8_t tmp[BlockSize];

    // 入力に tweak を XOR して
    for (int i = 0; i < static_cast<int>(BlockSize); i++)
    {
        tmp[i] = m_Tweak[i] ^ pSrc8[i];
    }

    // 暗号化/復号化する
    m_pCipherFunction(tmp, tmp, m_pCipherContext);

    // もう一度 tweak を XOR したものが出力になる
    for (int i = 0; i < static_cast<int>(BlockSize); i++)
    {
        pDst8[i] = m_Tweak[i] ^ tmp[i];
    }

    // 次の処理のために tweak を更新する
    GfMult(reinterpret_cast<uint64_t*>(m_Tweak));
}

}}} // namespace nn::crypto::detail
