﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include "testNim_Nas.h"

#include <cctype>
#include <memory>
#include <curl/curl.h>

#include <nn/nn_SdkAssert.h>
#define RAPIDJSON_NO_INT64DEFINE
#define RAPIDJSON_NAMESPACE             nne::rapidjson
#define RAPIDJSON_NAMESPACE_BEGIN       namespace nne { namespace rapidjson {
#define RAPIDJSON_NAMESPACE_END         }}
#define RAPIDJSON_ASSERT(x)             NN_SDK_ASSERT(x)
#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 1 //NOLINT(preprocessor/const)
#define RAPIDJSON_HAS_CXX11_TYPETRAITS  1 //NOLINT(preprocessor/const)
#if defined(NN_BUILD_CONFIG_OS_WIN32)
#pragma warning(push)
#pragma warning(disable : 4244)
#pragma warning(disable : 4668)
#pragma warning(disable : 4702)
#endif
#include <rapidjson/document.h>
#if defined(NN_BUILD_CONFIG_OS_WIN32)
#pragma warning(pop)
#endif

#include <nn/nn_Abort.h>
#include <nn/nn_Common.h>
#include <nn/nn_Log.h>
#include <nn/nsd/nsd_ApiForMenu.h>
#include <nn/util/util_FormatString.h>
#include <nn/util/util_ScopeExit.h>

#include "testNim_Http.h"

#define NNT_ACCOUNT_NAS_HELPER_URI "https://hikarie.nntt.mng.nintendo.net/api/v1/nas"

namespace accounts {
namespace {

class NasAuthorizationRequestParser final
    : public UriParser<255 + 1>
{
private:
    UriLookupEntry m_ClientIdEntry;
    UriLookupEntry m_RedirectUriEntry;
    UriLookupEntry m_StateEntry;
    UriLookupEntry m_CodeChallengeEntry;
    UriLookupEntry m_ScopeEntry;

    char m_Host[256];

    uint64_t m_ClientId;
    char m_RedirectUri[256];
    char m_State[128];
    char m_CodeChallenge[256];
    char m_Scope[256];

    inline uint8_t ConvertHexToInteger(char c) NN_NOEXCEPT
    {
        return std::isdigit(c)
            ? static_cast<uint8_t>(c - '0')
            : static_cast<uint8_t>(std::tolower(c) - 'a' + 10);
    }
    template <typename T>
    inline T ExtractHexadecimal(const char* pStr, size_t length) NN_NOEXCEPT
    {
        NN_UNUSED(length);
        T out = 0;
        for (auto i = 0u; i < sizeof(T) * 2; ++ i)
        {
            out = (out << 4 | ConvertHexToInteger(*(pStr ++)));
        }
        return out;
    }

    virtual void UpdateImpl(const char* key, size_t keyLength, const char* value, size_t valueLength) NN_NOEXCEPT NN_OVERRIDE
    {
        if (m_ClientIdEntry.CanAccept(key, keyLength) && valueLength == sizeof(uint64_t) * 2)
        {
            auto clientId = ExtractHexadecimal<uint64_t>(value, valueLength);
            m_ClientId = clientId;
            m_ClientIdEntry.MarkAccepted();
        }
        else if (m_RedirectUriEntry.CanAccept(key, keyLength) && valueLength < sizeof(m_RedirectUri))
        {
            strncpy(m_RedirectUri, value, valueLength);
            m_RedirectUri[valueLength] = '\0';
            m_RedirectUriEntry.MarkAccepted();
        }
        else if (m_StateEntry.CanAccept(key, keyLength) && valueLength < sizeof(m_State))
        {
            strncpy(m_State, value, valueLength);
            m_State[valueLength] = '\0';
            m_StateEntry.MarkAccepted();
        }
        else if (m_CodeChallengeEntry.CanAccept(key, keyLength) && valueLength < sizeof(m_CodeChallenge))
        {
            strncpy(m_CodeChallenge, value, valueLength);
            m_CodeChallenge[valueLength] = '\0';
            m_CodeChallengeEntry.MarkAccepted();
        }
        else if (m_ScopeEntry.CanAccept(key, keyLength) && valueLength < sizeof(m_Scope))
        {
            std::strncpy(m_Scope, value, valueLength);
            m_Scope[valueLength] = '\0';
            m_ScopeEntry.MarkAccepted();
        }
    }
public:
    explicit NasAuthorizationRequestParser(const char* host) NN_NOEXCEPT
        : UriParser<255 + 1>(m_Host)
        , m_ClientIdEntry("client_id")
        , m_RedirectUriEntry("redirect_uri")
        , m_StateEntry("state")
        , m_CodeChallengeEntry("code_challenge")
        , m_ScopeEntry("scope")
    {
        strncpy(m_Host, host, sizeof(m_Host));
    }

    uint64_t GetClientId() const NN_NOEXCEPT
    {
        NN_ABORT_UNLESS(m_ClientIdEntry);
        return m_ClientId;
    }
    const char* GetRedirectUri() const NN_NOEXCEPT
    {
        NN_ABORT_UNLESS(m_RedirectUriEntry);
        return m_RedirectUri;
    }
    const char* GetState() const NN_NOEXCEPT
    {
        NN_ABORT_UNLESS(m_StateEntry);
        return m_State;
    }
    const char* GetCodeChallenge() const NN_NOEXCEPT
    {
        NN_ABORT_UNLESS(m_CodeChallengeEntry);
        return m_CodeChallenge;
    }
    const char* GetScope() const NN_NOEXCEPT
    {
        NN_ABORT_UNLESS(m_ScopeEntry);
        return m_Scope;
    }
};

void DecodeRedirectUri(char* decoded, size_t decodedSize, const char *redirect)
{
    auto N = strlen(redirect) + 1;
    int count = 0;
    for (size_t i = 0u; i < N && redirect[i] != '\0'; ++ i)
    {
        auto c = redirect[i];
        if (c == '%')
        {
            auto c1 = redirect[i + 1];
            auto c2 = redirect[i + 2];
            if (c1 == '3' && c2 == 'A') { decoded[count ++] = ':'; }
            else if (c1 == '2' && c2 == 'F') { decoded[count ++] = '/'; }
            else { NN_ABORT("Unexpected escaped charactor: %%%c%c\n", c1, c2); }
            i += 2;
        }
        else
        {
            decoded[count ++] = c;
        }
    }
    decoded[count ++] = '\0';
}

} // ~namespace accounts::<anonymous>

bool GetAuthorizationViaNasHelper(Buffer& response, const char* id, const char* password, const char* url, bool verbose) NN_NOEXCEPT
{
    // libcurl
    curl_global_init(CURL_GLOBAL_DEFAULT);
    NN_UTIL_SCOPE_EXIT
    {
        curl_global_cleanup();
    };

    auto curl = curl_easy_init();
    NN_ABORT_UNLESS(curl != nullptr);
    NN_UTIL_SCOPE_EXIT
    {
        curl_easy_cleanup(curl);
    };

    // パラメータの抽出
    auto paramString = std::strchr(url, '?');
    auto length = static_cast<int>(paramString - url);
    char host[256];
    NN_ABORT_UNLESS(std::extent<decltype(host)>::value > length);
    std::strncpy(host, url, length);
    host[length] = '\0';

    std::unique_ptr<NasAuthorizationRequestParser> pParser(new NasAuthorizationRequestParser(host));
    ACCOUNTS_FAILURE_UNLESS(pParser.get()->Parse(url, strlen(url) + 1));

    // ID, パスワードのエンコード
    Buffer encodedId(1024);
    Buffer encodedPassword(1024);
    auto l = UrlEncode(encodedId.Get<char>(), encodedId.GetSize(), id, strlen(id) + 1);
    NN_ABORT_UNLESS(l < static_cast<int>(encodedId.GetSize()));
    l = UrlEncode(encodedPassword.Get<char>(), encodedPassword.GetSize(), password, strlen(password) + 1);
    NN_ABORT_UNLESS(l < static_cast<int>(encodedPassword.GetSize()));

    // NAS Helper へのリクエスト作成 + 実行
    Buffer nasHelperRequest(4096);
    nn::nsd::EnvironmentIdentifier envId;
    nn::nsd::GetEnvironmentIdentifier(&envId);
    l = nn::util::SNPrintf(
        nasHelperRequest.Get<char>(), nasHelperRequest.GetSize(),
        "response_type=code+id_token&code_challenge_method=S256"
        "&email=%s&password=%s"
        "&environment=%s&client_id=%016llx&redirect_uri=%s"
        "&scope=%s&state=%s&code_challenge=%s",
        encodedId.Get<char>(), encodedPassword.Get<char>(),
        envId.value, pParser.get()->GetClientId(), pParser.get()->GetRedirectUri(),
        pParser.get()->GetScope(), pParser.get()->GetState(), pParser.get()->GetCodeChallenge());
    NN_ABORT_UNLESS(l < static_cast<int>(nasHelperRequest.GetSize()));

    Buffer nasHelperResponse(4096 * 2);
    SimpleDownloader dl(curl);
    dl.Initialize(nasHelperResponse.Get<char>(), nasHelperResponse.GetSize());
    dl.Setup(NNT_ACCOUNT_NAS_HELPER_URI "/authorize", nasHelperRequest.Get<char>(), verbose);
    size_t nasHelperResponseSize;
    auto curlCode = dl.Invoke(&nasHelperResponseSize);
    if (curlCode != CURLE_OK)
    {
        NN_LOG("[accounts] CURL Error: %d\n", curlCode);
        return false;
    }
    auto httpCode = dl.GetHttpCode();
    if (httpCode != 200)
    {
        NN_LOG("[accounts] HTTP Error: %ld\n", httpCode);
        return false;
    }

    if (!(nasHelperResponseSize < nasHelperResponse.GetSize()))
    {
        NN_LOG("[accounts] Invalid response size: %zu\n", nasHelperResponseSize);
        return false;
    }
    nasHelperResponse.Get<char>()[nasHelperResponseSize] = '\0';

    // 応答の解析
    nne::rapidjson::Document document;
    document.ParseInsitu(nasHelperResponse.Get<char>());
    if (document.HasParseError())
    {
#define RAPIDJSON_ERROR(code) \
        case nne::rapidjson::kParseError ## code: \
            NN_LOG("[accounts] JSON Error: %s\n", # code); \
            break

        switch (document.GetParseError())
        {
        RAPIDJSON_ERROR(DocumentEmpty);
        RAPIDJSON_ERROR(DocumentRootNotSingular);
        RAPIDJSON_ERROR(ValueInvalid);
        RAPIDJSON_ERROR(ObjectMissName);
        RAPIDJSON_ERROR(ObjectMissColon);
        RAPIDJSON_ERROR(ObjectMissCommaOrCurlyBracket);
        RAPIDJSON_ERROR(ArrayMissCommaOrSquareBracket);
        RAPIDJSON_ERROR(StringUnicodeEscapeInvalidHex);
        RAPIDJSON_ERROR(StringUnicodeSurrogateInvalid);
        RAPIDJSON_ERROR(StringEscapeInvalid);
        RAPIDJSON_ERROR(StringMissQuotationMark);
        RAPIDJSON_ERROR(StringInvalidEncoding);
        RAPIDJSON_ERROR(NumberTooBig);
        RAPIDJSON_ERROR(NumberMissFraction);
        RAPIDJSON_ERROR(NumberMissExponent);
        RAPIDJSON_ERROR(Termination);
        RAPIDJSON_ERROR(UnspecificSyntaxError);
        default:
            NN_LOG("[accounts] JSON Unknown Error: %d\n", document.GetParseError());
        }

#undef RAPIDJSON_ERROR
        return false;
    }
    if (!(true
        && document.IsObject()
        && document.HasMember("state")
        && document.HasMember("code")
        && document.HasMember("id_token")))
    {
        NN_LOG("[accounts] Failed to authN/authZ for Nintendo Account\n");
        return false;
    }

    // state
    auto& eState = document["state"];
    NN_ABORT_UNLESS(eState.IsString());
    NN_ABORT_UNLESS(std::strncmp(pParser.get()->GetState(), eState.GetString(), strlen(pParser.get()->GetState()) + 1) == 0);
    // code
    auto& eCode = document["code"];
    NN_ABORT_UNLESS(eCode.IsString());
    auto codeLength = static_cast<int>(eCode.GetStringLength());
    NN_ABORT_UNLESS(codeLength <= nn::account::NintendoAccountAuthorizationCodeLengthMax);
    // id_token
    auto& eIdToken = document["id_token"];
    NN_ABORT_UNLESS(eIdToken.IsString());
    auto idTokenLength = static_cast<int>(eIdToken.GetStringLength());
    NN_ABORT_UNLESS(idTokenLength <= nn::account::NintendoAccountIdTokenLengthMax);

    // レスポンスの作成
    char redirectUri[256];
    DecodeRedirectUri(redirectUri, sizeof(redirectUri), pParser.get()->GetRedirectUri());
    l = nn::util::SNPrintf(
        response.Get<char>(), response.GetSize(),
        "%s?expires_in=0&state=%s&code=%s&id_token=%s",
        redirectUri, pParser.get()->GetState(), eCode.GetString(), eIdToken.GetString());
    NN_ABORT_UNLESS(static_cast<uint32_t>(l) < response.GetSize());

    return true;
} // NOLINT(readability/fn_size)

} // ~namespace accounts
