﻿//==============================================================================
//
//  Main entry for Thread Test
//
//==============================================================================

#include <string>
#include <stdlib.h>
#include <unordered_map>
#include "nn\nn_Result.h"
#include "nn\nn_TimeSpan.h"
#include "nn\os\os_Mutex.h"
#include "nn\os\os_Thread.h"
#include <nn\nn_Log.h>
#include <atomic>


// NOTE: Threads can only be created on cores 0, 1, and 2 on this system. Trying to make a thread
//       on any other core will result in an exception, by design.


// -------HELPERS-------

nn::os::MutexType GMutex;
volatile int GInt = -1;

struct ThreadInfo
{
    int ThreadNum;
    char* Stack;
};

// Store threads to give them easy names to compare against. First thread is thread 0.
std::unordered_map<nn::os::ThreadType*, ThreadInfo> ThreadInfos;
int CurThreadIndex = 0;

void ClearThreads()
{
    ThreadInfos.clear();
    CurThreadIndex = 0;
}

void AddThreadInfo( nn::os::ThreadType* Thread, char* Stack )
{
    ThreadInfos[Thread].ThreadNum = CurThreadIndex++;
    ThreadInfos[Thread].Stack = Stack;
}

int GetThreadNumber()
{
    nn::os::ThreadType* Thread = nn::os::GetCurrentThread();

    auto ThreadPos = ThreadInfos.find( Thread );
    if( ThreadPos != ThreadInfos.end() )
    {
        return ThreadPos->second.ThreadNum;
    }
    else
    {
        return -1;
    }
}

nn::os::ThreadType* GetThread( int ThreadNumber )
{
    for( auto Thread : ThreadInfos )
    {
        if( Thread.second.ThreadNum == ThreadNumber )
        {
            return Thread.first;
        }
    }

    return nullptr;
}

template <typename FunctionType>
nn::os::ThreadType* CreateThread( FunctionType* Func, int StackSize, int IdealCore )
{
    const int Alignment = 4096;

    nn::os::ThreadType* Thread = new nn::os::ThreadType();

    char* Stack = (char*)aligned_alloc( Alignment, StackSize );

    nn::Result Result = nn::os::CreateThread( Thread, (nn::os::ThreadFunction)Func,
        NULL, Stack, StackSize, nn::os::DefaultThreadPriority, IdealCore );

    nn::os::StartThread( Thread );

    AddThreadInfo( Thread, Stack );

    return Thread;
}

void WaitFreeThreads( nn::os::ThreadType** Threads, int Count )
{
    for( int i = 0; i < Count; ++i )
    {
        nn::os::WaitThread( Threads[i] );
    }

    for( int i = 0; i < Count; ++i )
    {
        free( ThreadInfos[Threads[i]].Stack );
        nn::os::DestroyThread( Threads[i] );
    }
}

void ResetTest()
{
    GInt = -1;
    ClearThreads();
}

// -------TESTS---------

// NOTE: Stepping tests are by nature very sensitive to code changes. Modifying a stepping test
//       will likely require a modification of the xml test definition.

// Simple generic stepping test.
const int nBasicThreadStepAmount = 3;
const int nBasicThreadBusyLoops = 10;

void BasicThreadStepFunc1()
{
    int A = 0;

    while( true )                                                                                       //ThreadKey00
    {
        if( GInt == 1 )                                                                                 //ThreadKey01
        {
            A++;                                                                                        //ThreadKey02

            if( A == nBasicThreadStepAmount )                                                           //ThreadKey03
            {
                // Kick off logic in thread 1.
                GInt = 2;                                                                               //ThreadKey04

                // Thread 3 should set CurThread to -7777. Just step through here to ensure enough
                // time is given for this to happen and that we don't move to different threads.
                for( int i = 0; i < nBasicThreadBusyLoops; ++i )                                        //ThreadKey05
                {
                    int Dummy = i;                                                                      //ThreadKey06
                    ++Dummy;                                                                            //ThreadKey07
                }

                // Check CurThread is -7777, which proves the other threads were still running as
                // we were stepping through thread 0.
                int Hold = 0;                                                                           //ThreadKey08
                break;
            }
        }
    }
}

void BasicThreadStepFunc2()
{
    while( true )
    {
        if( GInt == 2 )
        {
            GInt = 3;
            break;
        }
    }
}

void BasicThreadStepFunc3()
{
    while( true )
    {
        if( GInt == 3 )
        {
            GInt = -7777;
            break;
        }
    }
}

void ThreadBasicSteppingTest()
{
    ResetTest();

    GInt = 1;

    const int NumThreads = 3;
    nn::os::ThreadType* Threads[NumThreads];

    Threads[0] = CreateThread( BasicThreadStepFunc1, 4096, 0 );
    Threads[1] = CreateThread( BasicThreadStepFunc2, 4096, 1 );
    Threads[2] = CreateThread( BasicThreadStepFunc3, 4096, 2 );

    WaitFreeThreads( Threads, NumThreads );
}

const int nFunctionThreadLoops = 3;

// Complex stepping test with function calls and threads that go over the same code
// concurrently.
void RecursiveSteppingFunc( int i )
{                                                                                        //ThreadKey28
    if( i == 0 )                                                                         //ThreadKey29
    {
        LockMutex( &GMutex );                                                            //ThreadKey30
        int ThreadNum = GetThreadNumber();                                               //ThreadKey31
        UnlockMutex( &GMutex );                                                          //ThreadKey32
        return;                                                                          //ThreadKey33
    }
    else
    {
        RecursiveSteppingFunc( i - 1 );                                                  //ThreadKey34
    }

    int Step1 = 0;                                                                       //ThreadKey44
}                                                                                        //ThreadKey35

void NormalSteppingFunc( int Iterations )
{                                                                                        //ThreadKey36
    RecursiveSteppingFunc( Iterations );                                                 //ThreadKey37
}                                                                                        //ThreadKey38

template <typename T>
void TemplateSteppingFunc( T Iterations )
{                                                                                        //ThreadKey39
    NormalSteppingFunc( Iterations );                                                    //ThreadKey40
}                                                                                        //ThreadKey41

void ConcurrentStepFunc()
{                                                                                        //ThreadKey19
    // All threads fall into this function at slightly different times. This allows us
    // to test that stepping only follows a single thread, even in the same code region.
    for( int i = 0; i < nFunctionThreadLoops; ++i )                                      //ThreadKey20
    {
        int Step1 = 0;                                                                   //ThreadKey21
        int Step2 = 0;                                                                   //ThreadKey22
        int Step3 = 0;                                                                   //ThreadKey23
        int Step4 = 0;                                                                   //ThreadKey24

        LockMutex( &GMutex );                                                            //ThreadKey25
        int ThreadNum = GetThreadNumber();                                               //ThreadKey43
        TemplateSteppingFunc( ThreadNum );                                               //ThreadKey26
        UnlockMutex( &GMutex );                                                          //ThreadKey27
    }
}                                                                                        //ThreadKey42

void FunctionThreadStepFunc1()
{
    nn::os::SleepThread( nn::TimeSpan::FromMilliSeconds( 200 ) );
    ConcurrentStepFunc();
}

void FunctionThreadStepFunc2()
{
    nn::os::SleepThread( nn::TimeSpan::FromMilliSeconds( 100 ) );
    ConcurrentStepFunc();
}

void FunctionThreadStepFunc3()
{
    ConcurrentStepFunc();                                                                //ThreadKey18
}

void ThreadFunctionSteppingTest()
{
    ResetTest();

    const int NumThreads = 3;
    nn::os::ThreadType* Threads[NumThreads];

    Threads[0] = CreateThread( FunctionThreadStepFunc1, 4096, 0 );
    Threads[1] = CreateThread( FunctionThreadStepFunc2, 4096, 1 );
    Threads[2] = CreateThread( FunctionThreadStepFunc3, 4096, 2 );

    WaitFreeThreads( Threads, NumThreads );
}

// Simple stepping test where one thread waits for another to join it.
const int nWaitThreadLoops = 5;

void WaitingForThreadStepFunc()
{
    int Start = 0;                                                                  //ThreadKey10

    // Step and be busy for a while.
    for( int i = 0; i < nWaitThreadLoops; ++i )                                     //ThreadKey11
    {
        int Step1 = 0;                                                              //ThreadKey12
        int Step2 = 0;                                                              //ThreadKey13
        int Step3 = 0;                                                              //ThreadKey14
    }

    int Exit = 0;                                                                   //ThreadKey15
}

void WaitingThreadStepFunc()
{
    nn::os::ThreadType* ThreadsToWaitFor[] = { GetThread( 1 ) };
    WaitFreeThreads( ThreadsToWaitFor, 1 );

    int Step1 = 0;                                                                  //ThreadKey16
    int Step2 = 0;                                                                  //ThreadKey17
}

void ThreadWaitSteppingTest()
{
    ResetTest();

    const int NumThreads = 2;
    nn::os::ThreadType* Threads[NumThreads];

    Threads[0] = CreateThread( WaitingThreadStepFunc, 4096, 0 );
    Threads[1] = CreateThread( WaitingForThreadStepFunc, 4096, 1 );

    nn::os::ThreadType* WaitingThreads[] = { Threads[0] };
    WaitFreeThreads( WaitingThreads, 1 );
}

// Test to show that different threads contain different callstacks. We sometimes have to add "int Dummy = 0" to get
// IP in VS and TMAPI to match, this is because Debug Engine lies slightly about the IP's position.

class CallstackClass
{};

void CallstackTopLevelFunc()
{
    nn::os::LockMutex( &GMutex );
    int ThreadNum = GetThreadNumber();
    nn::os::UnlockMutex( &GMutex );

    if( ThreadNum == 0 )
    {
        // Expects: Func1 -> SharedFunc -> TemplateFunc -> TopLevelFunc
        int Hold = 0;                                                       //ThreadKey45
        GInt = 2;
    }
    else if( ThreadNum == 1 )
    {
        while( GInt != 3 )
        {
            if( GInt == 2 )
            {
                // Expects: Func2 -> SharedFunc -> YetAnotherFunc -> TopLevelFunc
                int Hold = 0;                                               //ThreadKey46
                GInt = 3;
            }
        }
    }
    else if( ThreadNum == 2 )
    {
        while( GInt != 4 )
        {
            if( GInt == 3 )
            {
                // Expects: Func3 -> SharedFunc -> YetAnotherFunc -> RecursiveFunc
                //          -> RecursiveFunc -> RecursiveFunc -> TopLevelFunc
                int Hold = 0;                                               //ThreadKey47
                GInt = 4;
            }
        }
    }
}

int CallstackRecursiveFunc( int Count )
{
    if( Count == 0 )
    {
        CallstackTopLevelFunc();                                                //ThreadKey64
        return 0;                                                               //ThreadKey65
    }

    return 1 + CallstackRecursiveFunc( Count - 1 );                             //ThreadKey66
}                                                                               //ThreadKey67

void CallstackYetAnotherFunc()
{
    nn::os::LockMutex( &GMutex );
    int ThreadNum = GetThreadNumber();
    nn::os::UnlockMutex( &GMutex );

    switch( ThreadNum )
    {
        case 1:
        {
            CallstackTopLevelFunc();                                            //ThreadKey54
            int Dummy = 0;                                                      //ThreadKey55
            break;
        }
        case 2:
        {
            CallstackRecursiveFunc( ThreadNum );                                //ThreadKey62
            int Dummy = 0;                                                      //ThreadKey63
            break;
        }
    }
}

template <typename T>
double CallstackTemplateFunc()
{
    CallstackTopLevelFunc();                                                //ThreadKey48

    int Dummy = 0;                                                          //ThreadKey49

    return 74.2;
}

void CallstackSharedFunc()
{
    nn::os::LockMutex( &GMutex );
    int ThreadNum = GetThreadNumber();
    nn::os::UnlockMutex( &GMutex );

    switch( ThreadNum )
    {
        case 0:
        {
            CallstackTemplateFunc<CallstackClass>();                            //ThreadKey50
            int Dummy = 0;                                                      //ThreadKey51
            break;
        }
        case 1:
        case 2:
        {
            CallstackYetAnotherFunc();                                          //ThreadKey56
            int Dummy = 0;                                                      //ThreadKey57
            break;
        }
    }
}

void CallstackThreadFunc1()
{
    CallstackSharedFunc();                                                      //ThreadKey52
}                                                                               //ThreadKey53

void CallstackThreadFunc2()
{
    CallstackSharedFunc();                                                      //ThreadKey58
}                                                                               //ThreadKey59

void CallstackThreadFunc3()
{
    CallstackSharedFunc();                                                      //ThreadKey60
}                                                                               //ThreadKey61

void ThreadCallstackTest()
{
    ResetTest();

    const int NumThreads = 3;
    nn::os::ThreadType* Threads[NumThreads];

    Threads[0] = CreateThread( CallstackThreadFunc1, 4096, 0 );
    Threads[1] = CreateThread( CallstackThreadFunc2, 4096, 1 );
    Threads[2] = CreateThread( CallstackThreadFunc3, 4096, 2 );

    WaitFreeThreads( Threads, NumThreads );
}

// Test where two different threads enter the same function at different times,
// with a single breakpoint that is never removed.
void SharedFunc()
{
    // Two threads will enter the same function. thread 0 enters first, thread 1 enters five
    // second later. Just to show two threads can hit a breakpoint in a single function.
    nn::os::LockMutex( &GMutex );
    int ThreadNum = GetThreadNumber();
    nn::os::UnlockMutex( &GMutex );                                                                  //ThreadKey09
}

void SharedThreadFunc1()
{
    SharedFunc();
}

void SharedThreadFunc2()
{
    // thread 1 will enter the function later, giving our test time to disable the breakpoint.
    nn::os::SleepThread( nn::TimeSpan::FromSeconds( 5 ) );
    SharedFunc();
}

void SharedFunctionTest()
{
    ResetTest();

    const int NumThreads = 2;
    nn::os::ThreadType* Threads[NumThreads];

    Threads[0] = CreateThread( SharedThreadFunc1, 4096, 0 );
    Threads[1] = CreateThread( SharedThreadFunc2, 4096, 1 );

    WaitFreeThreads( Threads, NumThreads );
}

extern "C" void nnMain ( void )
{
    nn::os::InitializeMutex( &GMutex, true, nn::os::MutexLockLevelMax );

    ThreadBasicSteppingTest();
    ThreadFunctionSteppingTest();
    ThreadWaitSteppingTest();
    ThreadCallstackTest();

    // Do this test last, it has a sleep call and will slow down any test it is above.
    SharedFunctionTest();
}
