﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <nn/os.h>
#include <nn/init.h>
#include <nn/nn_Log.h>
#include <nn/nn_Assert.h>
#include <nn/lmem/lmem_ExpHeap.h>
#include <nnt/nntest.h>
#include <nn/socket.h>
#include <nn/ssl.h>

#include <Common/testCommonUtil.h>
#include <Common/testInfraInfo.h>
#include <Common/testServerPki.h>
#include <Common/testClientPki.h>

// ------------------------------------------------------------------------------------------------
// Build flags for tests to run
// ------------------------------------------------------------------------------------------------
#define RUN_SETGET
#define RUN_INVALID
#define RUN_CREATE_FLUSH
#define RUN_SERIAL
#define RUN_WITH_INTERNET
#define RUN_SESSION_TICKET
// #define RUN_PARARELL

// ------------------------------------------------------------------------------------------------
// Build flags
// ------------------------------------------------------------------------------------------------
//#define NO_RESOLVER

// ------------------------------------------------------------------------------------------------
// Grobal parameters
// ------------------------------------------------------------------------------------------------
namespace
{
const char TestHostList[][64] = {
    "www.google.com",
    "facebook.com",
    "twitter.com",
    "youtube.com",
    "www.netflix.com",
    "www.amazon.com",
    "www.google.co.jp",
};
const int StandardSslPortNumber = 443;

SslTestCommonUtil        g_CommonUtil;
NN_ALIGNAS(4096) uint8_t g_SocketMemoryPoolBuffer[nn::socket::DefaultSocketMemoryPoolSize];
} // Un-named namespace

extern "C" void nninitStartup()
{
    NN_LOG("nninitStartup loaded %p\n", nninitStartup);
    // メモリヒープの全体サイズを設定する
    const size_t MemoryHeapSize = 128 * 1024 * 1024;
    auto result = nn::os::SetMemoryHeapSize( MemoryHeapSize );

    ASSERT_TRUE( result.IsSuccess() );

    // メモリヒープから malloc で使用するメモリ領域を確保
    uintptr_t address = 0;

    result = nn::os::AllocateMemoryBlock( &address, MemoryHeapSize );
    ASSERT_TRUE( result.IsSuccess() );

    // malloc 用のメモリ領域を設定する
    nn::init::InitializeAllocator( reinterpret_cast<void*>(address), MemoryHeapSize );
}

// ------------------------------------------------------------------------------------------------
// Threads
// ------------------------------------------------------------------------------------------------
#if defined(RUN_PARARELL)
namespace
{
const uint32_t NumberOfThreads = 4;
const size_t   ThreadStackSize = 1024 * 32;

NN_OS_ALIGNAS_THREAD_STACK uint8_t g_ThreadStack[NumberOfThreads][ThreadStackSize];
nn::os::ThreadType                 g_ThreadTid[NumberOfThreads];

void WorkerFunction(void* arg)
{
    nn::ssl::Connection::VerifyOption verifyOption =
        nn::ssl::Connection::VerifyOption::VerifyOption_None;
    SimpleHttpsClient httpsClient(
        false, // Blocking
        ServerName,
        ServerPort_Normal);
#ifdef NO_RESOLVER
    httpsClient.SetIpAddress(ServerIpAddress);
#endif
    EXPECT_TRUE(httpsClient.Initialize(
        nn::ssl::Context::SslVersion::SslVersion_Auto,
        verifyOption));

    // Enable session cache
    nn::ssl::Connection* pConn = nullptr;
    pConn = httpsClient.GetSslConnection();
    ASSERT_TRUE(pConn != nullptr);
    nn::Result result = pConn->SetSessionCacheMode(nn::ssl::Connection::SessionCacheMode_SessionId);
    EXPECT_TRUE(result.IsSuccess());

    EXPECT_TRUE(httpsClient.PerformSslHandshake(false));
    httpsClient.Finalize();
}

void RunWorkerThreads()
{
    int i = 0;
    nn::Result result;

    NN_LOG("[RunWorkerThreads] Start");

    for(i =0; i < NumberOfThreads; i++)
    {
        result = nn::os::CreateThread(
            &g_ThreadTid[i],
            WorkerFunction,
            nullptr,
            g_ThreadStack[i],
            ThreadStackSize,
            nn::os::LowestThreadPriority);
        EXPECT_TRUE(result.IsSuccess());
    }

    for(i =0; i < NumberOfThreads; i++)
    {
        nn::os::StartThread(&g_ThreadTid[i]);
    }

    for(i =0; i < NumberOfThreads; i++)
    {
        NN_LOG("Waiting thread %d\n", i);
        nn::os::WaitThread(&g_ThreadTid[i]);
        NN_LOG("DONE:Waiting thread %d\n", i);
    }

    for(i =0; i < NumberOfThreads; i++)
    {
        NN_LOG("cleaning thread %d\n", i);
        nn::os::DestroyThread(&g_ThreadTid[i]);
        NN_LOG("DONE:cleaning thread %d\n", i);
    }

    NN_LOG("[RunWorkerThreads] End");
}
}
#endif // RUN_PARARELL

//-------------------------------------------------------------------------------------------------
//  Tests
//-------------------------------------------------------------------------------------------------
TEST(InitTest, Success)
{
    ASSERT_TRUE(g_CommonUtil.SetupNetwork().IsSuccess());
    ASSERT_TRUE(nn::socket::Initialize(
        g_SocketMemoryPoolBuffer,
        nn::socket::DefaultSocketMemoryPoolSize,
        nn::socket::MinSocketAllocatorSize,
        nn::socket::DefaultConcurrencyLimit).IsSuccess());

    ASSERT_TRUE(nn::ssl::Initialize().IsSuccess());
}

#if defined(RUN_SETGET)
TEST(SetGetSessionCacheMode, Success)
{
    SimpleHttpsClient httpsClient(
        false, // Blocking
        ServerName,
        ServerPort_Normal);
#ifdef NO_RESOLVER
    httpsClient.SetIpAddress(ServerIpAddress);
#endif
    EXPECT_TRUE(httpsClient.Initialize(
        nn::ssl::Context::SslVersion::SslVersion_Auto,
        nn::ssl::Connection::VerifyOption::VerifyOption_None));

    nn::ssl::Connection* pConn = nullptr;
    pConn = httpsClient.GetSslConnection();
    ASSERT_TRUE(pConn != nullptr);

    // Get session cache mode (SessionCacheMode_SessionId is default)
    nn::ssl::Connection::SessionCacheMode currMode;
    nn::Result result = pConn->GetSessionCacheMode(&currMode);
    EXPECT_TRUE(result.IsSuccess());
    EXPECT_TRUE(currMode == nn::ssl::Connection::SessionCacheMode_SessionId);
    NN_LOG("Current session cache mode: %x\n", currMode);

    // Set session cache - SessionCacheMode_None
    result = pConn->SetSessionCacheMode(nn::ssl::Connection::SessionCacheMode_None);
    EXPECT_TRUE(result.IsSuccess());
    // Get it again
    result = pConn->GetSessionCacheMode(&currMode);
    EXPECT_TRUE(result.IsSuccess());
    EXPECT_TRUE(currMode == nn::ssl::Connection::SessionCacheMode_None);
    NN_LOG("Current session cache mode: %x\n", currMode);

    // Set session cache - SessionCacheMode_SessionId
    result = pConn->SetSessionCacheMode(nn::ssl::Connection::SessionCacheMode_SessionId);
    EXPECT_TRUE(result.IsSuccess());
    // Get it again
    result = pConn->GetSessionCacheMode(&currMode);
    EXPECT_TRUE(result.IsSuccess());
    EXPECT_TRUE(currMode == nn::ssl::Connection::SessionCacheMode_SessionId);
    NN_LOG("Current session cache mode: %x\n", currMode);

    httpsClient.Finalize();
}
#endif // RUN_SETGET

#if defined(RUN_INVALID)
TEST(Invalid, Success)
{
    nn::Result result;
    SimpleHttpsClient httpsClient(
        false, // Blocking
        ServerName,
        ServerPort_Normal);
#ifdef NO_RESOLVER
    httpsClient.SetIpAddress(ServerIpAddress);
#endif
    EXPECT_TRUE(httpsClient.Initialize(
        nn::ssl::Context::SslVersion::SslVersion_Auto,
        nn::ssl::Connection::VerifyOption::VerifyOption_None));

    nn::ssl::Connection* pConn = nullptr;
    pConn = httpsClient.GetSslConnection();
    ASSERT_TRUE(pConn != nullptr);
    result = pConn->FlushSessionCache();
    EXPECT_TRUE(nn::ssl::ResultNoSslConnection::Includes(result));

    httpsClient.Finalize();
}
#endif // RUN_INVALID

#if defined(RUN_CREATE_FLUSH)
TEST(CreateFlush, Success)
{
    nn::Result result;
    SimpleHttpsClient httpsClient(
        false, // Blocking
        ServerName,
        ServerPort_Normal);
#ifdef NO_RESOLVER
    httpsClient.SetIpAddress(ServerIpAddress);
#endif
    EXPECT_TRUE(httpsClient.Initialize(
        nn::ssl::Context::SslVersion::SslVersion_Auto,
        nn::ssl::Connection::VerifyOption::VerifyOption_None));

    nn::ssl::Connection* pConn = nullptr;
    pConn = httpsClient.GetSslConnection();
    ASSERT_TRUE(pConn != nullptr);

    EXPECT_TRUE(httpsClient.PerformSslHandshake(false));
    EXPECT_TRUE(pConn->FlushSessionCache().IsSuccess());
    httpsClient.Finalize();
}
#endif // RUN_CREATE_FLUSH

#if defined(RUN_SERIAL)
TEST(RunSerial, Success)
{
    static const int clientCount = 8;
    int i = 0;

    SimpleHttpsClient* pHttpsClient[clientCount];
    nn::ssl::Connection::VerifyOption verifyOption =
        nn::ssl::Connection::VerifyOption::VerifyOption_None;

    for (i = 0; i < clientCount; i++)
    {
        pHttpsClient[i] = new SimpleHttpsClient(
            false, // Blocking
            ServerName,
            ServerPort_Normal);
#ifdef NO_RESOLVER
            pHttpsClient[i]->SetIpAddress(ServerIpAddress);
#endif
        EXPECT_TRUE(pHttpsClient[i]->Initialize(
            nn::ssl::Context::SslVersion::SslVersion_Auto,
            verifyOption));
        EXPECT_TRUE(pHttpsClient[i]->PerformSslHandshake(false));
        pHttpsClient[i]->Finalize();
        delete pHttpsClient[i];
    }

    for (i = 0; i < clientCount; i++)
    {
        pHttpsClient[i] = new SimpleHttpsClient(
            false, // Blocking
            ServerName,
            ServerPort_Normal);
#ifdef NO_RESOLVER
            pHttpsClient[i]->SetIpAddress(ServerIpAddress);
#endif
        EXPECT_TRUE(pHttpsClient[i]->Initialize(
            nn::ssl::Context::SslVersion::SslVersion_Auto,
            verifyOption));
    }

    for (i = 0; i < clientCount; i++)
    {
        EXPECT_TRUE(pHttpsClient[i]->PerformSslHandshake(false));
    }

    for (i = 0; i < clientCount; i++)
    {
        pHttpsClient[i]->Finalize();
        delete pHttpsClient[i];
    }
}
#endif // RUN_SERIAL

#if defined(RUN_WITH_INTERNET)
TEST(Internet, Success)
{
    SimpleHttpsClient* pHttpsClient;
    nn::ssl::Connection::VerifyOption verifyOption =
        nn::ssl::Connection::VerifyOption::VerifyOption_None;

    // Create session cache
    for (int i = 0; i < (sizeof(TestHostList) / sizeof(TestHostList[0])); i++)
    {
        pHttpsClient = new SimpleHttpsClient(
            false, // Blocking
            TestHostList[i],
            StandardSslPortNumber);
#ifdef NO_RESOLVER
            pHttpsClient->SetIpAddress(ServerIpAddress);
#endif
        EXPECT_TRUE(pHttpsClient->Initialize(
            nn::ssl::Context::SslVersion::SslVersion_Auto,
            verifyOption));
        EXPECT_TRUE(pHttpsClient->PerformSslHandshake(false));
        pHttpsClient->Finalize();
        delete pHttpsClient;
    }

    // Performs SSL handshake again with the session cache
    for (int i = 0; i < (sizeof(TestHostList) / sizeof(TestHostList[0])); i++)
    {
        pHttpsClient = new SimpleHttpsClient(
            false, // Blocking
            TestHostList[i],
            StandardSslPortNumber);
#ifdef NO_RESOLVER
            pHttpsClient->SetIpAddress(ServerIpAddress);
#endif
        EXPECT_TRUE(pHttpsClient->Initialize(
            nn::ssl::Context::SslVersion::SslVersion_Auto,
            verifyOption));
        EXPECT_TRUE(pHttpsClient->PerformSslHandshake(false));

        //EXPECT_TRUE(pHttpsClient->GetSslConnection()->FlushSessionCache().IsSuccess());
        pHttpsClient->Finalize();
        delete pHttpsClient;
    }
}
#endif // RUN_WITH_INTERNET

#if defined(RUN_PARARELL)
TEST(RunPararell, Success)
{
    nn::ssl::Connection::VerifyOption verifyOption =
        nn::ssl::Connection::VerifyOption::VerifyOption_None;

    SimpleHttpsClient httpsClient(
        false, // Blocking
        ServerName,
        ServerPort_Normal);
#ifdef NO_RESOLVER
    httpsClient.SetIpAddress(ServerIpAddress);
#endif
    EXPECT_TRUE(httpsClient.Initialize(
        nn::ssl::Context::SslVersion::SslVersion_Auto,
        verifyOption));
    EXPECT_TRUE(httpsClient.PerformSslHandshake(false));

    // Performs the other handshake when there's session cache
    RunWorkerThreads();

    httpsClient.Finalize();
}
#endif // RUN_PARARELL


#if defined(RUN_SESSION_TICKET)
// NOTE
// Enable NN_DETAIL_SSL_ENABLE_DEBUG_PRINT and NN_DETAIL_SSL_DBG_PRINT_NSS #defines in ssl_Build.h
// to test it with debug logs because there's no way to see how session caches a managed.
// Don't forget to enable DEBUG, TRACE and NSS_HAVE_GETENV defined as AdditionalPreprocessorMacros
// in the nact file of libnn_nss.
TEST(RunSessionTicket, Success)
{
    typedef struct RunSessionTicketTestMode
    {
        nn::ssl::Connection::SessionCacheMode cacheMode;
        bool                                  isFlushCache;
        int                                   index;
        char                                  caseName[32];
    } RunSessionTicketTestMode;

    const RunSessionTicketTestMode testMode[] = {
        // 1: Create a cache with session ticket and use it next
        {nn::ssl::Connection::SessionCacheMode_SessionTicket, false, 1, "First Case" },
        {nn::ssl::Connection::SessionCacheMode_SessionTicket, true,  1, "Second Case"},
        // 2: Create a cache with session ticket and use session ID next, second handshake will be done without a cache
        {nn::ssl::Connection::SessionCacheMode_SessionTicket, false, 2, "First Case" },
        {nn::ssl::Connection::SessionCacheMode_SessionId,     true,  2, "Second Case"},
        // 3: Create a cache with session ID and use session ticket next, second handshake will be done by session ID
        {nn::ssl::Connection::SessionCacheMode_SessionId,     false, 3, "First Case" },
        {nn::ssl::Connection::SessionCacheMode_SessionTicket, true,  3, "Second Case"},
        // 4: Create a cache with session ID and use create a SSL connection without cache next, second handshake will be done without a cache
        {nn::ssl::Connection::SessionCacheMode_SessionId,     false, 4, "First Case" },
        {nn::ssl::Connection::SessionCacheMode_None,          false, 4, "Second Case"},
        { nn::ssl::Connection::SessionCacheMode_SessionId,    true , 4, "Third Case" }, // to delete session cache
        // 5: Create a cache with session ticket and use create a SSL connection without cache next, second handshake will be done without a cache
        {nn::ssl::Connection::SessionCacheMode_SessionTicket, false, 5, "First Case" },
        {nn::ssl::Connection::SessionCacheMode_None,          false, 5, "Second Case"},
        {nn::ssl::Connection::SessionCacheMode_SessionTicket, true , 5, "Third Case" }, // to delete session cache
    };

    int        runCount = sizeof(testMode) / sizeof(testMode[0]);
    nn::Result result;

    for (int i = 0; i < runCount; i++)
    {
        NN_LOG("\n\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n");
        NN_LOG("!!!! Test case - %d : %s\n", testMode[i].index, testMode[i].caseName);
        NN_LOG("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n\n\n");
        SimpleHttpsClient httpsClient(
            false, // Blocking
            "www.google.com",
            ServerPort_Normal);
        EXPECT_TRUE(httpsClient.Initialize(
            nn::ssl::Context::SslVersion::SslVersion_Auto,
            nn::ssl::Connection::VerifyOption::VerifyOption_None));

        nn::ssl::Connection* pConn = nullptr;
        pConn = httpsClient.GetSslConnection();
        ASSERT_TRUE(pConn != nullptr);

        result = pConn->SetSessionCacheMode(testMode[i].cacheMode);
        ASSERT_TRUE(result.IsSuccess());

        EXPECT_TRUE(httpsClient.PerformSslHandshake(false));

        if (testMode[i].isFlushCache == true)
        {
            EXPECT_TRUE(pConn->FlushSessionCache().IsSuccess());
        }
        httpsClient.Finalize();
    }
}
#endif // #define RUN_SESSION_TICKET

TEST(FinalizeTest, Success)
{
    EXPECT_TRUE(nn::ssl::Finalize().IsSuccess());
    nn::socket::Finalize();
    g_CommonUtil.FinalizeNetwork();
}
