﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <cstdlib>
#include <memory>
#include <mutex>
#include <nn/nn_SdkAssert.h>
#include <nn/crypto/crypto_Sha256Generator.h>
#include <nn/es.h>
#include <nn/ncm/ncm_ContentIdUtil.h>
#include <nn/ncm/ncm_ContentMetaDatabase.h>
#include <nn/ncm/ncm_ContentStorage.h>
#include <nn/ncm/ncm_Service.h>
#include <nn/nim/nim_Result.h>
#include <nn/nim/detail/nim_Log.h>
#include <nn/nim/srv/nim_LocalCommunicationDeliveryProtocol.h>
#include <nn/nim/srv/nim_LocalCommunicationSendContentTaskBase.h>
#include <nn/result/result_HandlingUtility.h>
#include <nn/socket.h>

#include "nim_DebugUtil.h"

namespace nn { namespace nim { namespace srv {
namespace {

Result HandleLocalCommunicationResult(Result result) NN_NOEXCEPT
{
    NN_RESULT_TRY(result)
        NN_RESULT_CATCH_ALL
        {
            // TODO: セッションが切れたとか、想定外の通信を行ったなどで Result を変える
            NN_RESULT_RETHROW;
        }
    NN_RESULT_END_TRY;

    NN_RESULT_SUCCESS;
}

// コンテンツメタ自身のハッシュ計算に使われる想定
Result CalculateContentMetaHash(ncm::Hash* outValue, void* buffer, size_t bufferSize, const ncm::ContentId& contentId, ncm::StorageId storageId) NN_NOEXCEPT
{
    NN_SDK_ASSERT_LESS_EQUAL(static_cast<int64_t>(bufferSize), INT64_MAX);
    ncm::ContentStorage storage;
    NN_RESULT_DO(ncm::OpenContentStorage(&storage, storageId));

    crypto::Sha256Generator sha256;
    sha256.Initialize();
    int64_t size;
    NN_RESULT_DO(storage.GetSize(&size, contentId));
    int64_t offset = 0;
    while (offset < size)
    {
        // size - offset が選択される時は、bufferSize よりは小さいので、キャストしても問題ない
        size_t writeSize = ((size - offset) > static_cast<int64_t>(bufferSize)) ? bufferSize : static_cast<size_t>(size - offset);
        NN_RESULT_DO(storage.ReadContentIdFile(buffer, writeSize, contentId, offset));
        sha256.Update(buffer, writeSize);
        offset += writeSize;
    }
    sha256.GetHash(outValue->data, sizeof(outValue->data));

    NN_RESULT_SUCCESS;
}


}

    Result LocalCommunicationSendContentTaskBase::Initialize(uint32_t ipv4, uint16_t port) NN_NOEXCEPT
    {
        NN_RESULT_DO(m_Connection.Initialize(ipv4, port));

        NN_RESULT_SUCCESS;
    }

    void LocalCommunicationSendContentTaskBase::Cancel() NN_NOEXCEPT
    {
        m_Connection.Cancel();
        std::lock_guard<os::Mutex> guard(m_CancelMutex);
        m_CancelRequested = true;
    }

    void LocalCommunicationSendContentTaskBase::ResetCancel() NN_NOEXCEPT
    {
        m_Connection.ResetCancel();
        std::lock_guard<os::Mutex> guard(m_CancelMutex);
        m_CancelRequested = false;
    }

    bool LocalCommunicationSendContentTaskBase::IsCancelRequested() const NN_NOEXCEPT
    {
        std::lock_guard<os::Mutex> guard(m_CancelMutex);
        return m_CancelRequested;
    }

    Result LocalCommunicationSendContentTaskBase::SetAndThrowLastResult(Result result) NN_NOEXCEPT
    {
        if (result.IsSuccess())
        {
            NN_RESULT_SUCCESS;
        }

        std::lock_guard<os::Mutex> guard(m_ProgressMutex);
        m_LastResult = result;
        return result;
    }

    Result LocalCommunicationSendContentTaskBase::Execute() NN_NOEXCEPT
    {
        NN_RESULT_DO(SetAndThrowLastResult(ExecuteImpl()));
        NN_RESULT_SUCCESS;
    }

    Result LocalCommunicationSendContentTaskBase::PrepareAndExecute() NN_NOEXCEPT
    {
        NN_RESULT_DO(SetAndThrowLastResult(PrepareImpl()));

        NN_RESULT_DO(SetAndThrowLastResult(ExecuteImpl()));

        NN_RESULT_SUCCESS;
    }

    Result LocalCommunicationSendContentTaskBase::PrepareImpl() NN_NOEXCEPT
    {
        NN_RESULT_DO(GetLocalCommunicationErrorResultForDebug());

        /* TORIAEZU:　socket を使って適当に通信する */
        NN_RESULT_DO(m_Connection.WaitClient());
        NN_RESULT_SUCCESS;
    }

    Result LocalCommunicationSendContentTaskBase::ExecuteImpl() NN_NOEXCEPT
    {
        /* TORIAEZU:　socket を使って適当に通信する */
        bool needsProcess = true;
        while (needsProcess)
        {
            LocalCommunicationDeliveryProtocolHeader header;
            NN_RESULT_DO(HandleLocalCommunicationResult(m_Connection.ReceiveHeader(&header)));

            /* TORIAEZU:　socket を使って適当に通信する */
            switch (header.tag)
            {
            case LocalCommunicationDeliveryProtocolTag::GetPackagedContentInfoTag:
                {
                    ncm::ContentMetaKey key;
                    NN_RESULT_THROW_UNLESS(header.size == sizeof(key), ResultLocalCommunicationInvalidDataSize());

                    NN_RESULT_DO(m_Connection.ReceiveData(&key, sizeof(key), header.size));

                    ncm::PackagedContentInfo info;
                    NN_RESULT_DO(GetPackagedContentInfo(&info, key));

                    auto sendHeader = MakeLocalCommunicationDeliveryProtocolResponseHeader(header.tag, sizeof(info));
                    NN_RESULT_DO(m_Connection.Send(sendHeader, &info, sizeof(info)));
                }
                break;
            case LocalCommunicationDeliveryProtocolTag::GetContentTag:
                {
                    LocalCommunicationContentInfo info;
                    NN_RESULT_THROW_UNLESS(header.size == sizeof(info), ResultLocalCommunicationInvalidDataSize());

                    NN_RESULT_DO(m_Connection.ReceiveData(&info, sizeof(info), sizeof(info)));

                    ncm::StorageId storageId;
                    NN_RESULT_DO(GetStorage(&storageId, info.contentId));

                    ncm::ContentStorage storage;
                    NN_RESULT_DO(ncm::OpenContentStorage(&storage, storageId));

                    int64_t contentLength;
                    NN_RESULT_DO(storage.GetSize(&contentLength, info.contentId));
                    auto sendContentFunc = [&](size_t writeSize, int64_t offset) -> Result
                    {
                        NN_RESULT_DO(storage.ReadContentIdFile(m_Buffer, writeSize, info.contentId, offset));
                        if (!info.isTemporary)
                        {
                            AddProgress(writeSize);
                        }

                        NN_RESULT_SUCCESS;
                    };

                    auto sendHeader = MakeLocalCommunicationDeliveryProtocolResponseHeader(header.tag, contentLength);

                    NN_RESULT_DO(m_Connection.Send(sendHeader, m_Buffer, m_BufferSize, sendContentFunc));
                }
                break;
    #if defined(NN_BUILD_CONFIG_OS_HORIZON)
            case LocalCommunicationDeliveryProtocolTag::GetCommonTicketTag:
                {
                    es::RightsIdIncludingKeyId rightsId;
                    NN_RESULT_THROW_UNLESS(header.size == sizeof(rightsId), ResultLocalCommunicationUnexpectedTag());
                    NN_RESULT_DO(m_Connection.ReceiveData(&rightsId, sizeof(rightsId), sizeof(rightsId)));

                    size_t ticketSize;
                    size_t certSize;
                    NN_RESULT_DO(es::GetCommonTicketAndCertificateSize(&ticketSize, &certSize, rightsId));

                    // TORIAEZU: ヒープから領域を取る
                    std::unique_ptr<Bit8> commonTicketBuffer(new Bit8[ticketSize]);
                    NN_RESULT_THROW_UNLESS(commonTicketBuffer, ResultBufferNotEnough());
                    std::unique_ptr<Bit8> commonCertificateBuffer(new Bit8[certSize]);
                    NN_RESULT_THROW_UNLESS(commonCertificateBuffer, ResultBufferNotEnough());

                    size_t outTicketSize;
                    size_t outCertSize;
                    NN_RESULT_DO(es::GetCommonTicketAndCertificateData(&outTicketSize, &outCertSize, commonTicketBuffer.get(), commonCertificateBuffer.get(), ticketSize, certSize, rightsId));

                    LocalCommunicationDeliveryProtocolGetCommonTicketResponseHeader responseHeader;
                    responseHeader.ticketSize = outTicketSize;
                    responseHeader.certificateSize = outCertSize;

                    auto sendHeader = MakeLocalCommunicationDeliveryProtocolResponseHeader(header.tag, sizeof(responseHeader) + responseHeader.ticketSize + responseHeader.certificateSize);
                    NN_RESULT_DO(m_Connection.SendHeader(sendHeader));

                    NN_RESULT_DO(m_Connection.SendData(&responseHeader, sizeof(responseHeader), sizeof(responseHeader)));

                    NN_RESULT_DO(m_Connection.SendData(commonTicketBuffer.get(), responseHeader.ticketSize, responseHeader.ticketSize));
                    NN_RESULT_DO(m_Connection.SendData(commonCertificateBuffer.get(), responseHeader.certificateSize, responseHeader.certificateSize));
                }
                break;
    #endif
            case LocalCommunicationDeliveryProtocolTag::SendTotalSizeTag:
                {
                    int64_t totalSize;
                    NN_RESULT_THROW_UNLESS(header.size == sizeof(totalSize), ResultLocalCommunicationUnexpectedTag());
                    NN_RESULT_DO(m_Connection.ReceiveData(&totalSize, sizeof(totalSize), sizeof(totalSize)));
                    SetTotalSize(totalSize);

                    auto sendHeader = MakeLocalCommunicationDeliveryProtocolResponseHeader(header.tag, 0);
                    NN_RESULT_DO(m_Connection.SendHeader(sendHeader));
                }
                break;

            case LocalCommunicationDeliveryProtocolTag::EndSessionTag:
                needsProcess = false;
                break;
            default:
                NN_RESULT_THROW(ResultLocalCommunicationUnexpectedTag());
            }
        }

        NN_RESULT_SUCCESS;
    }

    // m_Buffer を使うので注意
    Result LocalCommunicationSendContentTaskBase::GetPackagedContentInfoImpl(ncm::PackagedContentInfo* outValue, const ncm::ContentMetaKey& key, ncm::StorageId storageId) NN_NOEXCEPT
    {
        ncm::ContentMetaDatabase db;
        NN_RESULT_DO(ncm::OpenContentMetaDatabase(&db, storageId));

        int offset = 0;
        while (NN_STATIC_CONDITION(true))
        {
            const int ListCount = 16;
            ncm::ContentInfo listInfo[ListCount];

            int count;
            NN_RESULT_DO(db.ListContentInfo(&count, listInfo, ListCount, key, offset));
            for (int i = 0; i < count; i++)
            {
                if (listInfo[i].type == ncm::ContentType::Meta)
                {
                    outValue->info = listInfo[i];
                    // コンテンツメタ自身のハッシュ値はコンテンツメタ内には存在しないので、計算する必要がある
                    NN_RESULT_DO(CalculateContentMetaHash(&outValue->hash, m_Buffer, m_BufferSize, listInfo[i].GetId(), storageId));
                    NN_RESULT_SUCCESS;
                }
            }

            if (count < ListCount)
            {
                break;
            }
            offset += count;
        }

        NN_RESULT_THROW(nim::ResultContentNotFound());
    }

    void LocalCommunicationSendContentTaskBase::AddProgress(int64_t progress) NN_NOEXCEPT
    {
        std::lock_guard<os::Mutex> lock(m_ProgressMutex);

        m_Progress.sentSize += progress;
    }


    LocalCommunicationSendContentProgress LocalCommunicationSendContentTaskBase::GetProgress() const NN_NOEXCEPT
    {
        std::lock_guard<os::Mutex> lock(m_ProgressMutex);
        LocalCommunicationSendContentProgress progress = m_Progress;
        std::memcpy(&progress.lastResult, &m_LastResult, sizeof(progress.lastResult));

        return progress;
    }

    void LocalCommunicationSendContentTaskBase::SetTotalSize(int64_t totalSize) NN_NOEXCEPT
    {
        std::lock_guard<os::Mutex> lock(m_ProgressMutex);
        m_Progress.totalSize = totalSize;
    }
}}}
