﻿//==============================================================================
//  DLLCode object module
//==============================================================================

#include "stdafx.h"
#include <nn/init.h>
#include <nn/ro.h>
#include <nn/svc/svc_Base.h>
#include <nn/ro/detail/ro_NroHeader.h>
#include "DLLCode.h"
#include <nn/fs.h>
#include <nn/os.h>
#include <nn/mem/mem_StandardAllocator.h>
#include <nn\nn_Log.h>

//================================================================================
// Module-level stuff
//================================================================================
static bool sSystemInitialized = false;

namespace
{
    static nn::mem::StandardAllocator* s_pAllocator;
    static char s_FileMountName[1024];
}

//================================================================================

void dll_code::Init( nn::mem::StandardAllocator* pAllocator, const char* pMountName )
{
    if( sSystemInitialized == false )
    {
        s_pAllocator = pAllocator;
        nn::ro::Initialize();
        strcpy( s_FileMountName, pMountName );
        sSystemInitialized = true;
    }
}

//================================================================================

void dll_code::Close()
{
    if( sSystemInitialized == true )
    {
        nn::ro::Finalize();

        sSystemInitialized = false;
    }
}

//================================================================================
// Dll instance objects
//==============================================================================

dll_code::dll_code( const char* pSource )
{
    strcpy( m_Source, pSource );
    m_pNro = NULL;
    m_pBss = NULL;
    m_pNrr = NULL;
    m_NroSize = 0;
    m_BufferSize = 0;
    m_NrrSize = 0;
    memset( &m_Module, 0, sizeof(m_Module) );
    memset( &m_Info, 0, sizeof(m_Info) );
}

//==============================================================================

dll_code::~dll_code()
{
    Unload();
}

//==============================================================================

bool dll_code::LoadSources( )
{
    //=============================================================
    // Create the NRR name.
    char Source[1024];
    strcpy( Source, ".nrr/" );
    strcat( Source, m_Source );
    strcat( Source, ".nrr" );

    uint64_t Size = 0;
    if( LoadFile(Source, &m_pNrr, &Size ) == true )
    {
        m_NrrSize = Size;
        strcpy( Source, m_Source );
        strcat( Source, ".nro" );

        Size = 0;
        bool Loaded = LoadFile( Source, &m_pNro, &Size );
        m_NroSize = Size;

        return Loaded;
    }

    return false;
}

//==============================================================================

bool dll_code::Load( )
{
    if( LoadSources() == false )
    {
        return false;
    }

    nn::Result res = nn::ro::RegisterModuleInfo( &m_Info, m_pNrr );
    if( res.IsFailure() )
    {
        REPORT("ERROR registering module info:  0x%x\n", res.GetInnerValueForDebug() );
        return false;
    }

    res = nn::ro::GetBufferSize( &m_BufferSize, m_pNro );
    if( res.IsSuccess() )
    {
        if (m_BufferSize != 0)
        {
            m_pBss = new char[m_BufferSize];
        }

        res = nn::ro::LoadModule(&m_Module, m_pNro, m_pBss, m_BufferSize, nn::ro::BindFlag_Now);
    }

    return res.IsSuccess();
}

//==============================================================================

bool dll_code::Unload( )
{
    if ( m_Module._state == nn::ro::Module::State_Loaded )
    {
        nn::ro::UnloadModule( &m_Module );
    }

    if ( m_Info._state == nn::ro::RegistrationInfo::State_Registered )
    {
        nn::ro::UnregisterModuleInfo( &m_Info );
    }

    if( m_pNrr != NULL )
    {
        UnloadFile( m_pNrr );
        m_pNrr = NULL;
    }

    if( m_pNro != NULL )
    {
        UnloadFile( m_pNro );
        m_pNro = NULL;
    }

    if( m_pBss != NULL )
    {
        delete [] m_pBss;
        m_pBss = NULL;
    }

    return true;
}

//==============================================================================

bool dll_code::FindSymbol( uintptr_t* pAddress, const char* pSymbolName )
{
    nn::Result res = nn::ro::LookupModuleSymbol( pAddress, &m_Module, pSymbolName );
    return res.IsSuccess();
}

//==============================================================================
// File IO stuff
//==============================================================================
static char sLoadFileName[1024];
static const size_t DefaultAlign = 0x1000;

bool dll_code::LoadFile( const char* pFileName, void** pLoadedFile, uint64_t* pLoadedFileSize )
{
    nn::fs::FileHandle fHandle;

    //Init these to known values.
    bool Ret = false;
    *pLoadedFile = NULL;

    //Make the file name.
    strcpy( sLoadFileName, s_FileMountName );
    strcat( sLoadFileName, pFileName );

    nn::Result result = nn::fs::OpenFile( &fHandle, sLoadFileName, nn::fs::OpenMode_Read );
    if (result.IsFailure())
    {
        REPORT( "OpenFile FAILED for %s:  %d\n", pFileName, result.GetInnerValueForDebug() );
        return false;
    }

    int64_t fileLength = 0;
    result = nn::fs::GetFileSize( &fileLength, fHandle );

    if( result.IsSuccess() && fileLength > 0 )
    {
        char* pDataBuffer = (char*)s_pAllocator->Allocate( fileLength + 1, DefaultAlign );
        result = nn::fs::ReadFile( fHandle, 0, pDataBuffer, fileLength );
        if( result.IsSuccess()  )
        {
            *pLoadedFileSize = (uint64_t)fileLength;
            *pLoadedFile = pDataBuffer;
            Ret = true;
        }
        else
        {
            s_pAllocator->Free( pDataBuffer );
        }
    }

    nn::fs::CloseFile(fHandle);

    return Ret;
}

//==============================================================================

bool dll_code::UnloadFile( void* pLoadedFile )
{
    if( pLoadedFile != NULL )
    {
        s_pAllocator->Free( pLoadedFile );
    }

    return true;
}

//==============================================================================

