﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/crypto.h>
#include <nn/fssystem/fs_AesCtrStorage.h>
#include <nn/fssystem/fs_AesCtrCounterExtendedStorage.h>
#include <nn/fssystem/fs_AsynchronousAccess.h>
#include <nn/fssystem/fs_NcaHeader.h>
#include <nn/fssystem/fs_ThreadPriorityChanger.h>
#include <nn/fssystem/fs_Utility.h>
#include <nn/fs/fs_QueryRange.h>

namespace nn { namespace fssystem {

namespace {

// 外部関数で復号化する (HW 復号に使われる)
class ExternalDecryptor : public AesCtrCounterExtendedStorage::IDecryptor
{
public:
    static const size_t BlockSize = AesCtrCounterExtendedStorage::BlockSize;
    static const size_t CounterSize = AesCtrCounterExtendedStorage::CounterSize;
    static const size_t KeySize = AesCtrCounterExtendedStorage::KeySize;

public:
    ExternalDecryptor(
        AesCtrCounterExtendedStorage::DecryptFunction pDecryptFunction,
        int keyIndex
        ) NN_NOEXCEPT;
    virtual ~ExternalDecryptor() NN_NOEXCEPT NN_OVERRIDE;

    virtual void Decrypt(
        void* buffer, size_t size,
        void* encryptedKey, size_t keySize,
        void* pIv, size_t ivSize
        ) NN_NOEXCEPT NN_OVERRIDE;

    virtual bool HasExternalDecryptionKey() const NN_NOEXCEPT NN_OVERRIDE
    {
        return m_KeyIndex < 0;
    }

private:
    AesCtrCounterExtendedStorage::DecryptFunction m_pDecryptFunction;
    int m_KeyIndex;
};

// SW 復号
class SoftwareDecryptor : public AesCtrCounterExtendedStorage::IDecryptor
{
public:
    virtual void Decrypt(
        void* buffer, size_t size,
        void* encryptedKey, size_t keySize,
        void* pIv, size_t ivSize) NN_NOEXCEPT NN_OVERRIDE;

    virtual bool HasExternalDecryptionKey() const NN_NOEXCEPT NN_OVERRIDE
    {
        return false;
    }
};

}

NN_DEFINE_STATIC_CONSTANT(const size_t AesCtrCounterExtendedStorage::BlockSize);
NN_DEFINE_STATIC_CONSTANT(const size_t AesCtrCounterExtendedStorage::KeySize);
NN_DEFINE_STATIC_CONSTANT(const size_t AesCtrCounterExtendedStorage::CounterSize);
NN_DEFINE_STATIC_CONSTANT(const size_t AesCtrCounterExtendedStorage::NodeSize);

NN_STATIC_ASSERT(AesCtrStorage::BlockSize == crypto::Aes128CtrDecryptor::BlockSize);
NN_STATIC_ASSERT(AesCtrCounterExtendedStorage::BlockSize == crypto::Aes128CtrDecryptor::BlockSize);
NN_STATIC_ASSERT(AesCtrCounterExtendedStorage::CounterSize == crypto::Aes128CtrDecryptor::IvSize);

/**
 * @brief   外部関数で復号化する Decryptor を作成
 *
 * @param[out]  ppOutValue  生成されたDecryptorへのポインタ
 * @param[in]   function    復号化関数
 * @param[in]   keyIndex    鍵インデックス
 *
 * @return  関数の実行結果を返します。
 */
Result AesCtrCounterExtendedStorage::CreateExternalDecryptor(
    std::unique_ptr<IDecryptor>* ppOutValue,
    DecryptFunction function, int keyIndex) NN_NOEXCEPT
{
    std::unique_ptr<IDecryptor> pDecryptor(new ExternalDecryptor(function, keyIndex));
    NN_RESULT_THROW_UNLESS(pDecryptor != nullptr,
                           fs::ResultAllocationMemoryFailedInAesCtrCounterExtendedStorageA());

    *ppOutValue = std::move(pDecryptor);
    NN_RESULT_SUCCESS;
}

/**
 * @brief   ソフトウェアで復号化する Decryptor を作成
 *
 * @param[out]  ppOutValue                          生成されたDecryptorへのポインタ
 *
 * @return  関数の実行結果を返します。
 */
Result AesCtrCounterExtendedStorage::CreateSoftwareDecryptor(std::unique_ptr<IDecryptor>* ppOutValue) NN_NOEXCEPT
{
    std::unique_ptr<IDecryptor> pDecryptor(new SoftwareDecryptor());
    NN_RESULT_THROW_UNLESS(pDecryptor != nullptr,
                           fs::ResultAllocationMemoryFailedInAesCtrCounterExtendedStorageB());

    *ppOutValue = std::move(pDecryptor);
    NN_RESULT_SUCCESS;
}

/**
 * @brief   コンストラクタです。
 */
AesCtrCounterExtendedStorage::AesCtrCounterExtendedStorage() NN_NOEXCEPT
    : m_Table()
    , m_DataStorage()
    , m_SecureValue(0)
    , m_CounterOffset(0)
    , m_pDecryptor()
{
}

/**
 * @brief   初期化をします。（テスト向け）
 *
 * @param[in]   pAllocator      アロケータのポインタ
 * @param[in]   pKey            復号に使用する鍵の先頭アドレス
 * @param[in]   keySize         復号に使用する鍵のサイズ
 * @param[in]   secureValue     復号に使用するセキュア数値
 * @param[in]   dataStorage     暗号化したデータを格納したストレージ
 * @param[in]   tableStorage    カウンタのテーブルを格納したストレージ
 *
 * @return  関数の実行結果を返します。
 *
 * @pre
 *      - pAllocator != nullptr
 *      - pKey != nullptr
 *      - keySize == KeySize
 *      - 未初期化
 */
Result AesCtrCounterExtendedStorage::Initialize(
                                         IAllocator* pAllocator,
                                         const void* pKey,
                                         size_t keySize,
                                         uint32_t secureValue,
                                         fs::SubStorage dataStorage,
                                         fs::SubStorage tableStorage
                                     ) NN_NOEXCEPT
{
    BucketTree::Header header;
    NN_RESULT_DO(tableStorage.Read(0, &header, sizeof(header)));
    NN_RESULT_DO(header.Verify());

    const auto nodeStorageSize = QueryNodeStorageSize(header.entryCount);
    const auto entryStorageSize = QueryEntryStorageSize(header.entryCount);
    const auto nodeStorageOffset = QueryHeaderStorageSize();
    const auto entryStorageOffset = nodeStorageOffset + nodeStorageSize;

    std::unique_ptr<IDecryptor> pDecryptor;
    NN_RESULT_DO(CreateSoftwareDecryptor(&pDecryptor));

    // 引数の事前検証はこっちに任せる
    return Initialize(
               pAllocator,
               pKey,
               keySize,
               secureValue,
               0,
               dataStorage,
               fs::SubStorage(&tableStorage, nodeStorageOffset, nodeStorageSize),
               fs::SubStorage(&tableStorage, entryStorageOffset, entryStorageSize),
               header.entryCount,
               std::move(pDecryptor)
           );
}

/**
 * @brief   初期化をします。
 *
 * @param[in]   pAllocator      アロケータのポインタ
 * @param[in]   pKey            復号に使用する鍵の先頭アドレス
 * @param[in]   keySize         復号に使用する鍵のサイズ
 * @param[in]   secureValue     復号に使用するセキュア数値
 * @param[in]   counterOffset   復号に使用するカウンタのオフセット
 * @param[in]   dataStorage     暗号化したデータを格納したストレージ
 * @param[in]   nodeStorage     テーブル検索のための木構造を格納したノードストレージ
 * @param[in]   entryStorage    カウンタのテーブルデータを格納したストレージ
 * @param[in]   entryCount      テーブルデータのエントリの総数
 * @param[in]   pDecryptor      AesCtr の復号に使う関数
 *
 * @return  関数の実行結果を返します。
 *
 * @pre
 *      - pAllocator != nullptr
 *      - pKey != nullptr
 *      - keySize == KeySize
 *      - 0 <= counterOffset
 *      - pDecryptor != nullptr
 *      - 未初期化
 */
Result AesCtrCounterExtendedStorage::Initialize(
                                         IAllocator* pAllocator,
                                         const void* pKey,
                                         size_t keySize,
                                         uint32_t secureValue,
                                         int64_t counterOffset,
                                         fs::SubStorage dataStorage,
                                         fs::SubStorage nodeStorage,
                                         fs::SubStorage entryStorage,
                                         int entryCount,
                                         std::unique_ptr<IDecryptor>&& pDecryptor
                                     ) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(pKey);
    NN_SDK_REQUIRES_EQUAL(keySize, KeySize);
    NN_SDK_REQUIRES_GREATER_EQUAL(counterOffset, 0);
    NN_SDK_REQUIRES_NOT_NULL(pDecryptor);
    NN_UNUSED(keySize);
    // NOTE: pKey, keySize 以外の事前検証は BucketTree に任せる

    NN_RESULT_DO(
        m_Table.Initialize(
            pAllocator, nodeStorage, entryStorage, NodeSize, sizeof(Entry), entryCount
        )
    );

    m_DataStorage = dataStorage;
    std::memcpy(m_Key, pKey, KeySize);
    m_SecureValue = secureValue;
    m_CounterOffset = counterOffset;
    m_pDecryptor = std::move(pDecryptor);

    NN_RESULT_SUCCESS;
}

/**
 * @brief   終了処理をします。
 */
void AesCtrCounterExtendedStorage::Finalize() NN_NOEXCEPT
{
    if( IsInitialized() )
    {
        m_Table.Finalize();
        m_DataStorage = fs::SubStorage();
    }
}

/**
 * @brief   読み込みを行います。
 *
 * @param[in]   offset  読み込むオフセット
 * @param[out]  buffer  読み込んだデータを格納するバッファ
 * @param[in]   size    読み込むサイズ
 *
 * @return  関数の実行結果を返します。
 *
 * @pre
 *      - 0 <= offset
 *      - is_aligned(offset, BlockSize) != false
 *      - size == 0 || buffer != nullptr
 *      - is_aligned(size, BlockSize) != false
 *      - 初期化済み
 */
Result AesCtrCounterExtendedStorage::Read(int64_t offset, void* buffer, size_t size) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_LESS_EQUAL(0, offset);
    NN_SDK_REQUIRES(IsInitialized());

    if( size == 0 )
    {
        NN_RESULT_SUCCESS;
    }
    NN_RESULT_THROW_UNLESS(buffer != nullptr, fs::ResultNullptrArgument());
    NN_RESULT_THROW_UNLESS(util::is_aligned(offset, BlockSize), fs::ResultInvalidOffset());
    NN_RESULT_THROW_UNLESS(util::is_aligned(size, BlockSize), fs::ResultInvalidSize());
    NN_RESULT_THROW_UNLESS(m_Table.IsInclude(offset, size), fs::ResultOutOfRange());

    // 暗号化したデータを読み込む
    NN_RESULT_DO(m_DataStorage.Read(offset, buffer, size));

    ScopedThreadPriorityChanger changePriority(+1, ScopedThreadPriorityChanger::Mode::Relative);

    // Find() で offset < 0 を弾くので、0 <= offset は保障される
    BucketTree::Visitor visitor;
    NN_RESULT_DO(m_Table.Find(&visitor, offset));
    {
        const auto entryOffset = visitor.Get<Entry>()->GetOffset();
        NN_RESULT_THROW_UNLESS(
            util::is_aligned(entryOffset, BlockSize) &&
                0 <= entryOffset && m_Table.IsInclude(entryOffset),
            fs::ResultInvalidAesCtrCounterExtendedEntryOffset()
        );
    }

    util::BytePtr data(buffer);
    const auto endOffset = offset + static_cast<int64_t>(size);
    auto readOffset = offset;

    while( readOffset < endOffset )
    {
        const auto entry = *visitor.Get<Entry>();

        // entry1Offset(readOffset) が不正
        const auto entry1Offset = entry.GetOffset();
        NN_RESULT_THROW_UNLESS(
            entry1Offset <= readOffset,
            fs::ResultInvalidAesCtrCounterExtendedEntryOffset()
        );

        int64_t entry2Offset;
        if( visitor.CanMoveNext() )
        {
            NN_RESULT_DO(visitor.MoveNext());

            entry2Offset = visitor.Get<Entry>()->GetOffset();

            NN_RESULT_THROW_UNLESS(
                m_Table.IsInclude(entry2Offset),
                fs::ResultInvalidAesCtrCounterExtendedEntryOffset()
            );
        }
        else
        {
            entry2Offset = m_Table.GetEnd();
        }

        // entry2Offset が不正
        NN_RESULT_THROW_UNLESS(
            util::is_aligned(entry2Offset, BlockSize) && readOffset < entry2Offset,
            fs::ResultInvalidAesCtrCounterExtendedEntryOffset()
        );
        // この時点で 0 <= entry1Offset <= readOffset < entry2Offset が保証される

        const auto dataOffset = readOffset - entry1Offset;
        const auto dataSize = (entry2Offset - entry1Offset) - dataOffset;
        NN_SDK_ASSERT_LESS(0, dataSize);

        const auto remainingSize = endOffset - readOffset; // <= size <= size_t(-1)
        // 32bit 環境で size_t(-1) < dataSize になっても問題ない（↑が保証されるので）
        const auto readSize = static_cast<size_t>(std::min(remainingSize, dataSize));
        NN_SDK_ASSERT_LESS_EQUAL(readSize, size);

        const auto counterOffset = m_CounterOffset + entry1Offset + dataOffset;

        NcaAesCtrUpperIv upperIv;
        upperIv.part.generation = entry.generation;
        upperIv.part.secureValue = m_SecureValue;

        // カウンタ生成
        char counter[CounterSize];
        AesCtrStorage::MakeIv(counter, CounterSize, upperIv.value, counterOffset);

        // 復号する
        m_pDecryptor->Decrypt(
            data.Get(), readSize,
            m_Key, KeySize,
            counter, CounterSize);

        data.Advance(readSize);
        readOffset += readSize;
    }

    NN_RESULT_SUCCESS;
}

/**
* @brief       範囲指定処理を行います。
*
* @param[out]  outBuffer        範囲指定処理の結果を格納するバッファ
* @param[in]   outBufferSize    範囲指定処理の結果を格納するバッファのサイズ
* @param[in]   operationId      範囲指定処理の種類
* @param[in]   offset           範囲指定処理開始位置
* @param[in]   size             範囲指定処理を行うデータサイズ
* @param[in]   inBuffer         範囲指定処理に渡すバッファ
* @param[in]   inBufferSize     範囲指定処理に渡すバッファのサイズ
*
* @return      関数の処理結果を返します。
*/
Result AesCtrCounterExtendedStorage::OperateRange(
                                         void* outBuffer,
                                         size_t outBufferSize,
                                         fs::OperationId operationId,
                                         int64_t offset,
                                         int64_t size,
                                         const void* inBuffer,
                                         size_t inBufferSize
                                     ) NN_NOEXCEPT
{
    switch( operationId )
    {
    case fs::OperationId::Invalidate:
        {
            NN_SDK_REQUIRES_LESS_EQUAL(0, offset);
            NN_SDK_REQUIRES(IsInitialized());

            if( size == 0 )
            {
                NN_RESULT_SUCCESS;
            }
            NN_RESULT_THROW_UNLESS(util::is_aligned(offset, BlockSize), fs::ResultInvalidOffset());
            NN_RESULT_THROW_UNLESS(util::is_aligned(size, BlockSize), fs::ResultInvalidSize());
            NN_RESULT_THROW_UNLESS(m_Table.IsInclude(offset, size), fs::ResultOutOfRange());

            NN_RESULT_DO(m_Table.InvalidateCache());
            NN_RESULT_DO(m_DataStorage.OperateRange(
                outBuffer, outBufferSize, operationId, offset, size, inBuffer, inBufferSize));
            NN_RESULT_SUCCESS;
        }

    case nn::fs::OperationId::QueryRange:
        {
            NN_SDK_REQUIRES_LESS_EQUAL(0, offset);
            NN_SDK_REQUIRES(IsInitialized());

            NN_RESULT_THROW_UNLESS(outBuffer != nullptr, fs::ResultNullptrArgument());
            NN_RESULT_THROW_UNLESS(outBufferSize == sizeof(nn::fs::QueryRangeInfo), fs::ResultInvalidSize());

            if( size == 0 )
            {
                reinterpret_cast<nn::fs::QueryRangeInfo*>(outBuffer)->Clear();
                NN_RESULT_SUCCESS;
            }

            NN_RESULT_THROW_UNLESS(util::is_aligned(offset, BlockSize), fs::ResultInvalidOffset());
            NN_RESULT_THROW_UNLESS(util::is_aligned(size, BlockSize), fs::ResultInvalidSize());
            NN_RESULT_THROW_UNLESS(m_Table.IsInclude(offset, size), fs::ResultOutOfRange());

            NN_RESULT_DO(m_DataStorage.OperateRange(
                outBuffer, outBufferSize, operationId, offset, size, inBuffer, inBufferSize));

            nn::fs::QueryRangeInfo info;
            info.Clear();
            info.aesCtrKeyTypeFlag = static_cast<int32_t>(
                m_pDecryptor->HasExternalDecryptionKey()
                ? nn::fs::AesCtrKeyTypeFlag::ExternalKeyForHardwareAes
                : nn::fs::AesCtrKeyTypeFlag::InternalKeyForHardwareAes);

            reinterpret_cast<nn::fs::QueryRangeInfo*>(outBuffer)->Merge(info);

            NN_RESULT_SUCCESS;
        }

    default:
        return nn::fs::ResultUnsupportedOperation();
    }
}

NN_DEFINE_STATIC_CONSTANT(const size_t ExternalDecryptor::BlockSize);
NN_DEFINE_STATIC_CONSTANT(const size_t ExternalDecryptor::CounterSize);
NN_DEFINE_STATIC_CONSTANT(const size_t ExternalDecryptor::KeySize) NN_IS_UNUSED_MEMBER;

// 外部関数で復号化する (HW 復号に使われる)
ExternalDecryptor::ExternalDecryptor(
    AesCtrCounterExtendedStorage::DecryptFunction pDecryptFunction,
    int keyIndex
    ) NN_NOEXCEPT
    : m_pDecryptFunction(pDecryptFunction)
    , m_KeyIndex(keyIndex)
{
    NN_SDK_REQUIRES_NOT_NULL(pDecryptFunction);
}

ExternalDecryptor::~ExternalDecryptor() NN_NOEXCEPT
{
}

void ExternalDecryptor::Decrypt(
    void* buffer, size_t size,
    void* encryptedKey, size_t keySize,
    void* pIv, size_t ivSize
    ) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(buffer);
    NN_SDK_REQUIRES_NOT_NULL(encryptedKey);
    NN_SDK_REQUIRES(keySize == KeySize);
    NN_SDK_REQUIRES_NOT_NULL(pIv);
    NN_SDK_REQUIRES(ivSize == CounterSize);
    NN_UNUSED(ivSize);

    char counter[CounterSize];
    std::memcpy(counter, pIv, CounterSize);

    size_t restSize = size;
    int64_t currentOffset = 0;

    // TORIAEZU: sf の制限で InBuffer と OutBuffer に同じバッファを指定できないためワークバッファを通す
    PooledBuffer pooledBuffer;
    pooledBuffer.AllocateParticularlyLarge(size, BlockSize);

    // 割り当てできたサイズは 0 より大きければよい (size だけ取れている必要はない)
    // また、 BlockSize の倍数である必要がある
    NN_SDK_ASSERT(pooledBuffer.GetSize() > 0
                  && nn::util::is_aligned(pooledBuffer.GetSize(), BlockSize));

    while (restSize > 0)
    {
        size_t readSize = std::min(pooledBuffer.GetSize(), restSize);
        char* dstBuffer = static_cast<char*>(buffer) + currentOffset;

        m_pDecryptFunction(
            pooledBuffer.GetBuffer(),
            readSize,
            m_KeyIndex,
            encryptedKey,
            keySize,
            counter,
            CounterSize,
            dstBuffer,
            readSize
            );

        // TODO: size <= pooledBuffer.GetSize() の場合は pooledBuffer に Read すれば memcpy を無くせる
        memcpy(dstBuffer, pooledBuffer.GetBuffer(), readSize);

        currentOffset += readSize;
        restSize -= readSize;

        if (restSize > 0)
        {
            AddCounter(counter, CounterSize, readSize / BlockSize);
        }
    }
}

// SW 復号
void SoftwareDecryptor::Decrypt(
    void* buffer, size_t size,
    void* encryptedKey, size_t keySize,
    void* pIv, size_t ivSize) NN_NOEXCEPT
{
    crypto::DecryptAes128Ctr(
        buffer, size, encryptedKey, keySize, pIv, ivSize, buffer, size);
}


}}
