﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <nn/os.h>
#include <nn/init.h>
#include <nn/nn_Log.h>
#include <nn/nn_Assert.h>
#include <nn/lmem/lmem_ExpHeap.h>
#include <nnt/nntest.h>
#include <nn/socket.h>
#include <nn/ssl.h>
#include <curl/curl.h>

#include <Common/testCommonUtil.h>

#include <Utils/CommandLineParser.h>

#define BREAK_ON_CURL_FAILURE(res)    \
    do{                               \
        if(res != CURLE_OK) {         \
            break;                    \
        }                             \
    } while (NN_STATIC_CONDITION(0))

extern "C" void nninitStartup()
{
    NN_LOG("nninitStartup loaded %p\n", nninitStartup);
    // メモリヒープの全体サイズを設定する
    const size_t MemoryHeapSize = 128 * 1024 * 1024;
    auto result = nn::os::SetMemoryHeapSize( MemoryHeapSize );

    ASSERT_TRUE( result.IsSuccess() );

    // メモリヒープから malloc で使用するメモリ領域を確保
    uintptr_t address = 0;

    result = nn::os::AllocateMemoryBlock( &address, MemoryHeapSize );
    ASSERT_TRUE( result.IsSuccess() );

    // malloc 用のメモリ領域を設定する
    nn::init::InitializeAllocator( reinterpret_cast<void*>(address), MemoryHeapSize );
}

namespace
{
typedef struct ConnectionInfo
{
    const char* pUrl;
    bool isVerbose;
    bool isReadDeviceCert;
    bool isUseProxy;
    struct ProxyInfo
    {
        const char* pUrl;
        uint16_t    port;
    } proxyInfo;
} ConnectionInfo;

const char* ProxyAddress = "http://proxy.nintendo.co.jp";

const char TestUrlListNintendo[][256] = {
    "https://dauth-dd1.ndas.srv.nintendo.net",
};

const char TestUrlListCommercial[][256] = {
    "https://addtrustexternalcaroot-ev.comodoca.com/",          // AddTrust External CA Root
    "https://comodocertificationauthority-ev.comodoca.com/",    // COMODO Certification Authority
    "https://mobile2025.cybertrust.ne.jp/",                     // Baltimore CyberTrust Root
    "https://evup.cybertrust.ne.jp/ctj-ev-upgrader/mtest.html", // Cybertrust Global Root
    "https://sha256.managedpki.ne.jp/",                         // Verizon Global Root CA
    "https://assured-id-root.digicert.com/",                    // DigiCert Assured ID Root CA
    "https://assured-id-root-g2.digicert.com/",                 // DigiCert Assured ID Root G2
    "https://global-root.digicert.com/",                        // DigiCert Global Root CA
    "https://global-root-g2.digicert.com/",                     // DigiCert Global Root G2
    "https://ev-root.digicert.com/",                            // DigiCert High Assurance EV Root CA
    "https://2048test.entrust.net/",                            // Entrust.net Certification Authority (2048)
    "https://validev.entrust.net/",                             // Entrust Root Certification Authority
    "https://validg2.entrust.net/",                             // Entrust Root Certification Authority - G2

    "https://ssltest11.bbtest.net/",                            // GeoTrust Global CA
    "https://ssltest21.bbtest.net/",                            // GeoTrust Primary Certification Authority - G3
    "https://www.geotrust.com/",                                // GeoTrust Primary Certification Authority
    "https://2029.globalsign.com/",                             // GlobalSign Root CA - R3
    "https://ssltest8.bbtest.net/",                             // thawte Primary Root CA - G3
    "https://ssltest6.bbtest.net/",                             // thawte Primary Root CA
    "https://ssltest4.verisign.co.jp/",                         // VeriSign Class 3 Public Primary Certification Authority - G3
    "https://ssltest9.verisign.co.jp/",                         // VeriSign Class 3 Public Primary Certification Authority - G5

    // "https://utndatacorpsgc-ev.comodoca.com/",                  // UTN - DATACorp SGC
    // "https://utnuserfirsthardware-ev.comodoca.com/",            // UTN-USERFirst-Hardware
    // "https://ssltest19.bbtest.net/",                            // GeoTrust Global CA 2
    // "https://2028.globalsign.com/",                             // GlobalSign Root CA
    // "https://2021.globalsign.com/",                             // GlobalSign Root CA - R2
    // "https://ptnr-verisign256.bbtest.net/"                      // VeriSign Universal Root Certification Authority
};

const int                TestUrlCommercialCount = sizeof(TestUrlListCommercial) / sizeof(TestUrlListCommercial[0]);
const int                TestUrlNintendoCount = sizeof(TestUrlListNintendo) / sizeof(TestUrlListNintendo[0]);
SslTestCommonUtil        g_CommonUtil;
NN_ALIGNAS(4096) uint8_t g_SocketMemoryPoolBuffer[nn::socket::DefaultSocketMemoryPoolSize];

size_t WriteData(void *buffer, size_t size, size_t nmemb, void *userp)
{
    return size * nmemb;
}

size_t CurlSslContextCallback(CURL* pCurl, void* pSslContext, void* pUserData)
{
    // Obtain pointer to the SSL context passed by CURLOPT_SSL_CTX_FUNCTION
    nn::ssl::Context* pContext = reinterpret_cast<nn::ssl::Context*>(pSslContext);

    // Create SSL context
    nn::Result result = pContext->Create(nn::ssl::Context::SslVersion_Auto);
    if( result.IsFailure() )
    {
        NN_LOG("Create failed (Desc:%d)\n", result.GetDescription());
        return (size_t) - 1;
    }

    if (pUserData != nullptr)
    {
        ConnectionInfo* pInfo = reinterpret_cast<ConnectionInfo*>(pUserData);

        if (pInfo->isReadDeviceCert == true)
        {
            nn::ssl::CertStoreId certStoreId;
            result = pContext->RegisterInternalPki(
                &certStoreId,
                nn::ssl::Context::InternalPki_DeviceClientCertDefault);
            if (result.IsFailure())
            {
                NN_LOG("Failed to register internal cert (Desc:%d)\n", result.GetDescription());
                return (size_t) - 1;
            }
        }
    }

    return 0;
}

bool PerformHttps(ConnectionInfo* connInfo)
{
    bool     isSuccess = false;
    CURL*    curl;
    CURLcode res;

    do
    {
        curl = curl_easy_init();
        if (curl == nullptr)
        {
            return isSuccess;
        }

        BREAK_ON_CURL_FAILURE(curl_easy_setopt(curl, CURLOPT_URL, connInfo->pUrl));
        BREAK_ON_CURL_FAILURE(curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 15L));
        BREAK_ON_CURL_FAILURE(curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteData));
        BREAK_ON_CURL_FAILURE(curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, 1L));
        BREAK_ON_CURL_FAILURE(curl_easy_setopt(curl, CURLOPT_SSL_VERIFYHOST, 2L));
        BREAK_ON_CURL_FAILURE(curl_easy_setopt(curl, CURLOPT_SSL_CTX_FUNCTION, CurlSslContextCallback));
        BREAK_ON_CURL_FAILURE(curl_easy_setopt(curl, CURLOPT_SSL_CTX_DATA, reinterpret_cast<void*>(connInfo)));
        if(connInfo->isUseProxy == true)
        {
            BREAK_ON_CURL_FAILURE(curl_easy_setopt(curl, CURLOPT_PROXYAUTOCONFIG, 0));
            BREAK_ON_CURL_FAILURE(curl_easy_setopt(curl, CURLOPT_PROXYAUTH, CURLAUTH_BASIC));
            BREAK_ON_CURL_FAILURE(curl_easy_setopt(curl, CURLOPT_PROXY, connInfo->proxyInfo.pUrl));
            BREAK_ON_CURL_FAILURE(curl_easy_setopt(curl, CURLOPT_PROXYPORT, connInfo->proxyInfo.port));
        }
        if (connInfo->isVerbose == true)
        {
            BREAK_ON_CURL_FAILURE(curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L));
        }

        res = curl_easy_perform(curl);
        if (res != CURLE_OK)
        {
            NN_LOG("curl_easy_perform FAILED: curl error: %d\n", res);
            SslTestCommonUtil::DumpCurlResults(curl);
        }
        else
        {
            NN_LOG("curl_easy_perform SUCCEEDED.\n");
            isSuccess = true;
        }
    } while (NN_STATIC_CONDITION(0));

    curl_easy_cleanup(curl);
    return isSuccess;
}

} // Un-named namespace

//-------------------------------------------------------------------------------------------------
//  Tests
//-------------------------------------------------------------------------------------------------
TEST(InitTest, Success)
{
    nn::util::Uuid netProfile = nn::util::InvalidUuid;
    NATF::Utils::ParserGroup parser;

    parser.AddParser(NATF::Utils::UuidParser ("--NetProfile", &nn::util::InvalidUuid, netProfile));

    int      argc = nn::os::GetHostArgc();
    char**   argv = nn::os::GetHostArgv();

    if (!parser.Parse(argc, argv))
    {
        NN_LOG("\n * Failed to parse command line arguements!\n\n");
        FAIL();
        return;
    }

    ASSERT_TRUE(g_CommonUtil.SetupNetwork(netProfile));
    ASSERT_TRUE(nn::socket::Initialize(
        g_SocketMemoryPoolBuffer,
        nn::socket::DefaultSocketMemoryPoolSize,
        nn::socket::MinSocketAllocatorSize,
        nn::socket::DefaultConcurrencyLimit).IsSuccess());

    ASSERT_TRUE(curl_global_init(CURL_GLOBAL_DEFAULT) == CURLE_OK);
}

TEST(CommercialCas, Success)
{
    ConnectionInfo connInfo;
    memset(&connInfo, 0x00, sizeof(connInfo));
    connInfo.isReadDeviceCert = false;
    connInfo.isUseProxy       = false;
    connInfo.isVerbose        = false;
    for (int i=0; i < TestUrlCommercialCount; i++)
    {
        connInfo.pUrl = TestUrlListCommercial[i];
        NN_LOG("\n[%d][%s]---------------------------------------------------------\n",
            i, connInfo.pUrl);
        EXPECT_TRUE(PerformHttps(&connInfo));
    }
}

// This test is required to run with the device unique client certificate
// Please refer the page below for further information
// http://spdlybra.nintendo.co.jp/confluence/display/SIGLODOCEN/Device+client+certificate
TEST(NintendoCa, Success)
{
    ConnectionInfo connInfo;
    memset(&connInfo, 0x00, sizeof(connInfo));
    connInfo.isReadDeviceCert = true;
    connInfo.isUseProxy       = true;
    connInfo.proxyInfo.pUrl   = ProxyAddress;
    connInfo.proxyInfo.port   = 8080;
    connInfo.isVerbose        = false;
    for (int i=0; i < TestUrlNintendoCount; i++)
    {
        connInfo.pUrl = TestUrlListNintendo[i];
        NN_LOG("\n[%d][%s]---------------------------------------------------------\n",
            i, connInfo.pUrl);
        EXPECT_TRUE(PerformHttps(&connInfo));
    }
}

TEST(GetBuiltInCa, Success)
{
    uint32_t                                        bufSize = 0;
    nn::ssl::CaCertificateId                        nintendoId[] = { nn::ssl::CaCertificateId_NintendoCAG3 };
    nn::ssl::CaCertificateId                        commercialId[] = { nn::ssl::CaCertificateId_AmazonRootCA1 };
    nn::Result                                      result;
    uint8_t                                         *pBuf = nullptr;
    nn::ssl::BuiltInManager::BuiltInCertificateInfo *pCertInfoArray = nullptr;
    char                                            *pDerData = nullptr;
    nn::ssl::CertStoreId                            certId;

    //  Grab the first Nintendo trusted CA first
    result =
        nn::ssl::BuiltInManager::GetBuiltInCertificateBufSize(&bufSize,
                                                              nintendoId,
                                                              1);
    NN_LOG("[GetBuiltInCa, Success] need %u bytes\n", bufSize);
    ASSERT_TRUE(result.IsSuccess());

    pBuf = new uint8_t[bufSize];
    ASSERT_TRUE(pBuf != nullptr);

    result =
        nn::ssl::BuiltInManager::GetBuiltInCertificates(&pCertInfoArray,
                                                        pBuf,
                                                        bufSize,
                                                        nintendoId,
                                                        1);
    ASSERT_TRUE(result.IsSuccess());

    //  Attempt to import the cert into a nn::ssl::Context (this will
    //  verify the DER data is all there and correct.
    nn::ssl::Context *pCtx = new nn::ssl::Context();
    ASSERT_TRUE(pCtx != nullptr);

    result = pCtx->Create(nn::ssl::Context::SslVersion_Auto);
    ASSERT_TRUE(result.IsSuccess());

    pDerData =
        reinterpret_cast<char *>(pCertInfoArray[0].data.ptr.pCertificateDerData);
    result = pCtx->ImportServerPki(&certId,
                                   pDerData,
                                   pCertInfoArray[0].certificateSize,
                                   nn::ssl::CertificateFormat_Der);
    EXPECT_TRUE(result.IsSuccess());

    delete[] pBuf;
    pBuf = nullptr;
    pCtx->Destroy();
    delete pCtx;
    pCtx = nullptr;

    //  Now repeat with the first commercial CA
    result =
        nn::ssl::BuiltInManager::GetBuiltInCertificateBufSize(&bufSize,
                                                              commercialId,
                                                              1);
    ASSERT_TRUE(result.IsSuccess());

    pBuf = new uint8_t[bufSize];
    ASSERT_TRUE(pBuf != nullptr);

    result =
        nn::ssl::BuiltInManager::GetBuiltInCertificates(&pCertInfoArray,
                                                        pBuf,
                                                        bufSize,
                                                        commercialId,
                                                        1);
    ASSERT_TRUE(result.IsSuccess());

    //  Attempt to import the cert into a nn::ssl::Context (this will
    //  verify the DER data is all there and correct.
    pCtx = new nn::ssl::Context();
    ASSERT_TRUE(pCtx != nullptr);

    result = pCtx->Create(nn::ssl::Context::SslVersion_Auto);
    ASSERT_TRUE(result.IsSuccess());

    pDerData =
        reinterpret_cast<char *>(pCertInfoArray[0].data.ptr.pCertificateDerData);
    result = pCtx->ImportServerPki(&certId,
                                   pDerData,
                                   pCertInfoArray[0].certificateSize,
                                   nn::ssl::CertificateFormat_Der);
    EXPECT_TRUE(result.IsSuccess());

    delete[] pBuf;
    pBuf = nullptr;
    pCtx->Destroy();
    delete pCtx;
    pCtx = nullptr;
}

TEST(GetBuiltInCa, Failure)
{
    uint32_t                                        bufSize;
    nn::ssl::CaCertificateId                        nintendoId[] = { nn::ssl::CaCertificateId_NintendoCAG3 };
    nn::ssl::CaCertificateId                        badId[] = { static_cast<nn::ssl::CaCertificateId>(0x1234) };
    nn::Result                                      result;
    uint8_t                                         *pBuf = nullptr;
    nn::ssl::BuiltInManager::BuiltInCertificateInfo *pCertInfoArray = nullptr;
    uint32_t                                        dummy;

    //  Try to get size with invalid pointers
    result =
        nn::ssl::BuiltInManager::GetBuiltInCertificateBufSize(nullptr,
                                                              badId,
                                                              1);
    EXPECT_TRUE(result.IsFailure());
    EXPECT_TRUE(nn::ssl::ResultInvalidPointer::Includes(result));

    result =
        nn::ssl::BuiltInManager::GetBuiltInCertificateBufSize(&bufSize,
                                                              nullptr,
                                                              1);
    EXPECT_TRUE(result.IsFailure());
    EXPECT_TRUE(nn::ssl::ResultInvalidPointer::Includes(result));

    //  Attempt to get an invalid ID
    result =
        nn::ssl::BuiltInManager::GetBuiltInCertificateBufSize(&bufSize,
                                                              badId,
                                                              1);
    ASSERT_TRUE(result.IsSuccess());

    //  Attempt to get a cert with too small of a buffer
    result =
        nn::ssl::BuiltInManager::GetBuiltInCertificates(&pCertInfoArray,
                                                        reinterpret_cast<uint8_t *>(&dummy),
                                                        sizeof(dummy),
                                                        nintendoId,
                                                        1);
    EXPECT_TRUE(result.IsFailure());
    EXPECT_TRUE(nn::ssl::ResultBufferTooShort::Includes(result));

    pBuf = new uint8_t[bufSize];
    ASSERT_TRUE(pBuf != nullptr);

    result =
        nn::ssl::BuiltInManager::GetBuiltInCertificates(&pCertInfoArray,
                                                        pBuf,
                                                        bufSize,
                                                        badId,
                                                        1);
    ASSERT_TRUE(result.IsSuccess());

    //  Look at the cert info and make sure the status is set to INVALID
    ASSERT_TRUE(pCertInfoArray[0].status == nn::ssl::TrustedCertStatus_Invalid);
    delete[] pBuf;
    pBuf = nullptr;
}

TEST(GetMultiBuiltInCa, Success)
{
    uint32_t                                        bufSize;
    nn::ssl::CaCertificateId                        twoIds[] = { nn::ssl::CaCertificateId_NintendoCAG3,
                                                                 nn::ssl::CaCertificateId_AmazonRootCA1 };
    nn::Result                                      result;
    uint8_t                                         *pBuf = nullptr;
    nn::ssl::BuiltInManager::BuiltInCertificateInfo *pCertInfoArray = nullptr;
    char                                            *pDerData = nullptr;
    nn::ssl::CertStoreId                            certId;

    //  Grab the first Nintendo trusted CA first
    result =
        nn::ssl::BuiltInManager::GetBuiltInCertificateBufSize(&bufSize,
                                                              twoIds,
                                                              2);
    ASSERT_TRUE(result.IsSuccess());

    pBuf = new uint8_t[bufSize];
    ASSERT_TRUE(pBuf != nullptr);

    result =
        nn::ssl::BuiltInManager::GetBuiltInCertificates(&pCertInfoArray,
                                                        pBuf,
                                                        bufSize,
                                                        twoIds,
                                                        2);
    ASSERT_TRUE(result.IsSuccess());

    //  Attempt to import the cert into a nn::ssl::Context (this will
    //  verify the DER data is all there and correct.
    nn::ssl::Context *pCtx = new nn::ssl::Context();
    ASSERT_TRUE(pCtx != nullptr);

    result = pCtx->Create(nn::ssl::Context::SslVersion_Auto);
    ASSERT_TRUE(result.IsSuccess());

    pDerData =
        reinterpret_cast<char *>(pCertInfoArray[0].data.ptr.pCertificateDerData);
    result = pCtx->ImportServerPki(&certId,
                                   pDerData,
                                   pCertInfoArray[0].certificateSize,
                                   nn::ssl::CertificateFormat_Der);
    EXPECT_TRUE(result.IsSuccess());

    pDerData =
        reinterpret_cast<char *>(pCertInfoArray[1].data.ptr.pCertificateDerData);
    result = pCtx->ImportServerPki(&certId,
                                   pDerData,
                                   pCertInfoArray[1].certificateSize,
                                   nn::ssl::CertificateFormat_Der);
    EXPECT_TRUE(result.IsSuccess());

    delete[] pBuf;
    pBuf = nullptr;
    pCtx->Destroy();
    delete pCtx;
    pCtx = nullptr;
}


TEST(GetAllBuiltinCa, Success)
{
    uint32_t                                        bufSize = 0;
    nn::ssl::CaCertificateId                        allId[] = { nn::ssl::CaCertificateId_All };
    nn::Result                                      result;
    uint8_t                                         *pBuf = nullptr;
    nn::ssl::BuiltInManager::BuiltInCertificateInfo *pCertInfoArray = nullptr;
    uint32_t                                        total;
    char                                            *pDerData = nullptr;
    nn::ssl::CertStoreId                            certId;

    //  Grab all CA certs
    result =
        nn::ssl::BuiltInManager::GetBuiltInCertificateBufSize(&bufSize,
                                                              allId,
                                                              1);
    NN_LOG("[GetAllBuiltinCa, Success] need %u bytes\n", bufSize);
    ASSERT_TRUE(result.IsSuccess());

    pBuf = new uint8_t[bufSize];
    ASSERT_TRUE(pBuf != nullptr);

    result =
        nn::ssl::BuiltInManager::GetBuiltInCertificates(&pCertInfoArray,
                                                        pBuf,
                                                        bufSize,
                                                        allId,
                                                        1);
    ASSERT_TRUE(result.IsSuccess());

    //  Attempt to import each cert into a nn::ssl::Context (this will
    //  verify the DER data is all there and correct.
    nn::ssl::Context *pCtx = new nn::ssl::Context();
    ASSERT_TRUE(pCtx != nullptr);

    result = pCtx->Create(nn::ssl::Context::SslVersion_Auto);
    ASSERT_TRUE(result.IsSuccess());

    //  Walk each cert data returned, print each ID.  Also count how many
    //  we actually get and ensure it is > 1.
    total = 0;
    for (int i = 0; pCertInfoArray[i].id != nn::ssl::CaCertificateId_All; i++)
    {
        pDerData =
            reinterpret_cast<char *>(pCertInfoArray[i].data.ptr.pCertificateDerData);
        NN_LOG("[GetAllBuiltinCa, Success] id %8.8X, status %8.8X, len %u, data %p\n",
               static_cast<uint32_t>(pCertInfoArray[i].id),
               static_cast<uint32_t>(pCertInfoArray[i].status),
               pCertInfoArray[i].certificateSize,
               pDerData);
        result = pCtx->ImportServerPki(&certId,
                                       pDerData,
                                       pCertInfoArray[i].certificateSize,
                                       nn::ssl::CertificateFormat_Der);
        EXPECT_TRUE(result.IsSuccess());
        total++;
    }

    ASSERT_TRUE(total > 1);
    delete[] pBuf;
    pBuf = nullptr;
}


TEST(FinalizeTest, Success)
{
    curl_global_cleanup();
    nn::socket::Finalize();
    g_CommonUtil.FinalizeNetwork();
}
