﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include "test_Common.h"
#include <nn/svc/svc_Tcb.h>
#include <nn/svc/ipc/svc_SessionMessage.h>

namespace {

const int64_t SleepTime = 100 * 1000 * 1000;
nn::Bit32 g_ServerBuffer[0x2000 / sizeof(nn::Bit32)] __attribute__((aligned(0x1000)));
nn::Bit32 g_ClientBuffer[0x1000 / sizeof(nn::Bit32)] __attribute__((aligned(0x1000)));
char g_Buffer[2][DefaultStackSize] __attribute__((aligned(0x1000)));
bool g_Start;
int g_SequenceCount;
volatile nn::Bit32 g_Lock;

void DummyThread(uintptr_t arg)
{
    NN_UNUSED(arg);
    AutoThreadExit autoExit;
    while(!g_Start)
    {
        nn::svc::SleepThread(SleepTime);
    }
}

void WaitSynchronizationThread(uintptr_t arg)
{
    NN_UNUSED(arg);
    AutoThreadExit autoExit;
    nn::Result result;
    nn::svc::Handle handle;
    nn::svc::Handle handles[1];
    uintptr_t pc = reinterpret_cast<uintptr_t>(DummyThread);;
    uintptr_t sp = reinterpret_cast<uintptr_t>(&g_Buffer[1]) + sizeof(g_Buffer[1]);
    int32_t signalIndex;

    g_Start = false;
    result = nn::svc::CreateThread(&handle, pc, 0, sp, TestLowestThreadPriority, 0);
    ASSERT_RESULT_SUCCESS(result);
    handles[0] = handle;
    result = nn::svc::StartThread(handle);
    ASSERT_RESULT_SUCCESS(result);

    // WaitSynchronization を呼び出す前にCancelSynchronization が呼ばれる
    result = nn::svc::WaitSynchronization(&signalIndex, handles, 1, -1);
    ASSERT_RESULT_FAILURE_VALUE(result, nn::svc::ResultCancelled());

    g_SequenceCount = 1;

    // WaitSynchronization を呼び出した後にCancelSynchronization が呼ばれる
    result = nn::svc::WaitSynchronization(&signalIndex, handles, 1, -1);
    ASSERT_RESULT_FAILURE_VALUE(result, nn::svc::ResultCancelled());

    g_Start = true;
    result = nn::svc::WaitSynchronization(&signalIndex, handles, 1, -1);
    ASSERT_RESULT_SUCCESS(result);
    result = nn::svc::CloseHandle(handle);
    ASSERT_RESULT_SUCCESS(result);
}

void WaitReplyAndRecevieThread(nn::svc::Handle *handle)
{
    AutoThreadExit autoExit;
    nn::Result result;
    int32_t index;
    nn::svc::Handle handles[1] = { *handle };

    // ReplyAndReceive を呼び出す前にCancelSynchronization が呼ばれる
    result = nn::svc::ReplyAndReceive(&index, handles, 1, nn::svc::INVALID_HANDLE_VALUE, -1);
    ASSERT_RESULT_FAILURE_VALUE(result, nn::svc::ResultCancelled());

    g_SequenceCount = 1;

    // ReplyAndReceive を呼び出した後にCancelSynchronization が呼ばれる
    result = nn::svc::ReplyAndReceive(&index, handles, 1, nn::svc::INVALID_HANDLE_VALUE, -1);
    ASSERT_RESULT_FAILURE_VALUE(result, nn::svc::ResultCancelled());
}

void WaitReplyAndRecevieWithUserBufferThread(nn::svc::Handle *handle)
{
    AutoThreadExit autoExit;
    nn::Result result;
    int32_t index;
    nn::svc::Handle handles[1] = { *handle };
    uintptr_t pBuffer = reinterpret_cast<uintptr_t>(g_ServerBuffer);
    size_t size = sizeof(g_ServerBuffer);

    // ReplyAndReceiveWithUserBuffer を呼び出す前にCancelSynchronization が呼ばれる
    result = nn::svc::ReplyAndReceiveWithUserBuffer(
        &index, pBuffer, size, handles, 1, nn::svc::INVALID_HANDLE_VALUE, -1);
    ASSERT_RESULT_FAILURE_VALUE(result, nn::svc::ResultCancelled());

    g_SequenceCount = 1;

    // ReplyAndReceiveWithUserBuffer を呼び出した後にCancelSynchronization が呼ばれる
    result = nn::svc::ReplyAndReceiveWithUserBuffer(
        &index, pBuffer, size, handles, 1, nn::svc::INVALID_HANDLE_VALUE, -1);
    ASSERT_RESULT_FAILURE_VALUE(result, nn::svc::ResultCancelled());
}

struct TestData
{
    nn::svc::Handle handle;
    uintptr_t addr;
};

void WaitLockThread(uintptr_t arg)
{
    AutoThreadExit autoExit;
    nn::Result result;
    TestData* testData = reinterpret_cast<TestData*>(arg);

    NN_ASSERT(g_SequenceCount == 0);
    g_SequenceCount = 1;

    result = nn::svc::ArbitrateLock(testData->handle, testData->addr, 1);
    ASSERT_RESULT_SUCCESS(result);

    g_SequenceCount = 2;
}

const uintptr_t WaitKey = 0;

void WaitProcessWideThread(uintptr_t addr)
{
    AutoThreadExit autoExit;

    NN_ASSERT(g_SequenceCount == 0);
    nn::Bit32 newValue = 1;
    g_SequenceCount = 1;
    nn::Result result = nn::svc::WaitProcessWideKeyAtomic(addr, WaitKey, newValue, -1);
    NN_ASSERT_RESULT_SUCCESS(result);
    NN_ASSERT(g_SequenceCount == 1);
    g_SequenceCount = 2;
}

void SendSyncRequestThread(uintptr_t arg)
{
    AutoThreadExit autoExit;

    nn::Bit32* pMsgBuffer = nn::svc::ipc::GetMessageBuffer();
    nn::svc::Handle handle = *reinterpret_cast<nn::svc::Handle*>(arg);

    nn::svc::ipc::MessageBuffer ipcMsg(pMsgBuffer);
    ipcMsg.SetNull();

    NN_ASSERT(g_SequenceCount == 0);
    g_SequenceCount = 1;
    nn::Result result = nn::svc::SendSyncRequest(handle);
    NN_ASSERT(g_SequenceCount == 1);
    NN_ASSERT_RESULT_SUCCESS(result);
    g_SequenceCount = 2;
}

void SendSyncRequestWithUserBufferThread(uintptr_t arg)
{
    AutoThreadExit autoExit;

    nn::Bit32* pMsgBuffer = g_ClientBuffer;
    nn::svc::Handle handle = *reinterpret_cast<nn::svc::Handle*>(arg);

    nn::svc::ipc::MessageBuffer ipcMsg(pMsgBuffer);
    ipcMsg.SetNull();

    NN_ASSERT(g_SequenceCount == 0);
    g_SequenceCount = 1;
    nn::Result result = nn::svc::SendSyncRequestWithUserBuffer(
            reinterpret_cast<uintptr_t>(g_ClientBuffer), sizeof(g_ClientBuffer), handle);
    NN_ASSERT(g_SequenceCount == 1);
    NN_ASSERT_RESULT_SUCCESS(result);
    g_SequenceCount = 2;
}
} // namespace

TEST(CancelSynchronization, CancelWaitSynchronizaitionTest)
{
    nn::Result result;
    nn::svc::Handle handle;
    uintptr_t pc = reinterpret_cast<uintptr_t>(WaitSynchronizationThread);;
    uintptr_t sp = reinterpret_cast<uintptr_t>(&g_Buffer[0]) + sizeof(g_Buffer[0]);

    for (int32_t idealCore = 0; idealCore < NumCore; idealCore++)
    {
        for (int32_t priority = TestHighestThreadPriority;
                priority <= TestLowestThreadPriority; priority++)
        {
            g_SequenceCount = 0;
            result = nn::svc::CreateThread(
                    &handle, pc, 0, sp, priority, idealCore);
            ASSERT_RESULT_SUCCESS(result);

            // TEST 22-5 (同じコア) , 22-6 (違うコア)
            // WaitSynchronization を呼び出す前にCancelSynchronization をすると、
            // すぐに帰ってくる
            result = nn::svc::CancelSynchronization(handle);
            ASSERT_RESULT_SUCCESS(result);

            result = nn::svc::StartThread(handle);
            ASSERT_RESULT_SUCCESS(result);

            while(g_SequenceCount == 0)
            {
                nn::svc::SleepThread(SleepTime);
            }

            // TEST 22-7 (同じコア) , 22-8 (違うコア)
            // WaitSynchronization を呼び出した後にCancelSynchronization をすると、
            // スレッドを起こすことが出来る
            result = nn::svc::CancelSynchronization(handle);
            ASSERT_RESULT_SUCCESS(result);

            int32_t index;
            result = nn::svc::WaitSynchronization(&index, &handle, 1, -1);
            ASSERT_RESULT_SUCCESS(result);
            result = nn::svc::CloseHandle(handle);
            ASSERT_RESULT_SUCCESS(result);
        }
    }
}

TEST(CancelSynchronization, CancelReplyAndRecieveTest)
{
    nn::Result result;
    nn::svc::Handle handle;
    nn::svc::Handle serverSession;
    nn::svc::Handle clientSession;
    uintptr_t pc = reinterpret_cast<uintptr_t>(WaitReplyAndRecevieThread);;
    uintptr_t sp = reinterpret_cast<uintptr_t>(&g_Buffer[0]) + sizeof(g_Buffer[0]);

    result = nn::svc::CreateSession(&serverSession, &clientSession, false, 0);
    ASSERT_RESULT_SUCCESS(result);

    for (int32_t idealCore = 0; idealCore < NumCore; idealCore++)
    {
        for (int32_t priority = TestHighestThreadPriority;
                priority <= TestLowestThreadPriority; priority++)
        {
            g_SequenceCount = 0;
            result = nn::svc::CreateThread(
                    &handle, pc, reinterpret_cast<uintptr_t>(&serverSession),
                    sp, priority, idealCore);
            ASSERT_RESULT_SUCCESS(result);

            // TEST 22-9 (同じコア) , 22-10 (違うコア)
            // ReplyAndReceive を呼び出す前にCancelSynchronization をすると、
            // すぐに帰ってくる
            result = nn::svc::CancelSynchronization(handle);
            ASSERT_RESULT_SUCCESS(result);

            result = nn::svc::StartThread(handle);
            ASSERT_RESULT_SUCCESS(result);

            while(g_SequenceCount == 0)
            {
                nn::svc::SleepThread(SleepTime);
            }

            // TEST 22-11 (同じコア) , 22-12 (違うコア)
            // ReplyAndReceive を呼び出した後にCancelSynchronization をすると、
            // スレッドを起こすことが出来る
            result = nn::svc::CancelSynchronization(handle);
            ASSERT_RESULT_SUCCESS(result);

            int32_t index;
            result = nn::svc::WaitSynchronization(&index, &handle, 1, -1);
            ASSERT_RESULT_SUCCESS(result);
            result = nn::svc::CloseHandle(handle);
            ASSERT_RESULT_SUCCESS(result);
        }
    }

    result = nn::svc::CloseHandle(serverSession);
    ASSERT_RESULT_SUCCESS(result);
    result = nn::svc::CloseHandle(clientSession);
    ASSERT_RESULT_SUCCESS(result);
}

TEST(CancelSynchronization, CancelReplyAndRecieveWithUserBufferTest)
{
    nn::Result result;
    nn::svc::Handle handle;
    nn::svc::Handle serverSession;
    nn::svc::Handle clientSession;
    uintptr_t pc = reinterpret_cast<uintptr_t>(WaitReplyAndRecevieWithUserBufferThread);;
    uintptr_t sp = reinterpret_cast<uintptr_t>(&g_Buffer[0]) + sizeof(g_Buffer[0]);

    result = nn::svc::CreateSession(&serverSession, &clientSession, false, 0);
    ASSERT_RESULT_SUCCESS(result);

    for (int32_t idealCore = 0; idealCore < NumCore; idealCore++)
    {
        for (int32_t priority = TestHighestThreadPriority;
                priority <= TestLowestThreadPriority; priority++)
        {
            g_SequenceCount = 0;
            result = nn::svc::CreateThread(
                    &handle, pc, reinterpret_cast<uintptr_t>(&serverSession),
                    sp, priority, idealCore);
            ASSERT_RESULT_SUCCESS(result);

            // TEST 22-13 (同じコア) , 22-14 (違うコア)
            // ReplyAndReceiveWithUserBuffer を呼び出す前にCancelSynchronization をすると、
            // すぐに帰ってくる
            result = nn::svc::CancelSynchronization(handle);
            ASSERT_RESULT_SUCCESS(result);

            result = nn::svc::StartThread(handle);
            ASSERT_RESULT_SUCCESS(result);

            while(g_SequenceCount == 0)
            {
                nn::svc::SleepThread(SleepTime);
            }

            // TEST 22-15 (同じコア) , 22-16 (違うコア)
            // ReplyAndReceiveWithUserBuffer を呼び出した後にCancelSynchronization をすると、
            // スレッドを起こすことが出来る
            result = nn::svc::CancelSynchronization(handle);
            ASSERT_RESULT_SUCCESS(result);

            int32_t index;
            result = nn::svc::WaitSynchronization(&index, &handle, 1, -1);
            ASSERT_RESULT_SUCCESS(result);
            result = nn::svc::CloseHandle(handle);
            ASSERT_RESULT_SUCCESS(result);
        }
    }

    result = nn::svc::CloseHandle(serverSession);
    ASSERT_RESULT_SUCCESS(result);
    result = nn::svc::CloseHandle(clientSession);
    ASSERT_RESULT_SUCCESS(result);
}

TEST(CancelSynchronization, CancelArbitrateLockTest)
{
    for (int32_t idealCore = 0; idealCore < NumCore; idealCore++)
    {
        for (int32_t priority = TestHighestThreadPriority;
                priority <= TestLowestThreadPriority; priority++)
        {
            nn::Result result;
            uintptr_t addr = reinterpret_cast<uintptr_t>(&g_Lock);
            g_Lock = 0;

            TestData testData;
            testData.handle = nn::svc::Handle(nn::os::GetCurrentThread()->_handle);
            testData.addr = addr;

            nnHandle handle = testData.handle;
            nn::Bit32 value = handle.value | nn::svc::Handle::WaitMask;

            uintptr_t pc = reinterpret_cast<uintptr_t>(WaitLockThread);
            uintptr_t sp = reinterpret_cast<uintptr_t>(g_Buffer) + sizeof(g_Buffer);
            uintptr_t param = reinterpret_cast<uintptr_t>(&testData);

            g_SequenceCount = 0;
            g_Lock = value;
            TestThread thread(pc, param, sp, priority, idealCore);
            thread.Start();

            nn::svc::SleepThread(SleepTime);
            ASSERT_TRUE(g_SequenceCount == 1);

            // CancelSynchronization でキャンセル出来ない
            result = nn::svc::CancelSynchronization(thread.GetHandle());
            ASSERT_RESULT_SUCCESS(result);

            nn::svc::SleepThread(SleepTime);
            ASSERT_TRUE(g_SequenceCount == 1);

            result = nn::svc::ArbitrateUnlock(addr);
            ASSERT_RESULT_SUCCESS(result);

            thread.Wait();
            ASSERT_TRUE(g_SequenceCount == 2);
        }
    }
}

TEST(CancelSynchronization, CancelWaitProcessWideKeyTest)
{
    for (int32_t idealCore = 0; idealCore < NumCore; idealCore++)
    {
        for (int32_t priority = TestHighestThreadPriority;
                priority <= TestLowestThreadPriority; priority++)
        {
            nn::Result result;
            uintptr_t addr = reinterpret_cast<uintptr_t>(&g_Lock);

            uintptr_t pc = reinterpret_cast<uintptr_t>(WaitProcessWideThread);
            uintptr_t sp = reinterpret_cast<uintptr_t>(g_Buffer) + sizeof(g_Buffer);
            uintptr_t param = addr;

            g_SequenceCount = 0;
            g_Lock = 1;
            TestThread thread(pc, param, sp, priority, idealCore);
            thread.Start();

            nn::svc::SleepThread(SleepTime);
            ASSERT_TRUE(g_SequenceCount == 1);

            // CancelSynchronization でキャンセル出来ない
            result = nn::svc::CancelSynchronization(thread.GetHandle());
            ASSERT_RESULT_SUCCESS(result);

            nn::svc::SleepThread(SleepTime);
            ASSERT_TRUE(g_SequenceCount == 1);

            nn::svc::SignalProcessWideKey(WaitKey, 1);

            thread.Wait();
            ASSERT_TRUE(g_SequenceCount == 2);
        }
    }
}

TEST(CancelSynchronization, CancelSendSyncRequestTest)
{
    nn::Result result;
    nn::svc::Handle handle;
    nn::svc::Handle serverSession;
    nn::svc::Handle clientSession;
    uintptr_t pc = reinterpret_cast<uintptr_t>(SendSyncRequestThread);
    uintptr_t sp = reinterpret_cast<uintptr_t>(&g_Buffer[0]) + sizeof(g_Buffer[0]);

    result = nn::svc::CreateSession(&serverSession, &clientSession, false, 0);
    ASSERT_RESULT_SUCCESS(result);

    for (int32_t idealCore = 0; idealCore < NumCore; idealCore++)
    {
        for (int32_t priority = TestHighestThreadPriority;
                priority <= TestLowestThreadPriority; priority++)
        {
            g_SequenceCount = 0;
            result = nn::svc::CreateThread(
                    &handle, pc, reinterpret_cast<uintptr_t>(&clientSession),
                    sp, priority, idealCore);
            ASSERT_RESULT_SUCCESS(result);

            result = nn::svc::StartThread(handle);
            ASSERT_RESULT_SUCCESS(result);

            while(g_SequenceCount == 0)
            {
                nn::svc::SleepThread(SleepTime);
            }

            ASSERT_TRUE(g_SequenceCount == 1);

            // CancelSynchronization でキャンセル出来ない
            result = nn::svc::CancelSynchronization(handle);
            ASSERT_RESULT_SUCCESS(result);

            nn::svc::SleepThread(SleepTime);

            ASSERT_TRUE(g_SequenceCount == 1);

            int32_t index;
            result = nn::svc::ReplyAndReceive(&index, &serverSession, 1, nn::svc::INVALID_HANDLE_VALUE, -1);
            ASSERT_RESULT_SUCCESS(result);

            result = nn::svc::ReplyAndReceive(&index, &serverSession, 0, serverSession, 0);
            ASSERT_RESULT_FAILURE_VALUE(result, nn::svc::ResultTimeout());

            result = nn::svc::WaitSynchronization(&index, &handle, 1, -1);
            ASSERT_RESULT_SUCCESS(result);
            result = nn::svc::CloseHandle(handle);
            ASSERT_RESULT_SUCCESS(result);
        }
    }

    result = nn::svc::CloseHandle(serverSession);
    ASSERT_RESULT_SUCCESS(result);
    result = nn::svc::CloseHandle(clientSession);
    ASSERT_RESULT_SUCCESS(result);
}

TEST(CancelSynchronization, CancelSendSyncRequestWithUserBufferTest)
{
    nn::Result result;
    nn::svc::Handle handle;
    nn::svc::Handle serverSession;
    nn::svc::Handle clientSession;
    uintptr_t pc = reinterpret_cast<uintptr_t>(SendSyncRequestWithUserBufferThread);;
    uintptr_t sp = reinterpret_cast<uintptr_t>(&g_Buffer[0]) + sizeof(g_Buffer[0]);

    result = nn::svc::CreateSession(&serverSession, &clientSession, false, 0);
    ASSERT_RESULT_SUCCESS(result);

    for (int32_t idealCore = 0; idealCore < NumCore; idealCore++)
    {
        for (int32_t priority = TestHighestThreadPriority;
                priority <= TestLowestThreadPriority; priority++)
        {
            g_SequenceCount = 0;
            result = nn::svc::CreateThread(
                    &handle, pc, reinterpret_cast<uintptr_t>(&clientSession),
                    sp, priority, idealCore);
            ASSERT_RESULT_SUCCESS(result);

            result = nn::svc::StartThread(handle);
            ASSERT_RESULT_SUCCESS(result);

            while(g_SequenceCount == 0)
            {
                nn::svc::SleepThread(SleepTime);
            }

            ASSERT_TRUE(g_SequenceCount == 1);

            // CancelSynchronization でキャンセル出来ない
            result = nn::svc::CancelSynchronization(handle);
            ASSERT_RESULT_SUCCESS(result);

            nn::svc::SleepThread(SleepTime);

            ASSERT_TRUE(g_SequenceCount == 1);

            int32_t index;
            result = nn::svc::ReplyAndReceive(&index, &serverSession, 1, nn::svc::INVALID_HANDLE_VALUE, -1);
            ASSERT_RESULT_SUCCESS(result);

            result = nn::svc::ReplyAndReceive(&index, &serverSession, 0, serverSession, 0);
            ASSERT_RESULT_FAILURE_VALUE(result, nn::svc::ResultTimeout());

            result = nn::svc::WaitSynchronization(&index, &handle, 1, -1);
            ASSERT_RESULT_SUCCESS(result);
            result = nn::svc::CloseHandle(handle);
            ASSERT_RESULT_SUCCESS(result);
        }
    }

    result = nn::svc::CloseHandle(serverSession);
    ASSERT_RESULT_SUCCESS(result);
    result = nn::svc::CloseHandle(clientSession);
    ASSERT_RESULT_SUCCESS(result);
}

