﻿//==============================================================================
// Common code, used in all of our tests.
//==============================================================================

#include <nn/nn_Log.h>
#include <nn/init.h>
#include <nn/os.h>
#include <nn/fs.h>
#include <nn/mem/mem_StandardAllocator.h>
#include "..\DLLCode\DLLCode.h"
#include "..\Common\Common.h"
#define REPORT( ... ) NN_LOG(__VA_ARGS__)

//==============================================================================
//  Memory management
//==============================================================================
const size_t HeapMemorySize = 0x4000000;
static uintptr_t s_HeapAddress;
static nn::mem::StandardAllocator s_Allocator;
static char* s_CacheBuffer = NULL;

static void* Allocate(size_t size)
{
    return s_Allocator.Allocate(size);
}

static void Deallocate(void* p, size_t size)
{
    NN_UNUSED(size);
    return s_Allocator.Free( p );
}

//==============================================================================

extern "C" void nninitStartup()
{
    nn::os::SetMemoryHeapSize(HeapMemorySize * 2);

    uintptr_t addr;
    auto result = nn::os::AllocateMemoryBlock(&addr, HeapMemorySize);

    nn::init::InitializeAllocator(reinterpret_cast<void*>(addr), HeapMemorySize);
}

//==============================================================================

void Common_Init()
{
    nn::os::AllocateMemoryBlock(&s_HeapAddress, HeapMemorySize);
    s_Allocator.Initialize(reinterpret_cast<void*>(s_HeapAddress), HeapMemorySize);

    //===========================================
    // Init the file system.
    nn::fs::SetAllocator( Allocate, Deallocate );

    size_t cacheSize = 0;
    nn::fs::QueryMountRomCacheSize(&cacheSize);
    s_CacheBuffer = (char*)s_Allocator.Allocate( cacheSize );

    //All our DLLs are in the ROM location - ie, in the deploy data directory
    nn::Result result = nn::fs::MountRom("rom", s_CacheBuffer, cacheSize);

    dll_code::Init( &s_Allocator, "rom:/" );
}

//==============================================================================

void Common_Close()
{
    dll_code::Close();
    nn::fs::Unmount("rom");
    s_Allocator.Free( s_CacheBuffer );
    s_Allocator.Finalize();
    nn::os::FreeMemoryBlock(s_HeapAddress, HeapMemorySize);
}

//==============================================================================

typedef int (*TestDLL_ReturnNumberProc)();

int Common_TestDLLReturnIntCall( dll_code* pModule, const char* pFunctionName )
{
    int Ret = -1;
    uintptr_t ProcAddress;
    if( pModule->FindSymbol( &ProcAddress, pFunctionName ) )
    {
        TestDLL_ReturnNumberProc pCall = (TestDLL_ReturnNumberProc)ProcAddress;
        Ret = pCall();
    }
    else
    {
        REPORT("Unable to find %s\n", pFunctionName );
    }
    return Ret;
}

//==============================================================================

typedef void (*TestDLL_ChangeStringProc)    ( char*, int32_t );

bool Common_TestDLLChangeStringCall( dll_code* pModule, const char* pFunctionName, char* pString, int32_t Value )
{
    bool Ret = false;
    uintptr_t ProcAddress;
    if( pModule->FindSymbol( &ProcAddress, pFunctionName ) )
    {
        TestDLL_ChangeStringProc pCall = (TestDLL_ChangeStringProc)ProcAddress;
        pCall( pString, Value );
        Ret = true;
    }
    else
    {
        REPORT("Unable to find %s\n", pFunctionName );
    }
    return Ret;
}

//==============================================================================

float Common_TestDLLWorkCall( dll_code* pModule, const char* pFunctionName, int Value )
{
    float Ret = 0.0f;
    uintptr_t ProcAddress;
    if( pModule->FindSymbol( &ProcAddress, pFunctionName ) )
    {
        typedef float (*TestDLL_ReturnWorkProc)( int );
        TestDLL_ReturnWorkProc pCall = (TestDLL_ReturnWorkProc)ProcAddress;
        Ret = pCall( Value );
    }
    else
    {
        REPORT("Unable to find %s\n", pFunctionName );
    }
    return Ret;
}

//==============================================================================

void Common_TestDLLSleepCall( dll_code* pModule, const char* pFunctionName, int Duration )
{
    uintptr_t ProcAddress;
    if( pModule->FindSymbol( &ProcAddress, pFunctionName ) )
    {
        typedef void (*SleepProc)( int32_t );
        SleepProc pCall = (SleepProc)ProcAddress;
        pCall( Duration );
    }
    else
    {
        REPORT("Unable to find %s\n", pFunctionName );
    }
    return;
}

//==============================================================================
