﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <cstring>
#include <cstdlib>

#include <nn/os.h>
#include <nn/init.h>
#include <nn/nn_Log.h>
#include <nn/nn_Assert.h>
#include <nn/lmem/lmem_ExpHeap.h>
#include <nnt/nntest.h>
#include <nn/socket.h>
#include <nn/ssl.h>
#include <curl/curl.h>

#include <Common/testCommonUtil.h>
#include <Common/testInfraInfo.h>
#include <Common/testServerPki.h>
#include <Common/testClientPki.h>
#include <Common/testClientPkiNoPwd.h>

#if defined (NN_BUILD_CONFIG_OS_WIN)
#define MY_SNPRINTF _snprintf
#else
#define MY_SNPRINTF snprintf
#endif

namespace
{
SslTestCommonUtil        g_CommonUtil;
NN_ALIGNAS(4096) uint8_t g_SocketMemoryPoolBuffer[nn::socket::DefaultSocketMemoryPoolSize];
const uint32_t           HostNameLength = 128;
const uint32_t           ArgOptionMessageLength = 128;
const uint32_t           ProxyStrBuffLength = 256;

enum ArgOptionType
{
    ArgOptionType_Command = 0x00,
    ArgOptionType_Message,
    ArgOptionType_Max
};

enum ArgOption
{
    ArgOption_Help = 0x0,
    ArgOption_Host,
    ArgOption_Port,
    ArgOption_DisableServerValidation,
    ArgOption_DisableNameValidation,
    ArgOption_ServerCert,
    ArgOption_ClientCert,
    ArgOption_ProxyAddress,
    ArgOption_ProxyPort,
    ArgOption_ProxyUserName,
    ArgOption_ProxyPassword,
    ArgOption_Max
};

typedef struct ConfigParams
{
    char        pHostName[HostNameLength];
    uint32_t    portNumber;
    bool        isPerformServerValidation;
    bool        isPerformNameValidation;
    const char* pServerCert;
    const char* pClientPki;
    const char* pPassword;
    uint32_t    serverCertSize;
    uint32_t    clientPkiSize;
    uint32_t    passwordSize;
    char        proxyAddress[ProxyStrBuffLength];
    uint16_t    proxyPort;
    char        proxyUserName[ProxyStrBuffLength];
    char        proxyPassword[ProxyStrBuffLength];

} ConfigParams;

static const char g_ArgOptions[ArgOption_Max + 1][ArgOptionType_Max][ArgOptionMessageLength] = {
    {"--help", "This command."},
    {"--url", "Specify URL."},
    {"--port", "Specify port number (443 will be used when not specified)."},
    {"--disable_server_validation", "Disable server validation (it's performed by default)."},
    {"--disable_name_validation", "Disable host name validation (it's performed by default)."},
    {"--server_cert", "Specify server cert for ntd-net-server1 (this option shold be followed by port number - e.g.: --server_cert 444)"},
    {"--client_cert", "Specify client cert for ntd-net-server1 (this option shold be followed by port number - e.g.: --client_cert 444)"},
    {"--proxy_addr", "Specify proxy address."},
    {"--proxy_port", "Specify proxy port number (8080 will be used by default when address is set but port is not set)"},
    {"--proxy_user", "Specify proxy user name."},
    {"--proxy_pass", "Specify proxy password."},
};

void PrintHelp()
{
    NN_LOG("Usage\n");
    for(int i=0; i < ArgOption_Max; i++)
    {
        NN_LOG(" %s - %s\n", g_ArgOptions[i][ArgOptionType_Command], g_ArgOptions[i][ArgOptionType_Message]);
    }
    NN_LOG("example: testSsl_SslLibcurlUtil.exe --url https://www.google.com --port 443 --disable_server_validation\n");
}

void InitializeConfigParameters(ConfigParams* pInConfig)
{
    memset(pInConfig->pHostName, 0x00, HostNameLength);
    memset(pInConfig->proxyAddress, 0x00, ProxyStrBuffLength);
    memset(pInConfig->proxyUserName, 0x00, ProxyStrBuffLength);
    memset(pInConfig->proxyPassword, 0x00, ProxyStrBuffLength);

    pInConfig->portNumber                = 443;
    pInConfig->isPerformServerValidation = true;
    pInConfig->isPerformNameValidation   = true;
    pInConfig->pServerCert               = nullptr;
    pInConfig->pClientPki                = nullptr;
    pInConfig->pPassword                 = nullptr;
    pInConfig->serverCertSize            = 0;
    pInConfig->clientPkiSize             = 0;
    pInConfig->passwordSize              = 0;
    pInConfig->proxyPort                 = 8080;
}

bool ConfigureParameters(int argc, char** argv, ConfigParams* pInConfig)
{
    InitializeConfigParameters(pInConfig);

    for (int i = 1; i < argc; i++)
    {
        bool isFound = false;
        int  j;

        for(j=0; j < ArgOption_Max; j++)
        {
            if (!strcmp(argv[i], g_ArgOptions[j][ArgOptionType_Command]))
            {
                isFound = true;
                break;
            }
        }

        if (isFound == false)
        {
            NN_LOG("Unknown option was passed (%s).\n", argv[i]);
            PrintHelp();
            return false;
        }

        switch(j)
        {
        case ArgOption_Help:
            PrintHelp();
            return false;
        case ArgOption_Host:
            if (++i > argc)
            {
                NN_LOG("Actual URL name is not passed.\n");
                return false;
            }
            strncpy(pInConfig->pHostName, argv[i], HostNameLength);
            break;
        case ArgOption_Port:
            if (++i > argc)
            {
                NN_LOG("Actual port number is not passed.\n");
                return false;
            }
            pInConfig->portNumber = atoi(argv[i]);
            break;
        case ArgOption_DisableServerValidation:
            pInConfig->isPerformNameValidation = false;
            break;
        case ArgOption_DisableNameValidation:
            pInConfig->isPerformServerValidation = false;
            break;
        case ArgOption_ServerCert:
            {
                if (++i > argc)
                {
                    NN_LOG("Actual port number to select server cert is not passed.\n");
                    return false;
                }
                int tmpPortNumber = atoi(argv[i]);
                switch(tmpPortNumber)
                {
                case ServerPort_ExpiredCert:
                    pInConfig->pServerCert = g_pTestCaCertExpired;
                    pInConfig->serverCertSize = sizeof(g_pTestCaCertExpired);
                    break;
                case ServerPort_Normal:
                case ServerPort_ClientAuth:
                    pInConfig->pServerCert = g_pTestCaCert;
                    pInConfig->serverCertSize = sizeof(g_pTestCaCert);
                    break;
                case ServerPort_ClientAuthNoPwd:
                    pInConfig->pServerCert = g_pTestCaCert2;
                    pInConfig->serverCertSize = sizeof(g_pTestCaCert2);
                    break;
                default:
                    NN_LOG("There's no dedicated CA certificate for given port number.\n");
                    return false;
                }
            }
            break;
        case ArgOption_ClientCert:
            {
                if (++i > argc)
                {
                    NN_LOG("Actual port number to select client cert is not passed.\n");
                    return false;
                }
                int tmpPortNumber = atoi(argv[i]);
                switch(tmpPortNumber)
                {
                case ServerPort_ClientAuth:
                    pInConfig->pClientPki    = reinterpret_cast<const char*>(g_pTestClientPki);
                    pInConfig->clientPkiSize = g_pTestClientPkiSize;
                    pInConfig->pPassword     = g_pTestClientPkiPassword;
                    pInConfig->passwordSize  = g_TestClientPkiPasswordLength;
                    break;
                case ServerPort_ClientAuthNoPwd:
                    pInConfig->pClientPki    = reinterpret_cast<const char*>(g_pTestClientPkiNoPwd);
                    pInConfig->clientPkiSize = g_pTestClientPkiNoPwdSize;
                    break;
                default:
                    NN_LOG("There's no dedicated client PKI for given port number.\n");
                    return false;
                }
            }
            break;
        case ArgOption_ProxyAddress:
            if (++i > argc)
            {
                NN_LOG("Actual proxy address is not passed.\n");
                return false;
            }
            strncpy(pInConfig->proxyAddress, argv[i], ProxyStrBuffLength);
            break;
        case ArgOption_ProxyPort:
            if (++i > argc)
            {
                NN_LOG("Actual proxy port number is not passed.\n");
                return false;
            }
            pInConfig->proxyPort = static_cast<uint16_t>(atoi(argv[i]));
            break;
        case ArgOption_ProxyUserName:
            if (++i > argc)
            {
                NN_LOG("Actual proxy user name is not passed.\n");
                return false;
            }
            strncpy(pInConfig->proxyUserName, argv[i], ProxyStrBuffLength);
            break;
        case ArgOption_ProxyPassword:
            if (++i > argc)
            {
                NN_LOG("Actual proxy password name is not passed.\n");
                return false;
            }
            strncpy(pInConfig->proxyPassword, argv[i], ProxyStrBuffLength);
            break;
        default:
            NN_LOG("Internal parser error.\n");
            return false;
        }
    }

    return true;
} // NOLINT(impl/function_size)

size_t CurlSslContextCallback(CURL* pCurl, void* pSslContext, void* pUserData)
{
    // Obtain pointer to the SSL context passed by CURLOPT_SSL_CTX_FUNCTION
    nn::ssl::Context* pContext = reinterpret_cast<nn::ssl::Context*>(pSslContext);

    // Create SSL context
    nn::Result result = pContext->Create(nn::ssl::Context::SslVersion_Auto);
    if( result.IsFailure() )
    {
        return (size_t) - 1;
    }

    if (pUserData != nullptr)
    {
        nn::ssl::CertStoreId certStoreId;
        SimpleCurlHttpsClient* pClient = reinterpret_cast<SimpleCurlHttpsClient*>(pUserData);

        if(pClient->IsImportServerPki())
        {
            result = pContext->ImportServerPki(
                &certStoreId,
                pClient->GetServerCert(),
                pClient->GetServerCertSize(),
                nn::ssl::CertificateFormat_Pem);
            if (result.IsSuccess())
            {
                pClient->SetServerCertStoreId(certStoreId);
                NN_LOG("Server cert imported (certstore id:%d)\n", certStoreId);
            }
            else
            {
                NN_LOG("Importing server cert failed (desc:%d)\n", result.GetDescription());
                return (size_t) - 1;
            }
        }

        certStoreId = 0;
        if(pClient->IsImportClientPki())
        {
            result = pContext->ImportClientPki(
                &certStoreId,
                pClient->GetClientPki(),
                pClient->GetPassword(),
                pClient->GetClientPkiSize(),
                pClient->GetPasswordSize());
            if (result.IsSuccess())
            {
                NN_LOG("Client cert/key imported (certstore id:%d)\n", certStoreId);
                pClient->SetClientCertStoreId(certStoreId);
            }
            else
            {
                NN_LOG("Importing client cert failed (desc:%d)\n", result.GetDescription());
                return (size_t) - 1;
            }
        }
    }

    return 0;
}
} // Un-named namespace

extern "C" void nninitStartup()
{
    NN_LOG("nninitStartup loaded %p\n", nninitStartup);
    // メモリヒープの全体サイズを設定する
    const size_t MemoryHeapSize = 128 * 1024 * 1024;
    auto result = nn::os::SetMemoryHeapSize( MemoryHeapSize );

    ASSERT_TRUE( result.IsSuccess() );

    // メモリヒープから malloc で使用するメモリ領域を確保
    uintptr_t address = 0;

    result = nn::os::AllocateMemoryBlock( &address, MemoryHeapSize );
    ASSERT_TRUE( result.IsSuccess() );

    // malloc 用のメモリ領域を設定する
    nn::init::InitializeAllocator( reinterpret_cast<void*>(address), MemoryHeapSize );
}

extern "C" void nnMain()
{
    ConfigParams config;

    if(ConfigureParameters(nn::os::GetHostArgc(), nn::os::GetHostArgv(), &config) != true)
    {
        return;
    }

    if (config.pHostName[0] == '\0')
    {
        NN_LOG("URL is not specified.\n");
        return;
    }

    if (config.portNumber != 443)
    {
        char tmpHostName[HostNameLength] = {0};
        MY_SNPRINTF(tmpHostName, HostNameLength, "%s:%d", config.pHostName, config.portNumber);
        strncpy(config.pHostName, tmpHostName, HostNameLength);
    }

    NN_LOG(" URL              : %s\n", config.pHostName);
    NN_LOG(" Server validation: %s\n", (config.isPerformServerValidation)?("YES"):("NO"));
    NN_LOG(" Name validation  : %s\n", (config.isPerformNameValidation)?("YES"):("NO"));
    NN_LOG(" Proxy            : %s (address: %s port: %d usrname:%s password:%s)\n",
        (config.proxyAddress[0] != '\0')?("Enabled"):("Disabled"),
        config.proxyAddress,
        config.proxyPort,
        (config.proxyUserName[0] != '\0')?(config.proxyUserName):("NONE"),
        (config.proxyPassword[0] != '\0')?(config.proxyPassword):("NONE"));

    ASSERT_TRUE(g_CommonUtil.SetupNetwork().IsSuccess());
    ASSERT_TRUE(nn::socket::Initialize(
        g_SocketMemoryPoolBuffer,
        nn::socket::DefaultSocketMemoryPoolSize,
        nn::socket::MinSocketAllocatorSize,
        nn::socket::DefaultConcurrencyLimit).IsSuccess());
    ASSERT_TRUE(curl_global_init(CURL_GLOBAL_DEFAULT) == CURLE_OK);

    SimpleCurlHttpsClient curlSimple(
        config.pHostName,
        config.pServerCert,    // CA certificate
        config.pClientPki,     // client certificate
        config.pPassword,      // password
        config.serverCertSize, // CA certificate buffer size
        config.clientPkiSize,  // client certificate buffer size
        config.passwordSize,   // password buffer size
        SimpleCurlHttpsClient::InitMode_Manual);
    curlSimple.SetCurlCtxFunction(CurlSslContextCallback);
    curlSimple.EnableCurlVerbose();
    curlSimple.ConfigureValidation(
        config.isPerformServerValidation, // peer CA validation,
        config.isPerformNameValidation    // name validation
    );
    if(config.proxyAddress[0] != '\0')
    {
        curlSimple.SetupProxy(
            config.proxyAddress,
            (config.proxyUserName[0] == '\0')?(nullptr):(config.proxyUserName),
            (config.proxyPassword[0] == '\0')?(nullptr):(config.proxyPassword),
            config.proxyPort);
    }

    curlSimple.Perform();

    curl_global_cleanup();
    nn::socket::Finalize();
    g_CommonUtil.FinalizeNetwork();
}
