﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <mutex>
#include <nn/nn_Common.h>
#include <nn/nn_SdkLog.h>
#include <nn/fs/fs_Result.h>
#include <nn/fs/fs_ResultPrivate.h>
#include <nn/os/os_Mutex.h>
#include <nn/result/result_HandlingUtility.h>

#include <nn/fs/detail/fs_Newable.h>
#include <nn/fs/fs_IStorage.h>
#include <nn/fssystem/fs_Assert.h>
#include <nn/fssystem/fs_AesXtsStorage.h>
#include <nn/fssystem/fs_AsynchronousAccess.h>
#include <nn/fssystem/fs_ThreadPriorityChanger.h>
#include <nn/fssystem/fs_Utility.h>

#include <nn/crypto/crypto_Aes128XtsEncryptor.h>
#include <nn/crypto/crypto_Aes128XtsDecryptor.h>
#include <nn/util/util_BitUtil.h>

namespace nn { namespace fssystem {

NN_DEFINE_STATIC_CONSTANT(const size_t AesXtsStorage::AesBlockSize);
NN_DEFINE_STATIC_CONSTANT(const size_t AesXtsStorage::KeySize);
NN_DEFINE_STATIC_CONSTANT(const size_t AesXtsStorage::IvSize);


AesXtsStorage::AesXtsStorage(
    fs::IStorage* pBaseStorage,
    const void* pKey1,
    const void* pKey2,
    size_t keySize,
    const void* pIv,
    size_t ivSize,
    size_t blockSize
) NN_NOEXCEPT
    : m_pBaseStorage(pBaseStorage),
      m_BlockSize(blockSize),
      m_Mutex(false)
{
    NN_SDK_REQUIRES(keySize == KeySize);
    NN_SDK_REQUIRES(ivSize  == IvSize);
    NN_SDK_REQUIRES(nn::util::is_aligned(blockSize, AesBlockSize));
    NN_UNUSED(keySize);
    NN_UNUSED(ivSize);

    std::memcpy(m_Key[0], pKey1, KeySize);
    std::memcpy(m_Key[1], pKey2, KeySize);
    std::memcpy(m_Iv,  pIv,  IvSize);
}

Result AesXtsStorage::Read(int64_t offset, void* buffer, size_t size) NN_NOEXCEPT
{
    if( size == 0 )
    {
        NN_RESULT_SUCCESS;
    }
    NN_RESULT_THROW_UNLESS(buffer != nullptr, fs::ResultNullptrArgument());

    // KeySize 単位アクセスしかこない前提
    NN_FSP_REQUIRES(nn::util::is_aligned(offset, KeySize), nn::fs::ResultInvalidArgument());
    NN_FSP_REQUIRES(nn::util::is_aligned(size,   KeySize), nn::fs::ResultInvalidArgument());

    NN_RESULT_DO(m_pBaseStorage->Read(offset, buffer, size));

    ScopedThreadPriorityChanger changeThreadPriority(+1, ScopedThreadPriorityChanger::Mode::Relative);

    // ブロックインデックス
    char counter[IvSize];
    memcpy(counter, m_Iv, IvSize);
    AddCounter(counter, IvSize, offset / m_BlockSize);

    size_t doneSize = 0;
    if (offset % m_BlockSize != 0)
    {
        // m_BlockSize アラインでない head 部分の処理
        size_t skipSize = static_cast<size_t>(offset - nn::util::align_down(offset, m_BlockSize));
        size_t headSize = std::min(size, m_BlockSize - skipSize);

        {
            // 途中地点の tweak からの処理に Aes128XtsDecryptor が対応していないため、ワークバッファで作業する
            PooledBuffer tmpBuffer(m_BlockSize, m_BlockSize);
            NN_SDK_ASSERT(tmpBuffer.GetSize() >= m_BlockSize);

            memset(tmpBuffer.GetBuffer(), 0x00, skipSize); // padding
            memcpy(tmpBuffer.GetBuffer() + skipSize, buffer, headSize);

            auto decryptedSize = crypto::DecryptAes128Xts(tmpBuffer.GetBuffer(), m_BlockSize, m_Key[0], m_Key[1], KeySize, counter, IvSize, tmpBuffer.GetBuffer(), m_BlockSize);
            NN_SDK_ASSERT(decryptedSize == m_BlockSize);
            NN_UNUSED(decryptedSize);

            memcpy(buffer, tmpBuffer.GetBuffer() + skipSize, headSize);
        }

        AddCounter(counter, IvSize, 1);
        doneSize += headSize;
        NN_SDK_ASSERT(doneSize == std::min(size, m_BlockSize - skipSize));
    }

    char* srcDstBuffer = static_cast<char*>(buffer) + doneSize;
    size_t restSize = size - doneSize;
    while(restSize > 0)
    {
        auto decryptSize = std::min(m_BlockSize, restSize);
        auto decryptedSize = crypto::DecryptAes128Xts(srcDstBuffer, m_BlockSize, m_Key[0], m_Key[1], KeySize, counter, IvSize, srcDstBuffer, decryptSize);
        NN_ABORT_UNLESS(decryptedSize == decryptSize);

        restSize     -= decryptedSize;
        srcDstBuffer += decryptedSize;

        AddCounter(counter, IvSize, 1);
    }

    NN_RESULT_SUCCESS;
}

Result AesXtsStorage::Write(int64_t offset, const void* buffer, size_t size) NN_NOEXCEPT
{
    if( size == 0 )
    {
        NN_RESULT_SUCCESS;
    }
    NN_RESULT_THROW_UNLESS(buffer != nullptr, fs::ResultNullptrArgument());

    // KeySize 単位アクセスしかこない前提
    NN_FSP_REQUIRES(nn::util::is_aligned(offset, KeySize), nn::fs::ResultInvalidArgument());
    NN_FSP_REQUIRES(nn::util::is_aligned(size,   KeySize), nn::fs::ResultInvalidArgument());

    if( buffer == nullptr && size == 0 )
    {
        NN_RESULT_SUCCESS;
    }

    const auto useWorkBuffer = !IsPooledBuffer(buffer);
    PooledBuffer pooledBuffer;
    if( useWorkBuffer )
    {
        pooledBuffer.Allocate(size, m_BlockSize);
    }

    char counter[IvSize];
    memcpy(counter, m_Iv, IvSize);
    AddCounter(counter, IvSize, offset / m_BlockSize);

    size_t doneSize = 0;
    if (offset % m_BlockSize != 0)
    {
        // m_BlockSize アラインでない head 部分の処理
        size_t skipSize = static_cast<size_t>(offset - nn::util::align_down(offset, m_BlockSize));
        size_t headSize = std::min(size, m_BlockSize - skipSize);

        crypto::Aes128XtsEncryptor aesXts;
        aesXts.Initialize(m_Key[0], m_Key[1], KeySize, counter, IvSize);

        {
            // 途中地点の tweak からの処理に Aes128XtsEncryptor が対応していないため、ワークバッファで作業する
            const auto tmpBufferSize = m_BlockSize;
            PooledBuffer tmpBuffer(tmpBufferSize, tmpBufferSize);
            NN_SDK_ASSERT(tmpBuffer.GetSize() >= tmpBufferSize);

            memset(tmpBuffer.GetBuffer(), 0x00, skipSize); // padding
            memcpy(tmpBuffer.GetBuffer() + skipSize, buffer, headSize);

            auto encryptedSize = crypto::EncryptAes128Xts(tmpBuffer.GetBuffer(), m_BlockSize, m_Key[0], m_Key[1], KeySize, counter, IvSize, tmpBuffer.GetBuffer(), m_BlockSize);
            NN_SDK_ASSERT(encryptedSize == m_BlockSize);
            NN_UNUSED(encryptedSize);

            NN_RESULT_DO(m_pBaseStorage->Write(offset, tmpBuffer.GetBuffer() + skipSize, headSize));
        }

        AddCounter(counter, IvSize, 1);
        doneSize += headSize;
        NN_SDK_ASSERT(doneSize == std::min(size, m_BlockSize - skipSize));
    }

    size_t restSize = size - doneSize;
    int64_t currentOffset = offset + doneSize;
    while(restSize > 0)
    {
        size_t writeSize = useWorkBuffer ? std::min(pooledBuffer.GetSize(), restSize) : restSize;

        {
            ScopedThreadPriorityChanger changeThreadPriority(+1, ScopedThreadPriorityChanger::Mode::Relative);

            size_t restEncryptSize = writeSize;
            size_t encryptOffset = 0;

            while( restEncryptSize > 0 )
            {
                size_t encryptSize = std::min(restEncryptSize, m_BlockSize);
                auto srcBuffer = static_cast<const char*>(buffer) + doneSize + encryptOffset;
                auto dstBuffer = useWorkBuffer ? pooledBuffer.GetBuffer() + encryptOffset : const_cast<char*>(srcBuffer);
                auto encryptedSize = crypto::EncryptAes128Xts(dstBuffer, encryptSize, m_Key[0], m_Key[1], KeySize, counter, IvSize, srcBuffer, encryptSize);
                NN_ABORT_UNLESS(encryptedSize == encryptSize);

                AddCounter(counter, IvSize, 1);

                encryptOffset += encryptedSize;
                restEncryptSize -= encryptedSize;
            }
        }

        const char* writeBuffer = useWorkBuffer ? pooledBuffer.GetBuffer() : static_cast<const char*>(buffer) + doneSize;
        NN_RESULT_DO(m_pBaseStorage->Write(currentOffset, writeBuffer, writeSize));

        currentOffset += writeSize;
        doneSize      += writeSize;
        restSize      -= writeSize;
    }

    NN_RESULT_SUCCESS;
}

Result AesXtsStorage::Flush() NN_NOEXCEPT
{
    return m_pBaseStorage->Flush();
}

Result AesXtsStorage::SetSize(int64_t size) NN_NOEXCEPT
{
    // KeySize 単位アクセスしかこない前提
    NN_ABORT_UNLESS(nn::util::is_aligned(size, KeySize));

    return m_pBaseStorage->SetSize(size);
}

Result AesXtsStorage::GetSize(int64_t* outValue) NN_NOEXCEPT
{
    return m_pBaseStorage->GetSize(outValue);
}

Result AesXtsStorage::OperateRange(
                          void* outBuffer,
                          size_t outBufferSize,
                          fs::OperationId operationId,
                          int64_t offset,
                          int64_t size,
                          const void* inBuffer,
                          size_t inBufferSize
                      ) NN_NOEXCEPT
{
    if( size == 0 )
    {
        NN_RESULT_SUCCESS;
    }

    // KeySize 単位アクセスしかこない前提
    NN_FSP_REQUIRES(nn::util::is_aligned(offset, KeySize), nn::fs::ResultInvalidArgument());
    NN_FSP_REQUIRES(nn::util::is_aligned(size,   KeySize), nn::fs::ResultInvalidArgument());

    return m_pBaseStorage->OperateRange(
        outBuffer, outBufferSize, operationId, offset, size, inBuffer, inBufferSize);
}

}}
