﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <nn/os.h>
#include <nn/init.h>
#include <nn/nn_Log.h>
#include <nn/nn_Assert.h>
#include <nn/lmem/lmem_ExpHeap.h>
#include <nnt/nntest.h>
#include <nn/socket.h>
#include <nn/ssl.h>

#include <Utils/CommandLineParser.h>

#include <Common/testCommonUtil.h>
#include <Common/testInfraInfo.h>
#include <Common/testServerPki.h>
#include <Common/testClientPki.h>

// ------------------------------------------------------------------------------------------------
// Build flags
// ------------------------------------------------------------------------------------------------
#define USE_NETWORK

// ------------------------------------------------------------------------------------------------
// Build flags for tests to run
// ------------------------------------------------------------------------------------------------
#define RUN_NEGATIVE
#define RUN_POSITIVE
#define RUN_WITH_VALID_EV
#define RUN_WITH_NO_OID
#define RUN_WITH_USER_OID
#define RUN_NO_EV

// ------------------------------------------------------------------------------------------------
// Macros
// ------------------------------------------------------------------------------------------------

#if defined (NN_BUILD_CONFIG_OS_WIN)
#define MY_SNPRINTF _snprintf
#else
#define MY_SNPRINTF snprintf
#endif

// ------------------------------------------------------------------------------------------------
// Params
// ------------------------------------------------------------------------------------------------
namespace
{
const int g_MsecPollTimeout = 5000;

SslTestCommonUtil        g_CommonUtil;
NN_ALIGNAS(4096) uint8_t g_SocketMemoryPoolBuffer[nn::socket::DefaultSocketMemoryPoolSize];

// "assured-id-root-g2.digicert.com",        // DigiCert   2.16.840.1.114412.1.1
// "global-root-g2.digicert.com",            // DigiCert   2.16.840.1.114412.1.1
// "2048test.entrust.net",                   // Entrust    2.23.140.1.2.2
// "www.digicert.com",                       // DigiCert   2.16.840.1.114412.1.3.0.2 (not tested)

#ifdef RUN_WITH_VALID_EV
// Servers which policy OID is in the built-in
const char TestUrlListEv[][256] = {
    "secure.comodo.com",                      // COMODO
    "addtrustexternalcaroot-ev.comodoca.com", // COMODO
    "www.cybertrust.ne.jp",                   // CyberTrust
    "evup.cybertrust.ne.jp",                  // Cybertrust
    "www.digicert.com",                       // DigiCert
    "www.entrust.net",                        // Entrust
    "validev.entrust.net",                    // Entrust
    "validg2.entrust.net",                    // Entrust
    "ssltest21.bbtest.net",                   // GeoTrust
    "www.geotrust.com",                       // GeoTrust
    "www.globalsign.com",                     // GlobalSign
    "2021.globalsign.com",                    // GlobalSign
    "2029.globalsign.com",                    // GlobalSign
    "www.godaddy.com",                        // Go Daddy
    "valid.gdig2.catest.godaddy.com",         // Go Daddy
    "www.starfieldtech.com",                  // Starfield
    "valid.sfig2.catest.starfieldtech.com",   // Starfield
    "www.thawte.com",                         // Thawte
    "ssltest8.bbtest.net",                    // Thawte
    "www.verisign.com",                       // VeriSign
    "ssltest6.jp.websecurity.symantec.com",   // VeriSign
};
const int TestUrlCountEv = sizeof(TestUrlListEv) / sizeof(TestUrlListEv[0]);
#endif // RUN_WITH_VALID_EV

#ifdef RUN_WITH_NO_OID
// Servers which policy OID is NOT in the built-in
const char TestUrlListNoOid[][256] = {
    "www.amazontrust.com",
    "assured-id-root.digicert.com",           // DigiCert - DigiCert Assured ID Root CA
};
const int TestUrlCountNoOid = sizeof(TestUrlListNoOid) / sizeof(TestUrlListNoOid[0]);
#endif // RUN_WITH_NO_OID

#ifdef RUN_WITH_USER_OID
// Servers which policy OID is NOT in the built-in
const char TestUrlListUserOid[][256] = {
    "www.amazontrust.com",
};
const int TestUrlCountUserOid = sizeof(TestUrlListUserOid) / sizeof(TestUrlListUserOid[0]);

// Servers which policy OID is NOT in the built-in
const char UserOidList[][256] = {
    "2.23.140.1.2.1",        // www.amazontrust.com
};
#endif // RUN_WITH_USER_OID

#ifdef RUN_NO_EV
const char TestUrlListNoEv[][256] = {
    "assured-id-root.digicert.com", // assured-id-root.digicert.com
};
const int TestUrlCountNoEv = sizeof(TestUrlListNoEv) / sizeof(TestUrlListNoEv[0]);

// Servers which policy OID is NOT in the built-in
const char NoEvUserOidList[][256] = {
    "2.16.840.1.114412.1.1", // assured-id-root.digicert.com (this is for Non EV SSL certs)
};
#endif // RUN_NO_EV

// ------------------------------------------------------------------------------------------------
// Functions
// ------------------------------------------------------------------------------------------------
nn::Result CreateContext(nn::ssl::Context *pContext)
{
    nn::Result result = nn::ResultSuccess();

    do
    {
        result = pContext->Create(nn::ssl::Context::SslVersion_Auto);
    } while (NN_STATIC_CONDITION(false));

    return result;
}

nn::Result CreateConnection( nn::ssl::Connection* pConnection,
                             nn::ssl::Context *pInContext,
                             const char* hostName )
{
    nn::Result result = nn::ResultSuccess();

    do
    {
        int socketFd = g_CommonUtil.CreateTcpSocket(true, 443, hostName, 0);
        if (socketFd < 0)
        {
            break;
        }

        result = pConnection->Create(pInContext);
        if (result.IsFailure())
        {
            NN_LOG( "Failed to create connection - result: %d\n", result.GetDescription() );
            break;
        }

        result = pConnection->SetSocketDescriptor(socketFd);
        if (result.IsFailure())
        {
            NN_LOG( "Failed SetSocketDescriptor - result: %d\n", result.GetDescription() );
            break;
        }

        uint32_t hostNameLen = static_cast<uint32_t>(strlen(hostName));
        result = pConnection->SetHostName(hostName, hostNameLen);
        if (result.IsFailure())
        {
            NN_LOG( "Failed SetHostName - result: %d\n", result.GetDescription() );
            break;
        }

        result = pConnection->SetVerifyOption(
            nn::ssl::Connection::VerifyOption::VerifyOption_All |
            nn::ssl::Connection::VerifyOption::VerifyOption_EvCertPartial);
        if (result.IsFailure())
        {
            NN_LOG( "Failed SetVerifyOption - result: %d\n", result.GetDescription() );
            break;
        }
    } while (NN_STATIC_CONDITION(false));

    return result;
}

} // Un-named namespace

//-----------------------------------------------------------------------------
//  スタートアップ関数
//-----------------------------------------------------------------------------
extern "C" void nninitStartup()
{
    NN_LOG("nninitStartup -> %p\n", (void *)nninitStartup);

    // メモリヒープの全体サイズを設定する
    const size_t MemoryHeapSize = 16 * 1024 * 1024;
    auto result = nn::os::SetMemoryHeapSize( MemoryHeapSize );

    EXPECT_TRUE( result.IsSuccess() );

    // メモリヒープから malloc で使用するメモリ領域を確保
    uintptr_t address = 0;

    result = nn::os::AllocateMemoryBlock( &address, MemoryHeapSize );
    EXPECT_TRUE( result.IsSuccess() );

    // malloc 用のメモリ領域を設定する
    nn::init::InitializeAllocator( reinterpret_cast<void*>(address), MemoryHeapSize );
}

// ------------------------------------------------------------------------------------------------
// Tests
// ------------------------------------------------------------------------------------------------
TEST(InitTest, Success)
{
    NN_LOG("nninitStartup: loaded at %p\n", (void *)nninitStartup);

#if defined(USE_NETWORK)
    nn::util::Uuid netProfile = nn::util::InvalidUuid;
    NATF::Utils::ParserGroup parser;

    parser.AddParser(NATF::Utils::UuidParser ("--NetProfile", &nn::util::InvalidUuid, netProfile));

    int      argc = nn::os::GetHostArgc();
    char**   argv = nn::os::GetHostArgv();

    if (!parser.Parse(argc, argv))
    {
        NN_LOG("\n * Failed to parse command line arguements!\n\n");
        FAIL();
        return;
    }

    ASSERT_TRUE(g_CommonUtil.SetupNetwork(netProfile));
#endif

    EXPECT_TRUE(nn::socket::Initialize(
        g_SocketMemoryPoolBuffer,
        nn::socket::DefaultSocketMemoryPoolSize,
        nn::socket::MinSocketAllocatorSize,
        nn::socket::DefaultConcurrencyLimit).IsSuccess());
}

//  This test needs to be declared before the others which use SSL so that
//  SSL is initialized properly for the remaining tests.
TEST(ShimInit, Success)
{
    nn::Result result = nn::ssl::Initialize();
    EXPECT_TRUE(result.IsSuccess());
}

#ifdef RUN_NEGATIVE
TEST(Negative, Success)
{
    nn::Result       result = nn::ResultSuccess();
    nn::ssl::Context sslContext;
    int              i = 0;

    result = CreateContext(&sslContext);
    ASSERT_TRUE(result.IsSuccess());

    // Set string which length doesn't include null
    const char tmpStr1[] = "0.0.0.1";
    result = sslContext.AddPolicyOid(tmpStr1, static_cast<uint32_t>(strlen(tmpStr1)));
    EXPECT_TRUE(nn::ssl::ResultInvalidPolicyOidStringBufferLength::Includes(result));

    // Set too long policy OID string for ResultPolicyOidStringTooLong()
    char tmpStr2[nn::ssl::MaxPolicyOidStringLength + 1] = {0};
    for (i = 0; i < nn::ssl::MaxPolicyOidStringLength; i++)
    {
        tmpStr2[i] = 'A';
    }
    result = sslContext.AddPolicyOid(tmpStr2, static_cast<uint32_t>(strlen(tmpStr2) + 1));
    EXPECT_TRUE(nn::ssl::ResultPolicyOidStringTooLong::Includes(result));

    // Set too many policy OIDs for ResultMaxPolicyOidRegistered()
    char tmpStrs[nn::ssl::MaxPolicyOidStringCount + 1][17] = {{0}};
    for (i = 0; i < (sizeof(tmpStrs) / sizeof(tmpStrs[0])); i++)
    {
        MY_SNPRINTF(tmpStrs[i], sizeof(tmpStrs[0]), "0.0.0.0.0.0.0.%d", i);
    }

    for (i = 0; i < (sizeof(tmpStrs) / sizeof(tmpStrs[0])); i++)
    {
        result = sslContext.AddPolicyOid(tmpStrs[i], static_cast<uint32_t>(strlen(tmpStrs[i]) + 1));
        if (i == (sizeof(tmpStrs) / sizeof(tmpStrs[0])) - 1)
        {
            EXPECT_TRUE(nn::ssl::ResultMaxPolicyOidRegistered::Includes(result));
        }
        else
        {
            ASSERT_TRUE(result.IsSuccess());
        }
    }

    // Cleanup
    result = sslContext.Destroy();
    ASSERT_TRUE(result.IsSuccess());
}
#endif // RUN_NEGATIVE

#ifdef RUN_POSITIVE
TEST(Positive, Success)
{
    nn::Result       result = nn::ResultSuccess();
    nn::ssl::Context sslContext;
    int              i = 0;

    result = CreateContext(&sslContext);
    ASSERT_TRUE(result.IsSuccess());

    // Set max len policy OID string
    char tmpStr[nn::ssl::MaxPolicyOidStringLength] = {0};
    for (i = 0; i < nn::ssl::MaxPolicyOidStringLength - 1; i++)
    {
        if ((i % 2) == 0)
            tmpStr[i] = '0';
        else
            tmpStr[i] = '.';
    }

    result = sslContext.AddPolicyOid(tmpStr, static_cast<uint32_t>(strlen(tmpStr)) + 1);
    EXPECT_TRUE(result.IsSuccess());

    // Cleanup
    result = sslContext.Destroy();
    ASSERT_TRUE(result.IsSuccess());
}
#endif // RUN_NEGATIVE

#ifdef RUN_WITH_VALID_EV
TEST(RunWithValidEv, Success)
{
    nn::Result    result = nn::ResultSuccess();
    int           i = 0;

    for (i = 0; i < TestUrlCountEv; i++)
    {
        nn::ssl::Context    sslContext;
        nn::ssl::Connection sslConnection;

        result = CreateContext(&sslContext);
        ASSERT_TRUE(result.IsSuccess());

        result = CreateConnection(&sslConnection, &sslContext, TestUrlListEv[i]);
        ASSERT_TRUE(result.IsSuccess());

        result = sslConnection.DoHandshake();
        ASSERT_TRUE(result.IsSuccess());

        result = sslConnection.Destroy();
        ASSERT_TRUE(result.IsSuccess());

        result = sslContext.Destroy();
        ASSERT_TRUE(result.IsSuccess());
    }
}
#endif // RUN_WITH_VALID_EV

#ifdef RUN_WITH_NO_OID
TEST(RunWithNoOid, Success)
{
    nn::Result    result = nn::ResultSuccess();
    int           i = 0;

    for (i = 0; i < TestUrlCountNoOid; i++)
    {
        nn::ssl::Context    sslContext;
        nn::ssl::Connection sslConnection;

        result = CreateContext(&sslContext);
        ASSERT_TRUE(result.IsSuccess());

        result = CreateConnection(&sslConnection, &sslContext, TestUrlListNoOid[i]);
        ASSERT_TRUE(result.IsSuccess());

        result = sslConnection.DoHandshake();
        EXPECT_TRUE(nn::ssl::ResultVerifyCertFailed::Includes(result));
        nn::Result verifyCertError;
        sslConnection.GetVerifyCertError(&verifyCertError);
        EXPECT_TRUE(nn::ssl::ResultSslErrorUntrustedOid::Includes(verifyCertError));

        result = sslConnection.Destroy();
        ASSERT_TRUE(result.IsSuccess());
        result = sslContext.Destroy();
        ASSERT_TRUE(result.IsSuccess());
    }
}
#endif // RUN_WITH_NO_OID

#ifdef RUN_WITH_USER_OID
TEST(RunWithUserOid, Success)
{
    nn::Result          result = nn::ResultSuccess();
    int                 i = 0;

    for (i = 0; i < TestUrlCountUserOid; i++)
    {
        nn::ssl::Context    sslContext;
        nn::ssl::Connection sslConnection;

        result = CreateContext(&sslContext);
        ASSERT_TRUE(result.IsSuccess());

        result = sslContext.AddPolicyOid(UserOidList[i], static_cast<uint32_t>(strlen(UserOidList[i]) + 1));
        EXPECT_TRUE(result.IsSuccess());

        result = CreateConnection(&sslConnection, &sslContext, TestUrlListUserOid[i]);
        ASSERT_TRUE(result.IsSuccess());

        result = sslConnection.DoHandshake();
        if (result.IsFailure())
        {
            nn::Result verifyCertError;
            sslConnection.GetVerifyCertError(&verifyCertError);
            NN_SDK_LOG("DoHandshake: desc:%d VerifyCertError:%d\n",
                result.GetDescription(), verifyCertError.GetDescription());
        }
        EXPECT_TRUE(result.IsSuccess());
        nn::Result verifyCertError;
        sslConnection.GetVerifyCertError(&verifyCertError);
        EXPECT_TRUE(verifyCertError.IsSuccess());

        result = sslConnection.Destroy();
        ASSERT_TRUE(result.IsSuccess());
        result = sslContext.Destroy();
        ASSERT_TRUE(result.IsSuccess());
    }
}
#endif // RUN_WITH_USER_OID

#ifdef RUN_NO_EV
// Non-EV Certificate - but validates because of a User added Policy
TEST(RunWithNoEv, Success)
{
    nn::Result    result = nn::ResultSuccess();
    int           i = 0;

    for (i = 0; i < TestUrlCountNoEv; i++)
    {
        nn::ssl::Context    sslContext;
        nn::ssl::Connection sslConnection;

        result = CreateContext(&sslContext);
        ASSERT_TRUE(result.IsSuccess());

        result = sslContext.AddPolicyOid(NoEvUserOidList[i], static_cast<uint32_t>(strlen(NoEvUserOidList[i]) + 1));
        EXPECT_TRUE(result.IsSuccess());

        result = CreateConnection(&sslConnection, &sslContext, TestUrlListNoEv[i]);
        ASSERT_TRUE(result.IsSuccess());

        result = sslConnection.DoHandshake();
        EXPECT_TRUE(result.IsSuccess());

        result = sslConnection.Destroy();
        ASSERT_TRUE(result.IsSuccess());
        result = sslContext.Destroy();
        ASSERT_TRUE(result.IsSuccess());
    }
}
#endif // RUN_NO_EV

// This test MUST be the last test to ensure SSL is finalized properly
TEST(ShimFinalize, Success)
{
    nn::Result result = nn::ssl::Finalize();
    EXPECT_TRUE(result.IsSuccess());
    nn::socket::Finalize();
#ifdef USE_NETWORK
    g_CommonUtil.FinalizeNetwork();
#endif
}
