﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <filesystem>

#include <nn/gfx/gfx_ResShader.h>
#include <nn/gfx/gfx_ResShaderData-api.nvn.h>

#include <nn/gfxTool/gfxTool_Util.h>

#include <gfxTool_CompileOutput.h>
#include <gfxTool_OptionOutput.h>
#include <gfxTool_ShaderFileMerger.h>

namespace {

nngfxToolShaderCompilerCodeType GetCodeType( const nn::gfx::ResShaderVariationData& shaderVariationData )
{
    if( shaderVariationData.pBinaryProgram.GetOffset() != 0 )
    {
        if( shaderVariationData.pIntermediateLanguageProgram.GetOffset() != 0 )
        {
            return nngfxToolShaderCompilerCodeType_Binary_Ir;
        }
        else if( shaderVariationData.pSourceProgram.GetOffset() != 0 )
        {
            return nngfxToolShaderCompilerCodeType_Binary_Source;
        }
        else
        {
            return nngfxToolShaderCompilerCodeType_Binary;
        }
    }
    else if( shaderVariationData.pIntermediateLanguageProgram.GetOffset() != 0 )
    {
        if( shaderVariationData.pSourceProgram.GetOffset() != 0 )
        {
            return nngfxToolShaderCompilerCodeType_Ir_Source;
        }
        else
        {
            return nngfxToolShaderCompilerCodeType_Ir;
        }
    }
    else if( shaderVariationData.pSourceProgram.GetOffset() != 0 )
    {
        return nngfxToolShaderCompilerCodeType_Source;
    }
    return nngfxToolShaderCompilerCodeType_Undefined;
};

const nn::gfx::ResShaderReflectionStageData* GetReflectionStageData(
    const nn::gfx::ResShaderReflectionData* pShaderReflectionData, nn::gfxTool::ShaderStage stage )
{
    static nn::util::BinTPtr< nn::gfx::ResShaderReflectionStageData >
        nn::gfx::ResShaderReflectionData::*s_pReflectionStages[] =
    {
        &nn::gfx::ResShaderReflectionData::pVertexReflection,
        &nn::gfx::ResShaderReflectionData::pHullReflection,
        &nn::gfx::ResShaderReflectionData::pDomainReflection,
        &nn::gfx::ResShaderReflectionData::pGeometryReflection,
        &nn::gfx::ResShaderReflectionData::pPixelReflection,
        &nn::gfx::ResShaderReflectionData::pComputeReflection
    };
    return ( pShaderReflectionData->*s_pReflectionStages[ static_cast< int >( stage ) ] ).Get();
}

template< typename TReflection >
void SetShaderSlot( TReflection* pReflection, int slot, nn::gfxTool::ShaderStage stage )
{
    static int32_t nngfxToolShaderCompilerShaderSlot::*pReflectionSlot[] =
    {
        &nngfxToolShaderCompilerShaderSlot::vertexShaderSlot,
        &nngfxToolShaderCompilerShaderSlot::hullShaderSlot,
        &nngfxToolShaderCompilerShaderSlot::domainShaderSlot,
        &nngfxToolShaderCompilerShaderSlot::geometryShaderSlot,
        &nngfxToolShaderCompilerShaderSlot::pixelShaderSlot,
        &nngfxToolShaderCompilerShaderSlot::computeShaderSlot
    };
    pReflection->shaderSlot.*pReflectionSlot[ static_cast<int>( stage ) ] = slot;
}
template<>
void SetShaderSlot< nngfxToolShaderCompilerShaderInput >(
    nngfxToolShaderCompilerShaderInput* pReflection, int slot, nn::gfxTool::ShaderStage )
{
    pReflection->shaderSlot = slot;
}
template<>
void SetShaderSlot< nngfxToolShaderCompilerShaderOutput >(
    nngfxToolShaderCompilerShaderOutput* pReflection, int slot, nn::gfxTool::ShaderStage )
{
    pReflection->shaderSlot = slot;
}

template< typename TReflection >
void AddReflection( nn::gfxTool::OptionOutputReflection* pOutput,
    const nn::util::ResDic* pDic, const int* pSlotArray, int offset, nn::gfxTool::ShaderStage stage )
{

    static nngfxToolShaderCompilerReflectionStageReference StageReferenceTable[] =
    {
        nngfxToolShaderCompilerReflectionStageReference_Vertex,
        nngfxToolShaderCompilerReflectionStageReference_Hull,
        nngfxToolShaderCompilerReflectionStageReference_Domain,
        nngfxToolShaderCompilerReflectionStageReference_Geometry,
        nngfxToolShaderCompilerReflectionStageReference_Pixel,
        nngfxToolShaderCompilerReflectionStageReference_Compute,
    };
    if( pDic )
    {
        int idxStage = StaticCastAuto( stage );
        for( int idxReflection = 0; idxReflection < pDic->GetCount(); ++idxReflection )
        {
            TReflection reflection = {};
            auto name = pDic->GetKey( idxReflection );
            reflection.name.pValue = name.data();
            reflection.name.length = nn::gfxTool::NumericCastAuto( name.length() );
            int slot = pSlotArray[ offset + idxReflection ];
            SetShaderSlot( &reflection, slot, stage );
            reflection.stages = StageReferenceTable[ idxStage ];
            pOutput->AddReflection( reflection );
        }
    }
}

void ValidateShaderFiles( const nngfxToolShaderConverterConvertArg* pArg )
{
    // リロケーション前
    int shaderFileCount = pArg->mergeShaderFileCount;
    for( int idxShaderFile = 0; idxShaderFile < shaderFileCount; ++idxShaderFile )
    {
        auto pShaderFileData = static_cast<const nn::gfx::ResShaderFileData*>(
            pArg->pMergeShaderFiles[ idxShaderFile ] );
        if( pShaderFileData == nullptr )
        {
            NN_GFXTOOL_THROW( nngfxToolResultCode_InvalidArgument );
        }
        if( !pShaderFileData->fileHeader.IsSignatureValid( nn::gfx::ResShaderFile::Signature ) )
        {
            NN_GFXTOOL_THROW_MSG( nngfxToolResultCode_InvalidArgument,
                "Invalid signature (merge shader file %d)", idxShaderFile );
        }
        auto name = pShaderFileData->fileHeader.GetFileName();
        if( pShaderFileData->fileHeader.version.major != nn::gfx::ResShaderFile::MajorVersion ||
            pShaderFileData->fileHeader.version.minor != nn::gfx::ResShaderFile::MinorVersion ||
            pShaderFileData->fileHeader.version.micro != nn::gfx::ResShaderFile::MicroVersion )
        {
            NN_GFXTOOL_THROW_MSG( nngfxToolResultCode_InvalidArgument,
                "Invalid version (%*s)\n"
                "    Expected: %d.%d.%d\n"
                "    Actual: %d.%d.%d\n",
                static_cast< int >( name.length() ), name.data(),
                nn::gfx::ResShaderFile::MajorVersion,
                nn::gfx::ResShaderFile::MinorVersion,
                nn::gfx::ResShaderFile::MicroVersion,
                static_cast<int>( pShaderFileData->fileHeader.version.major ),
                static_cast<int>( pShaderFileData->fileHeader.version.minor ),
                static_cast<int>( pShaderFileData->fileHeader.version.micro ) );
        }
    }
}

}

namespace nn {
namespace gfxTool {

ShaderFileMerger::ShaderFileMerger()
    : m_pCompileOutput()
{
}

ShaderFileMerger::~ShaderFileMerger()
{
    if( m_pCompileOutput )
    {
        CompileOutput::Delete( static_cast< CompileOutput* >( m_pCompileOutput ) );
    }
}

void ShaderFileMerger::Merge( const nngfxToolShaderConverterConvertArg* pArg )
{
    static const int LowLevelApiTypeTable[] =
    {
        0,
        nngfxToolShaderCompilerLowLevelApiType_Gl,
        nngfxToolShaderCompilerLowLevelApiType_Gx,
        nngfxToolShaderCompilerLowLevelApiType_D3d,
        nngfxToolShaderCompilerLowLevelApiType_Nvn,
        nngfxToolShaderCompilerLowLevelApiType_Vk
    };

    if( pArg->pMergeShaderFiles == nullptr )
    {
        NN_GFXTOOL_THROW( nngfxToolResultCode_InvalidArgument );
    }
    ValidateShaderFiles( pArg );

    m_ShaderFiles.clear();
    m_ShaderFiles.reserve( pArg->mergeShaderFileCount );
    for( int idxShaderFile = 0; idxShaderFile < static_cast<int>( pArg->mergeShaderFileCount ); ++idxShaderFile )
    {
        auto pSrcShaderFileData = static_cast< const nn::gfx::ResShaderFileData* >(
            pArg->pMergeShaderFiles[ idxShaderFile ] );
        if( pSrcShaderFileData->fileHeader.IsRelocated() )
        {
            auto name = pSrcShaderFileData->fileHeader.GetFileName();
            NN_GFXTOOL_THROW_MSG( nngfxToolResultCode_InvalidArgument,
                "Relocated shader file is not acceptable (%d).",
                static_cast< int >( name.length() ), name.data() );
        }
        auto fileSize = pSrcShaderFileData->fileHeader.GetFileSize();
        m_ShaderFiles.push_back( std::unique_ptr< void, decltype( &free ) >( malloc( fileSize ), free ) );
        auto pDstShaderFileData = static_cast< nn::gfx::ResShaderFileData*>(
            m_ShaderFiles[ idxShaderFile ].get() );
        memcpy( pDstShaderFileData, pSrcShaderFileData, fileSize );
        if( !pDstShaderFileData->fileHeader.IsRelocated() )
        {
            pDstShaderFileData->fileHeader.GetRelocationTable()->Relocate();
        }
    }

    auto pOutput = CompileOutput::Create();
    m_pCompileOutput = static_cast< nngfxToolShaderCompilerCompileOutput* >( pOutput );

    auto pFirstShaderFileData = static_cast< const nn::gfx::ResShaderFileData* >( m_ShaderFiles[ 0 ].get() );
    auto firstShaderFileName = pFirstShaderFileData->fileHeader.GetFileName();
    auto pFirstShaderContainerData = reinterpret_cast< const nn::gfx::ResShaderContainerData* >(
        pFirstShaderFileData->fileHeader.GetFirstBlock() );
    nngfxToolShaderCompilerCodeType codeType = nngfxToolShaderCompilerCodeType_Undefined;
    int variationCount = 0;
    int shaderFileCount = StaticCastAuto( pArg->mergeShaderFileCount );
    for( int idxShaderFile = 0; idxShaderFile < shaderFileCount; ++idxShaderFile )
    {
        auto pShaderFileData = static_cast< const nn::gfx::ResShaderFileData* >( m_ShaderFiles[ idxShaderFile ].get() );
        auto shaderFileName = pShaderFileData->fileHeader.GetFileName();
        auto pShaderContainerData = reinterpret_cast< const nn::gfx::ResShaderContainerData* >(
            pShaderFileData->fileHeader.GetFirstBlock() );
        if( pFirstShaderContainerData->compilerVersion != pShaderContainerData->compilerVersion ||
            pFirstShaderContainerData->lowLevelCompilerVersion != pShaderContainerData->lowLevelCompilerVersion ||
            pFirstShaderContainerData->targetApiType != pShaderContainerData->targetApiType ||
            pFirstShaderContainerData->targetApiVersion != pShaderContainerData->targetApiVersion )
        {
            NN_GFXTOOL_THROW_MSG( nngfxToolResultCode_InconsistentShaderFile,
                "Shader files (%*s and %*s) are inconsistent",
                static_cast< int >( firstShaderFileName.length() ), firstShaderFileName.data(),
                static_cast< int >( shaderFileName.length() ), shaderFileName.data() );
        }
        variationCount += pShaderContainerData->shaderVariationCount;
        for( int idxVariation = 0; idxVariation < static_cast< int >(
            pShaderContainerData->shaderVariationCount ); ++idxVariation )
        {
            auto pShaderVariationData = pShaderContainerData->pShaderVariationArray.Get();
            auto codeTypeVariation = GetCodeType( *pShaderVariationData );
            if( codeType != nngfxToolShaderCompilerCodeType_Undefined && codeType != codeTypeVariation )
            {
                NN_GFXTOOL_THROW_MSG( nngfxToolResultCode_InconsistentShaderFile,
                    "Shader files (%*s and %*s) are inconsistent",
                    static_cast< int >( firstShaderFileName.length() ), firstShaderFileName.data(),
                    static_cast< int >( shaderFileName.length() ), shaderFileName.data() );
            }
            codeType = codeTypeVariation;
        }
    }

    pOutput->variationCount = variationCount;
    pOutput->compilerVersion = pFirstShaderContainerData->compilerVersion;
    pOutput->lowLevelCompilerVerison = pFirstShaderContainerData->lowLevelCompilerVersion;

    auto& compileArg = m_DummyCompileArg;
    compileArg.variationCount = variationCount;
    compileArg.targetCodeType = codeType;
    compileArg.targetLowLevelApiType = LowLevelApiTypeTable[ pFirstShaderContainerData->targetApiType ];
    compileArg.targetLowLevelApiVersion = pFirstShaderContainerData->targetApiVersion;
    pOutput->Initialize( &compileArg );

    auto CreateProgramOutput = [ & ]( ProgramOutput* pProgramOutput,
        nn::gfx::ResShaderProgramData* pShaderProgramData )
    {
        pProgramOutput->Initialize( &compileArg );
        *pProgramOutput->GetInfo() = pShaderProgramData->info;
        if( pArg->enableFullReflection && pShaderProgramData->pShaderCompilerReflection.Get() == nullptr )
        {
            NN_GFXTOOL_THROW_MSG( nngfxToolResultCode_InvalidArgument,
                "Source shader file does not have full reflection." );
        }
        if( pShaderProgramData->pShaderCompilerReflection.Get() ||
            pShaderProgramData->pShaderReflection.Get() )
        {
            std::shared_ptr< OptionOutputProgramCommon > pOptionOutputProgramCommon(
                new OptionOutputProgramCommon );
            pOptionOutputProgramCommon->Initialize( &compileArg );
            pProgramOutput->AddOptionOutput(
                nngfxToolShaderCompilerOptionOutputType_ProgramCommon, pOptionOutputProgramCommon );
            if( auto pFullReflection = pShaderProgramData->pShaderCompilerReflection.Get() )
            {
                pOptionOutputProgramCommon->GetOutput()->pReflection =
                    pShaderProgramData->pShaderCompilerReflection.Get();
            }
            else if( auto pReflection = pShaderProgramData->pShaderReflection.Get() )
            {
                // 必要な情報だけ埋めたリフレクションを作成
                std::shared_ptr< OptionOutputReflection > pReflectionOutput( new OptionOutputReflection );
                for( int idxStage = 0; idxStage < static_cast<int>( ShaderStage::End ); ++idxStage )
                {
                    auto stage = static_cast<ShaderStage>( idxStage );
                    if( auto pStageReflection = GetReflectionStageData( pReflection, StaticCastAuto( idxStage ) ) )
                    {
                        AddReflection< nngfxToolShaderCompilerShaderInput >( pReflectionOutput.get(),
                            pStageReflection->pShaderInputDic.Get(), pStageReflection->pShaderSlotArray.Get(),
                            0, stage );
                        AddReflection< nngfxToolShaderCompilerShaderOutput >( pReflectionOutput.get(),
                            pStageReflection->pShaderOutputDic.Get(), pStageReflection->pShaderSlotArray.Get(),
                            pStageReflection->offsetShaderOutput, stage );
                        AddReflection< nngfxToolShaderCompilerSampler >( pReflectionOutput.get(),
                            pStageReflection->pSamplerDic.Get(), pStageReflection->pShaderSlotArray.Get(),
                            pStageReflection->offsetSampler, stage );
                        AddReflection< nngfxToolShaderCompilerConstantBuffer >( pReflectionOutput.get(),
                            pStageReflection->pConstantBufferDic.Get(), pStageReflection->pShaderSlotArray.Get(),
                            pStageReflection->offsetConstantBuffer, stage );
                        AddReflection< nngfxToolShaderCompilerUnorderedAccessBuffer >( pReflectionOutput.get(),
                            pStageReflection->pUnorderedAccessBufferDic.Get(), pStageReflection->pShaderSlotArray.Get(),
                            pStageReflection->offsetUnorderedAccessBuffer, stage );
                        AddReflection< nngfxToolShaderCompilerImage >( pReflectionOutput.get(),
                            pStageReflection->pImageDic.Get(), pStageReflection->pShaderSlotArray.Get(),
                            pStageReflection->offsetImage, stage );
                        if( auto pStageReflection2 = pStageReflection->pReflectionStageData2.Get() )
                        {
                            AddReflection< nngfxToolShaderCompilerSeparateTexture >( pReflectionOutput.get(),
                                pStageReflection2->pSeparateTextureDic.Get(), pStageReflection->pShaderSlotArray.Get(),
                                pStageReflection2->offsetSeparateTexture, stage );
                            AddReflection< nngfxToolShaderCompilerSeparateSampler >( pReflectionOutput.get(),
                                pStageReflection2->pSeparateSamplerDic.Get(), pStageReflection->pShaderSlotArray.Get(),
                                pStageReflection2->offsetSeparateSampler, stage );
                        }
                        if( stage == ShaderStage::Compute )
                        {
                            pReflectionOutput->GetOutput()->computeWorkGroupSizeX = pStageReflection->computeWorkGroupSizeX;
                            pReflectionOutput->GetOutput()->computeWorkGroupSizeY = pStageReflection->computeWorkGroupSizeY;
                            pReflectionOutput->GetOutput()->computeWorkGroupSizeZ = pStageReflection->computeWorkGroupSizeZ;
                        }
                    }
                }
                pOptionOutputProgramCommon->SetReflection( pReflectionOutput );
            }
        }
    };

    int idxNewVariation = 0;
    nn::gfx::ResShaderProgramData* pFirstBinaryProgram = nullptr;
    nn::gfx::ResShaderProgramData* pFirstIrProgram = nullptr;
    nn::gfx::ResShaderProgramData* pFirstSourceProgram = nullptr;
    for( int idxShaderFile = 0; idxShaderFile < shaderFileCount; ++idxShaderFile )
    {
        auto pShaderFileData = static_cast< nn::gfx::ResShaderFileData* >(
            m_ShaderFiles[ idxShaderFile ].get() );
        auto pShaderContainerData = reinterpret_cast< nn::gfx::ResShaderContainerData* >(
            pShaderFileData->fileHeader.GetFirstBlock() );
        for( int idxVariation = 0; idxVariation < static_cast< int >(
            pShaderContainerData->shaderVariationCount ); ++idxVariation )
        {
            auto pVariationOutput = pOutput->GetVariationOutput( idxNewVariation++ );
            auto pShaderVariationArray = pShaderContainerData->pShaderVariationArray.Get();
            auto pShaderVariationData = pShaderVariationArray + idxVariation;
            bool isConsistent = true;
            auto IsReflectionConsistent = []( const nn::gfx::ResShaderProgramData* pFirstProgram,
                const nn::gfx::ResShaderProgramData* pProgram )
            {
                return ( pFirstProgram->pShaderCompilerReflection.Get() != nullptr ) ==
                    ( pProgram->pShaderCompilerReflection.Get() != nullptr ) ||
                    ( pFirstProgram->pShaderReflection.Get() != nullptr ) !=
                    ( pProgram->pShaderReflection.Get() != nullptr );
            };
            if( auto pBinaryProgramData = pShaderVariationData->pBinaryProgram.Get() )
            {
                if( pFirstBinaryProgram )
                {
                    isConsistent = IsReflectionConsistent( pFirstBinaryProgram, pBinaryProgramData );
                }
                else
                {
                    pFirstBinaryProgram = pBinaryProgramData;
                }
                CreateProgramOutput( pVariationOutput->GetBinaryOutput(), pBinaryProgramData );
            }
            if( auto pIrProgramData = pShaderVariationData->pIntermediateLanguageProgram.Get() )
            {
                if( pFirstIrProgram )
                {
                    isConsistent = IsReflectionConsistent( pFirstIrProgram, pIrProgramData );
                }
                else
                {
                    pFirstIrProgram = pIrProgramData;
                }
                CreateProgramOutput( pVariationOutput->GetIntermediateLanguageOutput(), pIrProgramData );
            }
            if( auto pSourceProgramData = pShaderVariationData->pSourceProgram.Get() )
            {
                if( pFirstSourceProgram )
                {
                    isConsistent = IsReflectionConsistent( pFirstSourceProgram, pSourceProgramData );
                }
                else
                {
                    pFirstSourceProgram = pSourceProgramData;
                }
                CreateProgramOutput( pVariationOutput->GetSourceOutput(), pSourceProgramData );
            }
            if( !isConsistent )
            {
                auto shaderFileName = pShaderFileData->fileHeader.GetFileName();
                NN_GFXTOOL_THROW_MSG( nngfxToolResultCode_InconsistentShaderFile,
                    "Shader files (%*s and %*s) are inconsistent",
                    static_cast< int >( firstShaderFileName.length() ), firstShaderFileName.data(),
                    static_cast< int >( shaderFileName.length() ), shaderFileName.data() );
            }
        }
    }
} // NOLINT

}
}
