﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <nn/os.h>
#include <nn/init.h>
#include <nn/nn_Log.h>
#include <nn/nn_Assert.h>
#include <nn/lmem/lmem_ExpHeap.h>
#include <nnt/nntest.h>
#include <nn/socket.h>
#include <nn/ssl.h>

#include <Common/testCommonUtil.h>
#include <Common/testInfraInfo.h>
#include <Common/testServerPki.h>
#include <Common/testClientPki.h>

#include <Utils/CommandLineParser.h>

// ------------------------------------------------------------------------------------------------
// Build flags for tests to run
// ------------------------------------------------------------------------------------------------
#define RUN_SERVER_CERT_CHAIN     // Verifies getting server cert and chain after handshake

// ------------------------------------------------------------------------------------------------
// Grobal parameters
// ------------------------------------------------------------------------------------------------
namespace
{
SslTestCommonUtil        g_CommonUtil;
NN_ALIGNAS(4096) uint8_t g_SocketMemoryPoolBuffer[nn::socket::DefaultSocketMemoryPoolSize];
} // Un-named namespace

extern "C" void nninitStartup()
{
    NN_LOG("nninitStartup loaded %p\n", nninitStartup);
    // メモリヒープの全体サイズを設定する
    const size_t MemoryHeapSize = 128 * 1024 * 1024;
    auto result = nn::os::SetMemoryHeapSize( MemoryHeapSize );

    ASSERT_TRUE( result.IsSuccess() );

    // メモリヒープから malloc で使用するメモリ領域を確保
    uintptr_t address = 0;

    result = nn::os::AllocateMemoryBlock( &address, MemoryHeapSize );
    ASSERT_TRUE( result.IsSuccess() );

    // malloc 用のメモリ領域を設定する
    nn::init::InitializeAllocator( reinterpret_cast<void*>(address), MemoryHeapSize );
}

//-------------------------------------------------------------------------------------------------
//  Tests
//-------------------------------------------------------------------------------------------------
TEST(InitTest, Success)
{
    nn::util::Uuid netProfile = nn::util::InvalidUuid;
    NATF::Utils::ParserGroup parser;

    parser.AddParser(NATF::Utils::UuidParser ("--NetProfile", &nn::util::InvalidUuid, netProfile));

    int      argc = nn::os::GetHostArgc();
    char**   argv = nn::os::GetHostArgv();

    if (!parser.Parse(argc, argv))
    {
        NN_LOG("\n * Failed to parse command line arguements!\n\n");
        FAIL();
        return;
    }

    ASSERT_TRUE(g_CommonUtil.SetupNetwork(netProfile));
    ASSERT_TRUE(nn::socket::Initialize(
        g_SocketMemoryPoolBuffer,
        nn::socket::DefaultSocketMemoryPoolSize,
        nn::socket::MinSocketAllocatorSize,
        nn::socket::DefaultConcurrencyLimit).IsSuccess());

    ASSERT_TRUE(nn::ssl::Initialize().IsSuccess());
}

#if defined(RUN_SERVER_CERT_CHAIN)
#define CALL_1_ARG_EXIT_ON_ERR(f, arg1) \
    result = f(arg1);                   \
    if (result.IsFailure()) {           \
        break;                          \
    }

#define CALL_2_ARG_EXIT_ON_ERR(f, arg1, arg2) \
    result = f(arg1, arg2);                   \
    if (result.IsFailure()) {                 \
        break;                                \
    }

const char               GoogleServer[] = "google.com";

static nn::Result setupSslConnection(nn::ssl::Context    *pCtx,
                                     nn::ssl::Connection *pConn)
{
    nn::Result                  result;

    do
    {
        int socketFd =
            g_CommonUtil.CreateTcpSocket(true,
                                         ServerPort_Normal,
                                         GoogleServer,
                                         0);
        if (socketFd < 0)
        {
            result = nn::ssl::ResultNoTcpConnection();
            break;
        }

        CALL_1_ARG_EXIT_ON_ERR(pConn->Create, pCtx);
        CALL_1_ARG_EXIT_ON_ERR(pConn->SetSocketDescriptor, socketFd);
        CALL_2_ARG_EXIT_ON_ERR(pConn->SetHostName, GoogleServer, strlen(GoogleServer));
        CALL_1_ARG_EXIT_ON_ERR(pConn->SetVerifyOption, nn::ssl::Connection::VerifyOption::VerifyOption_Default);
        CALL_1_ARG_EXIT_ON_ERR(pConn->SetSessionCacheMode, nn::ssl::Connection::SessionCacheMode_None);
    } while (NN_STATIC_CONDITION(false));

    return result;
}


TEST(ShimServerCertChain, Success)
{
    nn::Result result;

    NN_LOG("Server chain test start...\n");

    nn::ssl::Context*    pSslContext = new nn::ssl::Context();;
    nn::ssl::Connection* pSslConnection = new nn::ssl::Connection();;
    ASSERT_TRUE(pSslContext != nullptr);
    ASSERT_TRUE(pSslConnection != nullptr);

    result = pSslContext->Create(nn::ssl::Context::SslVersion_Auto);
    ASSERT_TRUE(result.IsSuccess());

    //  The same nn::ssl::Connection object will be re-used.  However, after
    //  handshake is done, a new one cannot be started on the same socket
    //  as it will block and there is no API to reset the handshake state.
    //  So each time we need to perform a new handshake, we destroy the
    //  existing Connection (but not delete it), close the socket, then
    //  recreate the Connection and the socket.
    result = setupSslConnection(pSslContext, pSslConnection);
    ASSERT_TRUE(result.IsSuccess());

    //  First pass, use old API with no cert buffer (so no chain received)
    NN_LOG("\nPass 1, regular handshake\n");
    result = pSslConnection->DoHandshake();
    EXPECT_TRUE(result.IsSuccess());

    int socketFd = -1;
    result = pSslConnection->GetSocketDescriptor(&socketFd);
    ASSERT_TRUE(result.IsSuccess());
    g_CommonUtil.CloseTcpSocket(socketFd);
    pSslConnection->Destroy();

    //  Second pass, use old API with cert buffer
    uint32_t serverCertSize;
    uint32_t serverCertChainCount;
    uint32_t tmp;
    char     *serverCertBuf;
    char     smallBuf[4];

    NN_LOG("\nPass 2, server cert only API w/cert buf\n");
    result = setupSslConnection(pSslContext, pSslConnection);
    ASSERT_TRUE(result.IsSuccess());

    result = pSslConnection->SetServerCertBuffer(smallBuf, sizeof(smallBuf));
    result = pSslConnection->DoHandshake(&tmp, &serverCertChainCount);
    ASSERT_TRUE(result.IsFailure());
    ASSERT_TRUE(nn::ssl::ResultInsufficientServerCertBuffer::Includes(result));
    result = pSslConnection->GetNeededServerCertBufferSize(&serverCertSize);
    ASSERT_TRUE(result.IsSuccess());
    serverCertBuf = new char[serverCertSize];
    ASSERT_TRUE(serverCertBuf != nullptr);

    //  Do handshake with the new buffer, get only the server cert
    result = pSslConnection->GetSocketDescriptor(&socketFd);
    ASSERT_TRUE(result.IsSuccess());
    g_CommonUtil.CloseTcpSocket(socketFd);
    pSslConnection->Destroy();
    result = setupSslConnection(pSslContext, pSslConnection);
    ASSERT_TRUE(result.IsSuccess());
    result = pSslConnection->SetServerCertBuffer(serverCertBuf, serverCertSize);
    ASSERT_TRUE(result.IsSuccess());
    result = pSslConnection->DoHandshake(&tmp, &serverCertChainCount);
    ASSERT_TRUE(result.IsSuccess());
    ASSERT_TRUE(serverCertChainCount == 1);
    delete[] serverCertBuf;
    serverCertBuf = nullptr;

    result = pSslConnection->GetSocketDescriptor(&socketFd);
    ASSERT_TRUE(result.IsSuccess());
    g_CommonUtil.CloseTcpSocket(socketFd);
    pSslConnection->Destroy();

    //  Third pass, use new API to request the entire chain
    NN_LOG("\nPass 3, get full chain using old API\n");
    result = setupSslConnection(pSslContext, pSslConnection);
    ASSERT_TRUE(result.IsSuccess());

    uint32_t serverCertChainSize;
    result = pSslConnection->SetOption(nn::ssl::Connection::OptionType_GetServerCertChain, true);
    ASSERT_TRUE(result.IsSuccess());

    //  Clear the buffer, do handshake to get the count and size needed
    result = pSslConnection->DoHandshake();
    ASSERT_TRUE(result.IsFailure());
    ASSERT_TRUE(nn::ssl::ResultInsufficientServerCertBuffer::Includes(result));
    result = pSslConnection->GetNeededServerCertBufferSize(&serverCertChainSize);
    ASSERT_TRUE(result.IsSuccess());
    EXPECT_TRUE(serverCertChainSize > serverCertSize);
    serverCertBuf = new char[serverCertChainSize];
    ASSERT_TRUE(serverCertBuf != nullptr);

    result = pSslConnection->GetSocketDescriptor(&socketFd);
    ASSERT_TRUE(result.IsSuccess());
    g_CommonUtil.CloseTcpSocket(socketFd);
    pSslConnection->Destroy();

    result = setupSslConnection(pSslContext, pSslConnection);
    ASSERT_TRUE(result.IsSuccess());
    result = pSslConnection->SetOption(nn::ssl::Connection::OptionType_GetServerCertChain, true);
    ASSERT_TRUE(result.IsSuccess());

    //  Handshake with the OLD APIs
    result = pSslConnection->SetServerCertBuffer(serverCertBuf, serverCertChainSize);
    ASSERT_TRUE(result.IsSuccess());
    result = pSslConnection->DoHandshake(&tmp, &serverCertChainCount);
    ASSERT_TRUE(result.IsSuccess());

    //  Verify there are multiple certs in the chain, dump their length
    EXPECT_TRUE(serverCertChainCount > 1);
    for (uint32_t i = 0; i < serverCertChainCount; i++)
    {
        nn::ssl::Connection::ServerCertDetail detail;

        result = pSslConnection->GetServerCertDetail(&detail,
                                                     serverCertBuf,
                                                     i);
        EXPECT_TRUE(result.IsSuccess());
        NN_LOG("Cert %u, len %u\n", i, detail.dataSize);
    }

    result = pSslConnection->GetSocketDescriptor(&socketFd);
    ASSERT_TRUE(result.IsSuccess());
    g_CommonUtil.CloseTcpSocket(socketFd);
    pSslConnection->Destroy();

    NN_LOG("\nPass 4, get full chain using new API\n");
    result = setupSslConnection(pSslContext, pSslConnection);
    ASSERT_TRUE(result.IsSuccess());
    result = pSslConnection->SetOption(nn::ssl::Connection::OptionType_GetServerCertChain, true);
    ASSERT_TRUE(result.IsSuccess());

    char *directServerCertBuf = new char[serverCertChainSize];
    ASSERT_TRUE(directServerCertBuf != nullptr);
    memset(directServerCertBuf, 0, serverCertChainSize);
    result = pSslConnection->DoHandshake(&tmp,
                                         &serverCertChainCount,
                                         directServerCertBuf,
                                         serverCertChainSize);
    ASSERT_TRUE(result.IsSuccess());

    EXPECT_TRUE(serverCertChainCount > 1);
    for (uint32_t i = 0; i < serverCertChainCount; i++)
    {
        nn::ssl::Connection::ServerCertDetail detail;
        nn::ssl::Connection::ServerCertDetail detail2;

        result = pSslConnection->GetServerCertDetail(&detail,
                                                     directServerCertBuf,
                                                     i);
        EXPECT_TRUE(result.IsSuccess());
        NN_LOG("Cert %u, len %u\n", i, detail.dataSize);

        result = pSslConnection->GetServerCertDetail(&detail2,
                                                     serverCertBuf,
                                                     i);
        EXPECT_TRUE(result.IsSuccess());
        ASSERT_TRUE(detail.dataSize == detail2.dataSize);
        ASSERT_TRUE(memcmp(detail.pDerData, detail2.pDerData, detail.dataSize) == 0);
    }

    delete[] directServerCertBuf;
    delete[] serverCertBuf;

    pSslConnection->Destroy();
    delete pSslConnection;
    pSslContext->Destroy();
    delete pSslContext;
    NN_LOG("Server cert chain test done\n");
}    // NOLINT(impl/function_size)
#endif    //  RUN_SERVER_CERT_CHAIN

TEST(FinalizeTest, Success)
{
    EXPECT_TRUE(nn::ssl::Finalize().IsSuccess());
    nn::socket::Finalize();
    g_CommonUtil.FinalizeNetwork();
}
