﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <cstdlib>
#include <cstdio>
#include <cstring>

#include <nn/init.h>
#include <nn/os.h>
#include <nn/nn_Common.h>
#include <nn/nn_Log.h>
#include <nn/nn_Assert.h>

#include <nnc/nn_Macro.h>
#include <nnc/nn_Result.h>

#include <nn/ssl.h>

#include <nnt/nntest.h>
#include <nnt/base/testBase_Exit.h>
#include <nnt/result/testResult_Assert.h>
#include <nnt/nnt_Argument.h>

using namespace std;

namespace nnt { namespace ssl {


TEST(SslInit, Success)
{
    nn::Result                  result;

    result = nn::ssl::Initialize();
    EXPECT_TRUE(result.IsSuccess());
}


TEST(ContextCreateDestroy, Success)
{
    nn::ssl::Context            ctx = nn::ssl::Context();
    nn::Result                  result;


    result = ctx.Create(nn::ssl::Context::SslVersion::SslVersion_Auto);
    EXPECT_TRUE(result.IsSuccess());

    result = ctx.Destroy();
    EXPECT_TRUE(result.IsSuccess());
}


TEST(MultiSslContext, ImposeLimits)
{
    nn::ssl::Context            testCtxs[nn::ssl::MaxContextCount + 1];
    nn::Result                  result;

    for (int i = 0; i < (nn::ssl::MaxContextCount + 1); i++)
    {

        testCtxs[i] = nn::ssl::Context();
        result = testCtxs[i].Create(nn::ssl::Context::SslVersion::SslVersion_Auto);
        if (i < nn::ssl::MaxContextCount)
        {
            EXPECT_FALSE(result.IsFailure());
        }
        else
        {
            EXPECT_TRUE(result.IsFailure());
        }
    }

    for (int i = 0; i < nn::ssl::MaxContextCount; i++)
    {
        result = testCtxs[i].Destroy();
        EXPECT_FALSE(result.IsFailure());
    }
}


TEST(SslConnectionCreateDestroy, Success)
{
    nn::ssl::Context            ctx;
    nn::ssl::Connection         conn;
    nn::Result                  result;

    ctx = nn::ssl::Context();
    result = ctx.Create(nn::ssl::Context::SslVersion::SslVersion_Auto);
    ASSERT_FALSE(result.IsFailure());

    conn = nn::ssl::Connection();
    result = conn.Create(&ctx);
    EXPECT_FALSE(result.IsFailure());

    conn.Destroy();
    result = ctx.Destroy();
    EXPECT_FALSE(result.IsFailure());
}


TEST(MultiSslConnection, ImposeLimits)
{
    nn::ssl::Context            ctx;
    nn::ssl::Connection         conns[nn::ssl::MaxConnectionCount + 1];
    nn::Result                  result;

    ctx = nn::ssl::Context();
    result = ctx.Create(nn::ssl::Context::SslVersion::SslVersion_Auto);
    ASSERT_FALSE(result.IsFailure());

    for (int i = 0; i < (nn::ssl::MaxConnectionCount + 1); i++)
    {
        conns[i] = nn::ssl::Connection();
        result = conns[i].Create(&ctx);
        if (i < nn::ssl::MaxConnectionCount)
        {
            EXPECT_FALSE(result.IsFailure());
        }
        else
        {
            EXPECT_TRUE(result.IsFailure());
        }
    }

    for (int i = 0; i < nn::ssl::MaxConnectionCount; i++)
    {
        result = conns[i].Destroy();
        EXPECT_FALSE(result.IsFailure());
    }

    result = ctx.Destroy();
    EXPECT_FALSE(result.IsFailure());
}


TEST(MultiSslContextConnection, ImposeLimits)
{
    nn::ssl::Context            ctx1;
    nn::ssl::Context            ctx2;
    nn::ssl::Connection         conns[nn::ssl::MaxConnectionCount + 1];
    nn::Result                  result;

    ctx1 = nn::ssl::Context();
    result = ctx1.Create(nn::ssl::Context::SslVersion::SslVersion_Auto);
    ASSERT_FALSE(result.IsFailure());

    ctx2 = nn::ssl::Context();
    result = ctx2.Create(nn::ssl::Context::SslVersion::SslVersion_Auto);
    ASSERT_FALSE(result.IsFailure());

    for (int i = 0; i < (nn::ssl::MaxConnectionCount + 1); i++)
    {
        nn::ssl::Context        *curCtx = ((i & 0x1) == 0) ? &ctx2 : &ctx1;
        conns[i] = nn::ssl::Connection();
        result = conns[i].Create(curCtx);
        if (i < nn::ssl::MaxConnectionCount)
        {
            EXPECT_FALSE(result.IsFailure());
        }
        else
        {
            EXPECT_TRUE(result.IsFailure());
        }
    }

    for (int i = 0; i < nn::ssl::MaxConnectionCount; i++)
    {
        result = conns[i].Destroy();
        EXPECT_FALSE(result.IsFailure());
    }

    result = ctx1.Destroy();
    EXPECT_FALSE(result.IsFailure());

    result = ctx2.Destroy();
    EXPECT_FALSE(result.IsFailure());
}


static void ContextCreatorThreadCb(void *arg)
{
    NN_UNUSED(arg);

    nn::ssl::Context            ctx;
    nn::Result                  result;
    nn::os::ThreadType          *thread = nn::os::GetCurrentThread();

    NN_LOG("[ContextCreatorThreadCb-%p] start\n", thread);

    result = ctx.Create(nn::ssl::Context::SslVersion::SslVersion_Auto);
    ASSERT_TRUE(result.IsSuccess());

    NN_LOG("[ContextCreatorThreadCb-%p] ctx created %p\n", thread, &ctx);

    nn::os::YieldThread();
    result = ctx.Destroy();
    ASSERT_TRUE(result.IsSuccess());

    NN_LOG("[ContextCreatorThreadCb-%p] ctx destroyed %p\n", thread, &ctx);
}


static NN_OS_ALIGNAS_GUARDED_STACK uint8_t  t1Stack[8192];
static NN_OS_ALIGNAS_GUARDED_STACK uint8_t  t2Stack[8192];

TEST(MultiThreadContextCreateDestroy, Success)
{
    nn::os::ThreadType          t1;
    nn::os::ThreadType          t2;
    nn::Result                  result;

    result = nn::os::CreateThread(&t1,
                                  ContextCreatorThreadCb,
                                  nullptr,
                                  reinterpret_cast<void *>(t1Stack),
                                  sizeof(t1Stack),
                                  nn::os::DefaultThreadPriority);
    ASSERT_TRUE(result.IsSuccess());

    NN_LOG("[MultiThreadContextCreateDestroy] *t1 is %p\n", &t1);

    result = nn::os::CreateThread(&t2,
                                  ContextCreatorThreadCb,
                                  nullptr,
                                  reinterpret_cast<void *>(t2Stack),
                                  sizeof(t2Stack),
                                  nn::os::DefaultThreadPriority);
    ASSERT_TRUE(result.IsSuccess());

    NN_LOG("[MultiThreadContextCreateDestroy] *t2 is %p\n", &t2);

    nn::os::StartThread(&t1);
    nn::os::StartThread(&t2);
    nn::os::WaitThread(&t2);
    nn::os::WaitThread(&t1);
    nn::os::DestroyThread(&t1);
    nn::os::DestroyThread(&t2);
}


static void SharedContextThreadCb(void *arg)
{
    nn::ssl::Context            *ctx = reinterpret_cast<nn::ssl::Context *>(arg);
    nn::ssl::Connection         conns[nn::ssl::MaxConnectionCount / 2];
    nn::Result                  result;
    int                         i;

    nn::os::ThreadType          *thread = nn::os::GetCurrentThread();

    NN_LOG("[SharedContextThreadCb-%p] start, ctx is %p\n", thread, ctx);

    for (i = 0; i < (nn::ssl::MaxConnectionCount / 2); i++)
    {
        conns[i] = nn::ssl::Connection();
        result = conns[i].Create(ctx);

        NN_LOG("[SharedContextThreadCb-%p] create conn %d (%p)\n",
               thread,
               i,
               &conns[i]);
        EXPECT_TRUE(result.IsSuccess());
    }

    for (i = 0; i < (nn::ssl::MaxConnectionCount / 2); i++)
    {
        result = conns[i].Destroy();

        NN_LOG("[SharedContextThreadCb-%p] destroy conn %d (%p)\n",
               thread,
               i,
               &conns[i]);

        EXPECT_TRUE(result.IsSuccess());
    }
}


TEST(SharedContext, Success)
{
    nn::os::ThreadType          t1;
    nn::os::ThreadType          t2;
    nn::Result                  result;
    nn::ssl::Context            ctx;

    result = ctx.Create(nn::ssl::Context::SslVersion::SslVersion_Auto);
    ASSERT_TRUE(result.IsSuccess());

    result = nn::os::CreateThread(&t1,
                                  SharedContextThreadCb,
                                  &ctx,
                                  reinterpret_cast<void *>(t1Stack),
                                  sizeof(t1Stack),
                                  nn::os::DefaultThreadPriority);
    ASSERT_TRUE(result.IsSuccess());

    NN_LOG("[SharedContext] *t1 is %p\n", &t1);

    result = nn::os::CreateThread(&t2,
                                  SharedContextThreadCb,
                                  &ctx,
                                  reinterpret_cast<void *>(t2Stack),
                                  sizeof(t2Stack),
                                  nn::os::DefaultThreadPriority);
    ASSERT_TRUE(result.IsSuccess());

    NN_LOG("[SharedContext] *t2 is %p\n", &t2);

    nn::os::StartThread(&t1);
    nn::os::StartThread(&t2);
    nn::os::WaitThread(&t2);
    nn::os::WaitThread(&t1);
    nn::os::DestroyThread(&t1);
    nn::os::DestroyThread(&t2);

    ctx.Destroy();
}


TEST(SslFinalize, Success)
{
    nn::ssl::Finalize();
}

} }

extern "C"
{

void nnMain()
{
    int                         argc = nnt::GetHostArgc();
    char                        **argv = nnt::GetHostArgv();

    //  Init the test harness, get it ready to go
    ::testing::InitGoogleTest(&argc, argv);

    const int exitCode = RUN_ALL_TESTS();
    nnt::Exit(exitCode);
}

void nninitStartup()
{
    const size_t MemoryHeapSize = 24 * 1024 * 1024;
    auto result = nn::os::SetMemoryHeapSize( MemoryHeapSize );

    NN_ASSERT( result.IsSuccess() );

    // メモリヒープから malloc で使用するメモリ領域を確保
    uintptr_t address = 0;

    result = nn::os::AllocateMemoryBlock( &address, MemoryHeapSize );
    NN_ASSERT( result.IsSuccess() );

    // malloc 用のメモリ領域を設定する
    nn::init::InitializeAllocator( reinterpret_cast<void*>(address), MemoryHeapSize );

    NN_LOG("nninitStartup: loaded at %p\n", (void *)nninitStartup);
}

}
