﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <filesystem>

#include <nn/gfx/gfx_Enum.h>

#include <gfxTool_Compiler-gx-binary.h>
#include <gfxTool_ShaderCompilerContext.h>
#include <gfxTool_CompileOption-glsl.h>
#include <gfxTool_VariationGrouper.h>
#include <gfxTool_GroupSource.h>
#include <gfxTool_CompileOptionManager.h>
#include <gfxTool_ShaderSourceManager.h>
#include <gfxTool_VariationManager.h>
#include <gfxTool_CompileOutput.h>
#include <gfxTool_OptionOutput.h>
#include <gfxTool_CompileOption-gx.h>

namespace {

static const char* GSH2CompileSetup3::*s_pStageSource[] =
{
    &GSH2CompileSetup3::vs_source,
    nullptr,
    nullptr,
    &GSH2CompileSetup3::gs_source,
    &GSH2CompileSetup3::ps_source,
    &GSH2CompileSetup3::cs_source
};

const char* GetShaderUtilsDllName()
{
    return "shaderUtils.dll";
}

template< typename TGx2Var >
void RetrieveGx2Var( typename nn::gfxTool::Custom< std::vector< TGx2Var > >::Type* pDst,
    int count, TGx2Var** ppSrc )
{
    auto pSrc = *ppSrc;
    pDst->reserve( count );
    for( int idxVar = 0; idxVar < count; ++idxVar )
    {
        pDst->push_back( pSrc[ idxVar ] );
    }
    *ppSrc = pDst->empty() ? nullptr : &( ( *pDst )[ 0 ] );
}

template< typename TGx2Var >
void RetrieveGx2Var( typename nn::gfxTool::Custom< std::vector< TGx2Var > >::Type* pDst,
    nn::gfxTool::Custom< std::list< nn::gfxTool::Custom< std::string >::Type > >::Type* pDstName,
    int count, TGx2Var** ppSrc )
{
    auto pSrc = *ppSrc;
    pDst->reserve( count );
    for( int idxVar = 0; idxVar < count; ++idxVar )
    {
        pDst->push_back( pSrc[ idxVar ] );
        pDstName->emplace_back( pSrc[ idxVar ].name );
        pDst->back().name = pDstName->back().c_str();
    }
    *ppSrc = pDst->empty() ? nullptr : &( ( *pDst )[ 0 ] );
}

template< typename TGx2Shader >
void RetrieveProgramOutput( nn::gfxTool::ProgramOutput* pProgramOutput,
    const TGx2Shader* pGx2Shader, const char* pDump,
    std::shared_ptr< nn::gfxTool::Custom< std::string >::Type > pInfoLog, nn::gfx::ShaderStage stage )
{
    if( pProgramOutput == nullptr || pGx2Shader == nullptr )
    {
        return;
    }

    std::shared_ptr< nn::gfxTool::Gx2Introspection > pIntrospection( new nn::gfxTool::Gx2Introspection() );

    std::shared_ptr< TGx2Shader > pCode( new TGx2Shader() );
    memcpy( pCode.get(), pGx2Shader, sizeof( TGx2Shader ) );
    std::shared_ptr< char > pShaderCode( new char[ pCode->shaderSize ] );
    memcpy( pShaderCode.get(), pCode->shaderPtr, pCode->shaderSize );
    pIntrospection->shaderPtr = pShaderCode;
    pCode->shaderPtr = pIntrospection->shaderPtr.get();
    pProgramOutput->SetCode( nn::gfxTool::StaticCastAuto( stage ), pCode );
    if( NN_STATIC_CONDITION( ( std::is_same< TGx2Shader, GX2GeometryShader >::value ) ) )
    {
        auto pGeometryShader = reinterpret_cast< GX2GeometryShader* >( pCode.get() );
        std::shared_ptr< char > pCopyShaderCode( new char[ pGeometryShader->copyShaderSize ] );
        memcpy( pCopyShaderCode.get(), pGeometryShader->copyShaderPtr,
            pGeometryShader->copyShaderSize );
        pIntrospection->copyShaderPtr = pCopyShaderCode;
        pGeometryShader->copyShaderPtr = pIntrospection->copyShaderPtr.get();
    }

    RetrieveGx2Var( &pIntrospection->uniformBlocks,
        &pIntrospection->names, pCode->numUniformBlocks, &pCode->uniformBlocks );
    RetrieveGx2Var( &pIntrospection->uniformVars,
        &pIntrospection->names, pCode->numUniforms, &pCode->uniformVars );
    RetrieveGx2Var( &pIntrospection->initialValues, pCode->numInitialValues, &pCode->initialValues );
    auto pLoopVar = static_cast< GX2LoopVar* >( pCode->_loopVars );
    RetrieveGx2Var( &pIntrospection->loopVars, pCode->_numLoops, &pLoopVar );
    RetrieveGx2Var( &pIntrospection->samplerVars,
        &pIntrospection->names, pCode->numSamplers, &pCode->samplerVars );
    if( NN_STATIC_CONDITION( ( std::is_same< TGx2Shader, GX2VertexShader >::value ) ) )
    {
        auto pVertexShader = reinterpret_cast< GX2VertexShader* >( pCode.get() );
        RetrieveGx2Var( &pIntrospection->attribVars, &pIntrospection->names,
            pVertexShader->numAttribs, &pVertexShader->attribVars );
    }
    pProgramOutput->SetAdditionalData( pIntrospection );

    auto pOptionStageOutput = static_cast< nn::gfxTool::OptionOutputProgramCommon* >(
        pProgramOutput->GetOptionOutput( nngfxToolShaderCompilerOptionOutputType_ProgramCommon )
        )->GetOptionOutputStageCommon( nn::gfxTool::StaticCastAuto( stage ) );
    if( pDump )
    {
        std::shared_ptr< nn::gfxTool::Custom< std::string >::Type > dump(
            new nn::gfxTool::Custom< std::string >::Type( pDump ) );
        pOptionStageOutput->SetDump( dump );
    }
    if( pInfoLog.get() )
    {
        pOptionStageOutput->SetInfoLog( pInfoLog );
    }
}

}

namespace nn {
namespace gfxTool {

typedef CompilerVariation< static_cast<
    nngfxToolShaderCompilerLowLevelApiType >( nngfxToolShaderCompilerLowLevelApiType_Gx ),
    nngfxToolShaderCompilerCodeType_Binary >
    Target;

void Compiler< Target >::PreCompile( CompileOutput* pOutput,
    const ShaderCompilerContext* pContext, const nngfxToolShaderCompilerCompileArg* pArg )
{
    if( !m_ShaderUtilsDll.IsInitialized() )
    {
        auto path = std::tr2::sys::path( GetModulePath(
            GetShaderCompilerModuleHandle() ) ).parent_path().string();
        path.append( "/" ).append( GetShaderUtilsDllName() );
        if( !m_ShaderUtilsDll.Initialize( path.c_str() ) )
        {
            NN_GFXTOOL_THROW_MSG( nngfxToolResultCode_DllNotFound, "%s is not found.", path.c_str() );
        }
    }

    auto pCompileOptionManager = pContext->GetCompileOptionManager();
    auto pCommonOption = pCompileOptionManager->GetCompileOption<
        nngfxToolShaderCompilerOptionType_Common >();

    m_HandleCountPerVariation = 1;
    if( pCommonOption->IsSeparationEnabled() )
    {
        m_HandleCountPerVariation = 0;
        for( int idxStage = 0; idxStage < static_cast< int >( ShaderStage::End ); ++idxStage )
        {
            if( GetStageSource( pArg, StaticCastAuto( idxStage ) ).pValue )
            {
                ++m_HandleCountPerVariation;
            }
        }
    }

    GSH2Setup setup = {};
    setup.gpu = GPU_VERSION_GPU7;
    auto handleCount = m_HandleCountPerVariation * pArg->variationCount;
    m_Handles.reserve( handleCount );
    for( int idxVariaiton = 0; idxVariaiton < static_cast< int >( handleCount ); ++idxVariaiton )
    {
        m_Handles.push_back( m_ShaderUtilsDll.Gsh2Initialize( &setup ) );
    }

    m_Option.forceUniformBlock = 0;
    m_Option.dumpShaders = pCommonOption->IsDumpEnabled() ? 1 : 0;
    m_Option.skipSparkDebug = pCommonOption->GetDebugInfoLevel()
        == nngfxToolShaderCompilerDebugInfoLevel_None ? 1 : 0;
    m_Option.optimize = 0;
    m_Option.optFlags = 1;
    m_Option.getStats = pCommonOption->IsShaderStatisticsEnabled() ? 1 : 0;

    pOutput->lowLevelCompilerVerison = m_ShaderUtilsDll.Gsh2GetAbiVersion();
}

void Compiler< Target >::CompileGroup( CompileOutput* pOutput, const ShaderCompilerContext* pContext,
    const nngfxToolShaderCompilerCompileArg* pArg, int idxGroup )
{
    auto idxVariation = idxGroup;

    GSH2CompileSetup3 compileSetup = {};
    compileSetup.abi_version = GSH2_ABI_VERSION_CURRENT;
    compileSetup.lang = SHADERLANG_GLSL;
    compileSetup.vs_source_filename = "";
    compileSetup.gs_source_filename = "";
    compileSetup.ps_source_filename = "";
    compileSetup.cs_source_filename = "";

    auto pCompileOptionManager = pContext->GetCompileOptionManager();
    auto pCommonOption = pCompileOptionManager->GetCompileOption<
        nngfxToolShaderCompilerOptionType_Common >();
    auto pGlslOption = pCompileOptionManager->GetCompileOption<
        nngfxToolShaderCompilerOptionType_Glsl >();
    auto pGxOption = pCompileOptionManager->GetCompileOption< static_cast<
        nngfxToolShaderCompilerOptionType >( nngfxToolShaderCompilerOptionType_Gx ) >();

    compileSetup.options = m_Option;
    compileSetup.optimizeFlags = pGxOption->GetOption()->optimizeFlags;
    compileSetup.spark_output_dir = pGxOption->GetOption()->pSparkOutputDir;

    static GSH2ShaderStats* GSH2CompileSetup3::* const s_pStatistics[] =
    {
        &GSH2CompileSetup3::vs_stats,
        nullptr,
        nullptr,
        &GSH2CompileSetup3::gs_stats,
        &GSH2CompileSetup3::ps_stats,
        &GSH2CompileSetup3::cs_stats
    };

    auto pVariationOutput = pOutput->GetVariationOutput( idxVariation );
    auto pProgramOutput = pVariationOutput->GetBinaryOutput();
    auto pInfo = pProgramOutput->GetInfo();
    pInfo->flags.SetBit( nn::gfx::ShaderInfoData::Flag_SeparationEnable,
        pCommonOption->IsSeparationEnabled() );
    pInfo->binaryFormat = 0;
    pInfo->codeType = nn::gfxTool::StaticCastAuto( nn::gfx::ShaderCodeType_Binary );
    pInfo->sourceFormat = 0;

    std::shared_ptr< OptionOutputProgramCommon > pOptionOutputProgramCommon(
        new OptionOutputProgramCommon() );
    pOptionOutputProgramCommon->Initialize( pArg );
    pProgramOutput->AddOptionOutput(
        nngfxToolShaderCompilerOptionOutputType_ProgramCommon, pOptionOutputProgramCommon );

    Custom< std::string >::Type sources[ ShaderStage::End ];
    for( int idxStage = 0, idxHandle = 0; idxStage < static_cast< int >( ShaderStage::End ); ++idxStage )
    {
        auto stage = static_cast< ShaderStage >( idxStage );
        if( GetStageSource( pArg, stage ).pValue == nullptr )
        {
            continue;
        }
        if( s_pStageSource[ idxStage ] == nullptr )
        {
            NN_GFXTOOL_THROW_MSG( nngfxToolResultCode_Unsupported,
                "Tessellation shaders are not supported in GX2" );
        }
        if( pCommonOption->IsSeparationEnabled() )
        {
            auto headVariation = pContext->GetVariationManager()->GetVariationGroup(
                stage )->VariationToHeadVaritaion( idxVariation );
            if( headVariation != idxVariation )
            {
                continue;
            }
        }
        if( compileSetup.options.getStats )
        {
            auto pOptionStageOutput = pOptionOutputProgramCommon->GetOptionOutputStageCommon( stage );
            std::shared_ptr< GSH2ShaderStats > pStatistics( new GSH2ShaderStats() );
            compileSetup.*s_pStatistics[ idxStage ] = pStatistics.get();
            pOptionStageOutput->SetShaderStatistics( pStatistics );
        }

        auto variationConstantGroup = pContext->GetVariationManager(
            )->GetVariationConstantGroup( stage )->GetVariationToGroupTable().at( idxVariation );
        auto& variationConstantSource = pContext->GetVariationManager(
            )->GetVariationConstantSource( stage )->GetSources().at( variationConstantGroup );
        auto preprocessorDefinitionGroup = pContext->GetVariationManager(
            )->GetPreprocessorDefinitionGroup( stage )->GetVariationToGroupTable().at( idxVariation );
        auto& preprocessorDefinitionSource = pContext->GetVariationManager(
            )->GetPreprocessorDefinitionSource( stage )->GetSources().at( preprocessorDefinitionGroup );

        auto pShaderSource = pContext->GetShaderSourceManager(
            )->GetShaderSource( StaticCastAuto( stage ) );

        auto& source = sources[ idxStage ];
        source.append( *pGlslOption->GetGlslHeader().get() );
        source.append( *pCommonOption->GetPreprocessorDefinitionSource().get() );
        source.append( preprocessorDefinitionSource );
        source.append( variationConstantSource );
        source.append( pShaderSource->beforeVariationBufferView.data(),
            pShaderSource->beforeVariationBufferView.length() );
        source.append( pShaderSource->variationBufferSource.data(),
            pShaderSource->variationBufferSource.length() );
        source.append( pShaderSource->afterVariationBufferView.data(),
            pShaderSource->afterVariationBufferView.length() );

        if( pCommonOption->IsSeparationEnabled() )
        {
            auto compileSetupSeparate = compileSetup;
            compileSetupSeparate.*s_pStageSource[ idxStage ] = sources[ idxStage ].c_str();
            CompileGx2( pProgramOutput, m_Handles[ idxVariation * m_HandleCountPerVariation
                + idxHandle++ ], &compileSetupSeparate, pCommonOption->GetCodePage(), idxVariation );
        }
        else
        {
            compileSetup.*s_pStageSource[ idxStage ] = sources[ idxStage ].c_str();
        }
    }

    if( !pCommonOption->IsSeparationEnabled() )
    {
        CompileGx2( pProgramOutput, m_Handles[ idxVariation ],
            &compileSetup, pCommonOption->GetCodePage(), idxVariation );
    }
} // NOLINT

void Compiler< Target >::PostCompile( CompileOutput* pOutput,
    const ShaderCompilerContext* pContext, const nngfxToolShaderCompilerCompileArg* pArg )
{
    auto pCommonOption = pContext->GetCompileOptionManager()->GetCompileOption<
        nngfxToolShaderCompilerOptionType_Common >();
    if( !pCommonOption->IsSeparationEnabled() )
    {
        // TODO
        return;
    }

    for( int idxVariation = 0, variationCount = NumericCastAuto( pArg->variationCount );
        idxVariation < variationCount; ++idxVariation )
    {
        auto* pProgramOutput = pOutput->GetVariationOutput( idxVariation )->GetBinaryOutput();
        for( int idxStage = 0; idxStage < static_cast< int >( ShaderStage::End ); ++idxStage )
        {
            auto stage = static_cast< ShaderStage >( idxStage );
            if( GetStageSource( pArg, stage ).pValue == nullptr )
            {
                continue;
            }

            auto headVariation = pContext->GetVariationManager()->GetVariationGroup(
                stage )->VariationToHeadVaritaion( idxVariation );
            if( headVariation != idxVariation )
            {
                auto* pHeadOutput = pOutput->GetVariationOutput( headVariation )->GetBinaryOutput();
                *GetStageCodePtr( pProgramOutput->GetInfo(), stage ) =
                    *GetStageCodePtr( pHeadOutput->GetInfo(), stage );

                auto pStageOutput = static_cast< OptionOutputProgramCommon* >(
                    pProgramOutput->GetOptionOutput( nngfxToolShaderCompilerOptionOutputType_ProgramCommon )
                    )->GetOptionOutputStageCommon( stage );
                auto pHeadStageOutput = static_cast< OptionOutputProgramCommon* >(
                    pProgramOutput->GetOptionOutput( nngfxToolShaderCompilerOptionOutputType_ProgramCommon )
                    )->GetOptionOutputStageCommon( stage );
                *pStageOutput->GetOutput() = *pHeadStageOutput->GetOutput();
            }
        }
    }
}

void Compiler< Target >::CompileGx2( ProgramOutput* pProgramOutput,
    GSH2Handle handle, GSH2CompileSetup3* pCompileSetup, int codePage, int idxVariation )
{
    GSH2CompileOutput3 compileOutput = {};
    auto succeeded = m_ShaderUtilsDll.Gsh2CompileProgram3(
        handle, pCompileSetup, &compileOutput );

    if( !succeeded )
    {
        // TODO
        Custom< std::string >::Type errorLog;
        errorLog.append( "\n[Variation: " ).append( LexicalCast< Custom<
            std::string >::Type >( idxVariation ) ).append( "]\n" );

        static const char* s_StageNames[] =
        {
            "----Vertex Shader----",
            nullptr,
            nullptr,
            "----Geometry Shader----",
            "----Pixel Shader----",
            "----Compute Shader----"
        };
        for( int idxStage = 0; idxStage < static_cast< int >( ShaderStage::End ); ++idxStage )
        {
            if( s_pStageSource[ idxStage ] && pCompileSetup->*s_pStageSource[ idxStage ] )
            {
                errorLog.append( "\n" ).append( s_StageNames[ idxStage ] ).append( "\n" );
                errorLog.append( pCompileSetup->*s_pStageSource[ idxStage ] );
            }
        }
        errorLog.append( "\n----Error Log----\n" ).append( compileOutput.pInfoLog );
        m_ShaderUtilsDll.Gsh2DestroyGx2Program3( handle, &compileOutput.gx2Program );
        m_ShaderUtilsDll.Gsh2Destroy( handle );
        auto result = ConvertEncoding( errorLog, codePage, 0 );
        NN_GFXTOOL_THROW_MSG( nngfxToolResultCode_CompileError, "%s", result.c_str() );
    }

    std::shared_ptr< Custom< std::string >::Type > pInfoLog;
    if( compileOutput.pInfoLog )
    {
        pInfoLog.reset( new Custom< std::string >::Type( compileOutput.pInfoLog ) );
    }

    RetrieveProgramOutput( pCompileSetup->vs_source ? pProgramOutput : nullptr,
        &compileOutput.gx2Program.vs, compileOutput.pVSDump, pInfoLog, nn::gfx::ShaderStage_Vertex );
    RetrieveProgramOutput( pCompileSetup->gs_source ? pProgramOutput : nullptr,
        &compileOutput.gx2Program.gs, compileOutput.pGSDump, pInfoLog, nn::gfx::ShaderStage_Geometry );
    RetrieveProgramOutput( pCompileSetup->ps_source ? pProgramOutput : nullptr,
        &compileOutput.gx2Program.ps, compileOutput.pPSDump, pInfoLog, nn::gfx::ShaderStage_Pixel );
    RetrieveProgramOutput( pCompileSetup->cs_source ? pProgramOutput : nullptr,
        &compileOutput.gx2Program.cs, compileOutput.pCSDump, pInfoLog, nn::gfx::ShaderStage_Compute );

    m_ShaderUtilsDll.Gsh2DestroyGx2Program3( handle, &compileOutput.gx2Program );
    m_ShaderUtilsDll.Gsh2Destroy( handle );
}

}
}
