﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <nn/gfx/gfx_Shader.h>
#include <nn/gfx/gfx_ResShaderData.h>

#include <nn/gfx/detail/gfx_Shader-api.gl.4.h>
#include <nn/gfx/detail/gfx_Shader-api.gx.2.h>
#include <nn/gfx/detail/gfx_Shader-api.nvn.8.h>
#include <nn/gfx/detail/gfx_Shader-api.d3d.11.h>
#include <nn/gfx/detail/gfx_Shader-api.vk.1.h>

#include <nn/gfxTool/gfxTool_Util.h>

#include <gfxTool_ResShaderProgramBinarizer.h>
#include <gfxTool_ResShaderVariationBinarizer.h>
#include <gfxTool_ResShaderContainerBinarizer.h>
#include <gfxTool_ShaderBinarizerContext.h>
#include <gfxTool_ShaderCodeBinarizer.h>
#include <gfxTool_NvnShaderCodeBinarizer.h>
#include <gfxTool_ShaderReflectionBinarizer.h>
#include <gfxTool_FullReflectionBinarizer.h>

namespace {

static nn::gfx::detail::Ptr< const void > nn::gfx::ShaderInfoData::* const g_pStageCode[] =
{
    &nn::gfx::ShaderInfoData::pVertexShaderCode,
    &nn::gfx::ShaderInfoData::pHullShaderCode,
    &nn::gfx::ShaderInfoData::pDomainShaderCode,
    &nn::gfx::ShaderInfoData::pGeometryShaderCode,
    &nn::gfx::ShaderInfoData::pPixelShaderCode,
    &nn::gfx::ShaderInfoData::pComputeShaderCode
};

size_t GetSizeOfShaderObj( int targetLowLevelApiType )
{
    const size_t GlSize = 64;
    const size_t GxSize = 64;
    const size_t NvnSize = 256;
    const size_t D3dSize = 256;
    const size_t VkSize = 128;

    NN_STATIC_ASSERT( GlSize >= sizeof( nn::gfx::TShader<
        nn::gfx::ApiType< nn::gfx::LowLevelApi_Gl > >::DataType ) );
    NN_STATIC_ASSERT( GxSize >= sizeof( nn::gfx::TShader<
        nn::gfx::ApiType< nn::gfx::LowLevelApi_Gx > >::DataType ) );
    NN_STATIC_ASSERT( NvnSize >= sizeof( nn::gfx::TShader<
        nn::gfx::ApiType< nn::gfx::LowLevelApi_Nvn > >::DataType ) );
    NN_STATIC_ASSERT( D3dSize >= sizeof( nn::gfx::TShader<
        nn::gfx::ApiType< nn::gfx::LowLevelApi_D3d > >::DataType ) );
    NN_STATIC_ASSERT( VkSize >= sizeof( nn::gfx::TShader<
        nn::gfx::ApiType< nn::gfx::LowLevelApi_Vk > >::DataType ) );

    const size_t ObjSizeTable[] = { GlSize, GxSize, NvnSize, D3dSize, VkSize };

    if( targetLowLevelApiType < 0 || targetLowLevelApiType >=
        sizeof( ObjSizeTable ) / sizeof( *ObjSizeTable ) )
    {
        NN_GFXTOOL_THROW( nngfxToolResultCode_InternalError );
    }

    return ObjSizeTable[ targetLowLevelApiType ];
}

const nngfxToolShaderCompilerShaderReflection* GetReflectionOutput(
    const nngfxToolShaderCompilerShaderProgramOutput* pOutput )
{
    for( int idxOptionOutput = 0; idxOptionOutput < static_cast< int >(
        pOutput->optionOutputCount ); ++idxOptionOutput )
    {
        auto pOptionOutput = pOutput->pOptionOutputArray + idxOptionOutput;
        if( pOptionOutput->optionOutputType == nngfxToolShaderCompilerOptionOutputType_ProgramCommon )
        {
            auto pCommonProgramOutput = static_cast<
                nngfxToolShaderCompilerOptionOutputProgramCommon* >( pOptionOutput->pOutput );
            return pCommonProgramOutput->pReflection;
        }
    }
    return nullptr;
}

}

namespace nn {
namespace gfxTool {

ResShaderProgramBinarizer::ResShaderProgramBinarizer()
    : m_pTarget()
    , m_pParent()
{
}

ResShaderProgramBinarizer::~ResShaderProgramBinarizer() = default;

void ResShaderProgramBinarizer::Initialize(
    const BinarizationTarget* pTarget, const ResShaderVariationBinarizer* pParent )
{
    m_pTarget = pTarget;
    m_pParent = pParent;

    auto apiType = m_pParent->GetParent()->GetArg()->pCompileArg->targetLowLevelApiType;
    auto codeType = GetInfoData()->codeType;
    for( int idxStage = 0; idxStage < static_cast< int >( ShaderStage::End ); ++idxStage )
    {
        if( static_cast< nn::gfx::ShaderInfoData* >(
            pTarget->nngfxShaderInfoData )->*g_pStageCode[ idxStage ] == nullptr )
        {
            continue;
        }

        auto stage = static_cast< ShaderStage >( idxStage );
        m_MergeCode[ idxStage ] = -1;
        if( codeType == nn::gfx::ShaderCodeType_SourceArray )
        {
            m_pCodeBinarizer[ idxStage ].reset( new SourceArrayBinarizer() );
            static_cast< SourceArrayBinarizer* >( m_pCodeBinarizer[ idxStage ].get() )->Initialize(
                StaticCastAuto( GetStageCode( GetInfoData(), stage ) ), this, stage );
        }
        else
        {
            if( apiType == nngfxToolShaderCompilerLowLevelApiType_Nvn )
            {
                m_pCodeBinarizer[ idxStage ].reset( new NvnShaderCodeBinarizer() );
                static_cast< NvnShaderCodeBinarizer* >( m_pCodeBinarizer[ idxStage ].get() )->Initialize(
                    StaticCastAuto( GetStageCode( GetInfoData(), stage ) ), this, stage );
            }
            else
            {
                m_pCodeBinarizer[ idxStage ].reset( new ShaderCodeBinarizer() );
                static_cast< ShaderCodeBinarizer* >( m_pCodeBinarizer[ idxStage ].get() )->Initialize(
                    StaticCastAuto( GetStageCode( GetInfoData(), stage ) ), this );

                if( apiType == nngfxToolShaderCompilerLowLevelApiType_Vk )
                {
                    if( auto pReflection = GetReflectionOutput( pTarget ) )
                    {
                        m_pReflectionBinarizer.reset( new ShaderReflectionBinarizer() );
                        m_pReflectionBinarizer.get()->Initialize( pReflection );
                    }
                }
            }
        }
    }

    if( auto pReflection = GetReflectionOutput( pTarget ) )
    {
        if( codeType == nn::gfx::ShaderCodeType_Binary &&
            apiType == nngfxToolShaderCompilerLowLevelApiType_Nvn ||
            codeType == nn::gfx::ShaderCodeType_Binary &&
            apiType == nngfxToolShaderCompilerLowLevelApiType_Vk )
        {
            m_pReflectionBinarizer.reset( new ShaderReflectionBinarizer() );
            m_pReflectionBinarizer.get()->Initialize( pReflection );
        }

        if( m_pParent->GetParent()->GetArg()->enableFullReflection != 0 )
        {
            m_pFullReflectionBinarizer.reset( new FullReflectionBinarizer() );
            m_pFullReflectionBinarizer->Initialize( pReflection );
        }
    }
}

void ResShaderProgramBinarizer::RegisterChild( ShaderBinarizerContext* pContext )
{
    pContext->AddMemoryBlock( StaticCastAuto( Section::Obj ), &m_Obj );

    for( int idxStage = 0; idxStage < static_cast< int >( ShaderStage::End ); ++idxStage )
    {
        if( m_pCodeBinarizer[ idxStage ] != nullptr && m_MergeCode[ idxStage ] < 0 )
        {
            pContext->AddMemoryBlock( StaticCastAuto( Section::Common ), m_pCodeBinarizer[ idxStage ].get() );
        }
    }

    if( m_pReflectionBinarizer )
    {
        pContext->AddMemoryBlock( StaticCastAuto( Section::Common ), m_pReflectionBinarizer.get() );
    }
    if( m_pFullReflectionBinarizer )
    {
        pContext->AddMemoryBlock( StaticCastAuto( Section::Common ), m_pFullReflectionBinarizer.get() );
    }

    for( int idxStage = 0; idxStage < static_cast< int >( ShaderStage::End ); ++idxStage )
    {
        if( m_pCodeBinarizer[ idxStage ] != nullptr && m_MergeCode[ idxStage ] < 0 )
        {
            m_pCodeBinarizer[ idxStage ]->RegisterChild( pContext );
        }
    }
    if( m_pReflectionBinarizer.get() )
    {
        m_pReflectionBinarizer->RegisterChild( pContext );
    }
    if( m_pFullReflectionBinarizer.get() )
    {
        m_pFullReflectionBinarizer->RegisterChild( pContext );
    }
}

void ResShaderProgramBinarizer::CalculateSize()
{
    SetSizeBy< ResTarget >();

    m_Obj.SetSize( GetSizeOfShaderObj(
        m_pParent->GetParent()->GetArg()->pCompileArg->targetLowLevelApiType ) );
}

void ResShaderProgramBinarizer::Link( ShaderBinarizerContext* pContext )
{
    const ResTarget* pTarget = nullptr;

    pContext->LinkPtr( this, &pTarget->pParent, GetParent() );

    for( int idxStage = 0; idxStage < static_cast< int >( ShaderStage::End ); ++idxStage )
    {
        if( m_pCodeBinarizer[ idxStage ] == nullptr )
        {
            continue;
        }

        if( m_MergeCode[ idxStage ] < 0 )
        {
            pContext->LinkPtr( this, reinterpret_cast< const nn::util::BinPtr* >(
                &( pTarget->info.*g_pStageCode[ idxStage ] ) ), m_pCodeBinarizer[ idxStage ].get() );
        }
        else
        {
            static const ResShaderProgramBinarizer* (
                ResShaderVariationBinarizer::*s_GetProgramBinarizer [])() const =
            {
                &ResShaderVariationBinarizer::GetBinaryProgramBinarizer,
                &ResShaderVariationBinarizer::GetIntermediateLanguageProgramBinarizer,
                &ResShaderVariationBinarizer::GetSourceProgramBinarizer,
                &ResShaderVariationBinarizer::GetSourceProgramBinarizer
            };
            auto pMergeProgram = ( m_pParent->GetParent()->GetVariationBinarizers()->at(
                m_MergeCode[ idxStage ] ).*s_GetProgramBinarizer[ GetInfoData()->codeType ] )();
            pContext->LinkPtr( this, reinterpret_cast< const nn::util::BinPtr* >(
                &( pTarget->info.*g_pStageCode[ idxStage ] ) ),
                pMergeProgram->m_pCodeBinarizer[ idxStage ].get() );
        }
    }

    pContext->LinkPtr( this, &pTarget->pObj, &m_Obj );

    pContext->LinkPtr( this, &pTarget->pShaderReflection, m_pReflectionBinarizer.get() );
    pContext->LinkPtr( this, &pTarget->pShaderCompilerReflection, m_pFullReflectionBinarizer.get() );
}

void ResShaderProgramBinarizer::Convert( ShaderBinarizerContext* pContext )
{
    auto pTarget = Get< ResTarget >( pContext->GetPtr() );

    pTarget->info = *GetInfoData();
    pTarget->info.flags.SetBit( nn::gfx::ShaderInfo::DataType::Flag_ResShader, true );
    pTarget->objSize = NumericCastAuto( m_Obj.GetSize() );

    memset( m_Obj.Get( pContext->GetPtr() ), 0, m_Obj.GetSize() );
}

}
}
