﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include "FileDownloader.h"

#include <nn/nn_Log.h>
#include <nn/nn_Assert.h>
#include <nn/ssl/ssl_Context.h>
#include <nn/util/util_ScopeExit.h>
#include <atomic>

// curl API を実行し、失敗したら CURLcode を返します。
#define CURL_EASY_DO(exp) \
    do                          \
    {                           \
        CURLcode code_ = (exp); \
                                \
        if (code_ != CURLE_OK)  \
        {                       \
            return code_;       \
        }                       \
    }                           \
    while (NN_STATIC_CONDITION(false))

namespace
{
    std::atomic<uint64_t> g_RequestIdCounter(0);
}

FileDownloader::FileDownloader(int sessionCountMax) NN_NOEXCEPT
    : m_SessionCount(0)
    , m_SessionCountMax(sessionCountMax)
    , m_CancelEvent(nn::os::EventClearMode_ManualClear)
    , m_pCallback(nullptr)
    , m_pCallbackParam(nullptr)
{
    NN_ASSERT_GREATER(sessionCountMax, 0);
}

void FileDownloader::AddRequest(uint64_t* pOutRequestId, const char* pUrl, void* pBuffer, size_t size)
{
    NN_ASSERT_NOT_NULL(pOutRequestId);
    NN_ASSERT_NOT_NULL(pUrl);
    NN_ASSERT_NOT_NULL(pBuffer);
    NN_ASSERT_GREATER(size, 0u);

    RequestItem request = {};

    request.requestId = ++g_RequestIdCounter;

    request.pUrl = pUrl;
    request.pBuffer = pBuffer;
    request.size = size;

    m_RequestQueue.push(request);

    *pOutRequestId = request.requestId;
}

void FileDownloader::Execute()
{
    if (m_RequestQueue.empty())
    {
        return;
    }

    CURLM* pMulti = curl_multi_init();

    if (!pMulti)
    {
        ClearRequests();
        return;
    }

    NN_UTIL_SCOPE_EXIT
    {
        ClearRequests();
        ClearSessions(pMulti);

        curl_multi_cleanup(pMulti);
    };

    if (DequeueRequests(pMulti) != CURLE_OK)
    {
        return;
    }

    int stillRunning = 0;

    curl_multi_perform(pMulti, &stillRunning);

    do
    {
        int numFds = 0;
        CURLMcode code = curl_multi_wait(pMulti, nullptr, 0, 10, &numFds);

        if (code != CURLM_OK)
        {
            return;
        }
        if (numFds == 0)
        {
            nn::os::SleepThread(nn::TimeSpan::FromMilliSeconds(5));
        }

        curl_multi_perform(pMulti, &stillRunning);

        if (m_CancelEvent.TryWait())
        {
            return;
        }

        int completionCount = HandleCompletion(pMulti);

        m_SessionCount -= completionCount;

        NN_ASSERT_GREATER_EQUAL(m_SessionCount, 0);

        if (m_SessionCount != m_SessionCountMax)
        {
            if (DequeueRequests(pMulti) != CURLE_OK)
            {
                return;
            }
            if (completionCount == m_SessionCountMax)
            {
                curl_multi_perform(pMulti, &stillRunning);
            }
        }
    }
    while (stillRunning > 0);
}

void FileDownloader::Cancel() NN_NOEXCEPT
{
    m_CancelEvent.Signal();
}

void FileDownloader::SetCompletionCallback(CompletionCallback pCallback, void* pParam) NN_NOEXCEPT
{
    m_pCallback = pCallback;
    m_pCallbackParam = pParam;
}

void FileDownloader::ClearRequests()
{
    while (!m_RequestQueue.empty())
    {
        RequestItem request = m_RequestQueue.front();

        m_RequestQueue.pop();

        if (m_pCallback)
        {
            m_pCallback(request.requestId, CURLE_ABORTED_BY_CALLBACK, 0, request.pBuffer, 0, m_pCallbackParam);
        }
    }
}

void FileDownloader::ClearSessions(CURLM* pMulti)
{
    for (SessionItem& session : m_SessionList)
    {
        curl_multi_remove_handle(pMulti, session.pHandle);
        curl_easy_cleanup(session.pHandle);

        if (m_pCallback)
        {
            m_pCallback(session.requestId, CURLE_ABORTED_BY_CALLBACK, 0, session.pBuffer, 0, m_pCallbackParam);
        }
    }

    m_SessionList.clear();
}

CURLcode FileDownloader::DequeueRequests(CURLM* pMulti)
{
    if (m_SessionCount == m_SessionCountMax)
    {
        return CURLE_OK;
    }

    while (!m_RequestQueue.empty())
    {
        RequestItem request = m_RequestQueue.front();

        m_RequestQueue.pop();

        SessionItem initial = {};
        m_SessionList.push_back(initial);

        // list<T> の T を SessionItem* にしないで済むよう、先に push_back した後そのポインタを取得しています。
        SessionItem* pSession = &m_SessionList.back();

        CURLcode curlCode = CreateSession(pSession, request);

        if (curlCode != CURLE_OK)
        {
            m_SessionList.pop_back();
            return curlCode;
        }

        CURLMcode curlmCode = curl_multi_add_handle(pMulti, pSession->pHandle);

        if (curlmCode != CURLM_OK)
        {
            if (curlmCode == CURLM_OUT_OF_MEMORY)
            {
                return CURLE_OUT_OF_MEMORY;
            }
            else
            {
                return CURLE_FAILED_INIT;
            }
        }

        if (++m_SessionCount == m_SessionCountMax)
        {
            break;
        }
    }

    return CURLE_OK;
}

CURLcode FileDownloader::CreateSession(SessionItem* pOutSession, const RequestItem& request) NN_NOEXCEPT
{
    CURL* pCurl = curl_easy_init();

    if (!pCurl)
    {
        return CURLE_OUT_OF_MEMORY;
    }

    NN_UTIL_SCOPE_EXIT
    {
        if (pCurl)
        {
            curl_easy_cleanup(pCurl);
        }
    };

    CURL_EASY_DO(curl_easy_setopt(pCurl, CURLOPT_URL, request.pUrl));

    // 接続タイムアウトを設定します。
    CURL_EASY_DO(curl_easy_setopt(pCurl, CURLOPT_CONNECTTIMEOUT, 30));

    // 1 B/s 以下の転送速度の通信が 30 秒以上続いたら切断するようにします。
    CURL_EASY_DO(curl_easy_setopt(pCurl, CURLOPT_LOW_SPEED_LIMIT, 1));
    CURL_EASY_DO(curl_easy_setopt(pCurl, CURLOPT_LOW_SPEED_TIME, 30));

    // 上限サイズを指定します。
    CURL_EASY_DO(curl_easy_setopt(pCurl, CURLOPT_MAXFILESIZE, static_cast<long>(request.size)));

    CURL_EASY_DO(curl_easy_setopt(pCurl, CURLOPT_WRITEFUNCTION, HttpWriteFunction));
    CURL_EASY_DO(curl_easy_setopt(pCurl, CURLOPT_WRITEDATA, pOutSession));
    CURL_EASY_DO(curl_easy_setopt(pCurl, CURLOPT_SSL_CTX_FUNCTION, SslCtxFunction));

    pOutSession->requestId = request.requestId;
    pOutSession->pBuffer = request.pBuffer;
    pOutSession->size = request.size;
    pOutSession->downloaded = 0;

    pOutSession->pHandle = pCurl;
    pCurl = nullptr;

    return CURLE_OK;
}

int FileDownloader::HandleCompletion(CURLM* pMulti)
{
    int count = 0;

    CURLMsg* pMsg = nullptr;

    do
    {
        int num = 0;
        pMsg = curl_multi_info_read(pMulti, &num);

        if (pMsg && (pMsg->msg == CURLMSG_DONE))
        {
            long statusCode = 0;
            curl_easy_getinfo(pMsg->easy_handle, CURLINFO_RESPONSE_CODE, &statusCode);

            for (SessionItem& session : m_SessionList)
            {
                if (session.pHandle == pMsg->easy_handle)
                {
                    if (m_pCallback)
                    {
                        m_pCallback(session.requestId, pMsg->data.result, statusCode, session.pBuffer, session.downloaded, m_pCallbackParam);
                    }
                    break;
                }
            }

            m_SessionList.remove_if([&](SessionItem& session) {return session.pHandle == pMsg->easy_handle;});

            curl_multi_remove_handle(pMulti, pMsg->easy_handle);
            curl_easy_cleanup(pMsg->easy_handle);

            count++;
        }
    }
    while (pMsg);

    return count;
}

CURLcode FileDownloader::SslCtxFunction(CURL* pCurl, void* pSsl, void* pParam) NN_NOEXCEPT
{
    NN_UNUSED(pCurl);
    NN_UNUSED(pParam);

    nn::ssl::Context* pContext = reinterpret_cast<nn::ssl::Context*>(pSsl);

    if (pContext->Create(nn::ssl::Context::SslVersion_Auto).IsFailure())
    {
        return CURLE_ABORTED_BY_CALLBACK;
    }

    return CURLE_OK;
}

size_t FileDownloader::HttpWriteFunction(char* pBuffer, size_t size, size_t count, void* pParam) NN_NOEXCEPT
{
    SessionItem* pSession = reinterpret_cast<SessionItem*>(pParam);

    size_t bufferSize = size * count;

    if (pSession->downloaded + bufferSize > pSession->size)
    {
        return 0;
    }

    nn::Bit8* pReceiveBuffer = static_cast<nn::Bit8*>(pSession->pBuffer);

    std::memcpy(&pReceiveBuffer[pSession->downloaded], pBuffer, bufferSize);
    pSession->downloaded += bufferSize;

    return bufferSize;
}
