﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <nn/gfx/gfx_ResShaderData-api.nvn.h>

#include <nn/gfxTool/gfxTool_Util.h>

#include <gfxTool_ShaderCodeMerger.h>
#include <gfxTool_ShaderCodeBinarizer.h>
#include <gfxTool_ResShaderProgramBinarizer.h>
#include <gfxTool_ResShaderVariationBinarizer.h>
#include <gfxTool_ResShaderContainerBinarizer.h>
#include <gfxTool_NvnShaderCodeBinarizer.h>

namespace nn {
namespace gfxTool {

namespace {

static nn::gfxTool::ResShaderProgramBinarizer* (
    nn::gfxTool::ResShaderVariationBinarizer::* const s_GetProgramBinarizer[])() =
{
    &nn::gfxTool::ResShaderVariationBinarizer::GetBinaryProgramBinarizer,
    &nn::gfxTool::ResShaderVariationBinarizer::GetIntermediateLanguageProgramBinarizer,
    &nn::gfxTool::ResShaderVariationBinarizer::GetSourceProgramBinarizer
};

bool LessShaderCode( const nn::gfx::ShaderCode* lhs, const nn::gfx::ShaderCode* rhs )
{
    if( lhs->codeSize != rhs->codeSize )
    {
        return lhs->codeSize < rhs->codeSize;
    }
    if( lhs->pCode.ptr == rhs->pCode.ptr )
    {
        return false;
    }
    return memcmp( lhs->pCode, rhs->pCode, lhs->codeSize ) < 0;
}

bool LessCode( const void* pLhsCode, size_t lhsCodeSize, const void* pRhsCode, size_t rhsCodeSize )
{
    if( lhsCodeSize != rhsCodeSize )
    {
        return lhsCodeSize < rhsCodeSize;
    }
    if( pLhsCode == pRhsCode )
    {
        return false;
    }
    return memcmp( pLhsCode, pRhsCode, lhsCodeSize ) < 0;
}

bool LessSource( const nn::gfx::SourceArrayCode* lhs,
    const nn::gfx::SourceArrayCode* rhs, int idxSource )
{
    if( lhs->pCodeSizeArray[ idxSource ] != rhs->pCodeSizeArray[ idxSource ] )
    {
        return lhs->pCodeSizeArray[ idxSource ] < rhs->pCodeSizeArray[ idxSource ];
    }
    if( lhs->pCodePtrArray[ idxSource ].ptr == rhs->pCodePtrArray[ idxSource ].ptr )
    {
        return false;
    }
    return memcmp( lhs->pCodePtrArray[ idxSource ], rhs->pCodePtrArray[ idxSource ],
        lhs->pCodeSizeArray[ idxSource ] ) < 0;
}

}

void ShaderCodeMerger::MergeShaderCode( const Custom< std::vector< ResShaderVariationBinarizer*
    > >::Type& pSortBinarizersOriginal, ShaderStage stage, int idxCodeType )
{
    auto pSortBinarizers = pSortBinarizersOriginal;
    std::sort( pSortBinarizers.begin(), pSortBinarizers.end(), [ & ](
        decltype( pSortBinarizers[ 0 ] ) lhs, decltype( pSortBinarizers[ 0 ] ) rhs ) {
            auto pLhsShaderCode = ( lhs->*s_GetProgramBinarizer[ idxCodeType ] )(
                )->GetShaderCodeBinarizer< ShaderCodeBinarizer >( stage )->GetShaderCode();
            auto pRhsShaderCode = ( rhs->*s_GetProgramBinarizer[ idxCodeType ] )(
                )->GetShaderCodeBinarizer< ShaderCodeBinarizer >( stage )->GetShaderCode();
            return LessCode( pLhsShaderCode->pCode.ptr, pLhsShaderCode->codeSize,
                pRhsShaderCode->pCode.ptr, pRhsShaderCode->codeSize ); } );
    m_Count[ static_cast< int >( MergeType::Code ) ] +=
        static_cast< int >( pSortBinarizersOriginal.size() );
    for( int idxSort = 1, merge = 0; idxSort < static_cast<
        decltype( idxSort ) >( pSortBinarizersOriginal.size() ); ++idxSort )
    {
        auto pProgram = ( pSortBinarizers.at( idxSort )->*s_GetProgramBinarizer[ idxCodeType ] )();
        auto pTargetProgram = ( pSortBinarizers.at( idxSort - 1 )->*s_GetProgramBinarizer[ idxCodeType ] )();
        auto pLhsShaderCode = pTargetProgram->GetShaderCodeBinarizer<
            ShaderCodeBinarizer >( stage )->GetShaderCode();
        auto pRhsShaderCode = pProgram->GetShaderCodeBinarizer<
            ShaderCodeBinarizer >( stage )->GetShaderCode();
        if( LessCode( pLhsShaderCode->pCode.ptr, pLhsShaderCode->codeSize,
                pRhsShaderCode->pCode.ptr, pRhsShaderCode->codeSize ) )
        {
            merge = idxSort;
        }
        else
        {
            ++m_MergedCount[ static_cast< int >( MergeType::Code ) ];
            pProgram->MergeCode( stage, StaticCastAuto(
                pSortBinarizers.at( merge ) - pSortBinarizersOriginal.at( 0 ) ) );
        }
    }
}

void ShaderCodeMerger::MergeSourceArrayCode( const Custom< std::vector< ResShaderVariationBinarizer*
    > >::Type& pSortBinarizersOriginal, ShaderStage stage, int codeArrayLength )
{
    for( int idxSource = 0; idxSource < codeArrayLength; ++idxSource )
    {
        auto pSortBinarizers = pSortBinarizersOriginal;
        std::sort( pSortBinarizers.begin(), pSortBinarizers.end(), [ & ](
            decltype( pSortBinarizers[ 0 ] ) lhs, decltype( pSortBinarizers[ 0 ] ) rhs ) {
            return LessSource( lhs->GetSourceProgramBinarizer(
            )->GetShaderCodeBinarizer< SourceArrayBinarizer >( stage )->GetSourceArrayCode(),
            rhs->GetSourceProgramBinarizer()->GetShaderCodeBinarizer<
            SourceArrayBinarizer >( stage )->GetSourceArrayCode(), idxSource ); } );
        m_Count[ static_cast< int >( MergeType::SourceArray ) ] +=
            static_cast< int >( pSortBinarizersOriginal.size() );
        for( int idxSort = 1, merge = 0; idxSort < static_cast<
            decltype( idxSort ) >( pSortBinarizersOriginal.size() ); ++idxSort )
        {
            auto pProgram = pSortBinarizers.at( idxSort )->GetSourceProgramBinarizer();
            auto pTargetProgram = pSortBinarizers.at( idxSort - 1 )->GetSourceProgramBinarizer();
            if( LessSource( pTargetProgram->GetShaderCodeBinarizer< SourceArrayBinarizer >(
                stage )->GetSourceArrayCode(), pProgram->GetShaderCodeBinarizer<
                SourceArrayBinarizer >( stage )->GetSourceArrayCode(), idxSource ) )
            {
                merge = idxSort;
            }
            else
            {
                ++m_MergedCount[ static_cast<int>( MergeType::SourceArray ) ];
                pProgram->GetShaderCodeBinarizer< SourceArrayBinarizer >( stage )->MergeSource(
                    idxSource, StaticCastAuto( pSortBinarizers.at( merge ) - pSortBinarizersOriginal.at( 0 ) ) );
            }
        }
    }
}

void ShaderCodeMerger::MergeNvnShaderCode( const Custom< std::vector<
    ResShaderVariationBinarizer* > >::Type& pSortBinarizersOriginal, ShaderStage stage )
{
    auto LessNvnControl = []( const void* pLhsControl, size_t lhsControlSize,
        const void* pRhsControl, size_t rhsControlSize ) -> bool
    {
        // デバッグ情報が ON のときはコントロールセクションにビルド ID とデバッグハッシュが入っている
        const int buildIdOffset = 1896;
        const int buildIdAndDebugHashSize = 16 + 8;
        // GPU コードバージョン 1.11 以降はコントロールセクションにソースと GLASM と ucode のハッシュが入っている
        const int codeHashOffset = 2000;
        const int sourceAndGlasmAndUcodeHashSize = 8 + 8 + 8;

        if( lhsControlSize != rhsControlSize )
        {
            return lhsControlSize < rhsControlSize;
        }
        if( pLhsControl == pRhsControl )
        {
            return false;
        }
        if( lhsControlSize <= buildIdOffset )
        {
            return memcmp( pLhsControl, pRhsControl, lhsControlSize ) < 0;
        }

        if( auto firstComp = memcmp( pLhsControl, pRhsControl, buildIdOffset ) )
        {
            return firstComp < 0;
        }
        int latterStart = buildIdOffset + buildIdAndDebugHashSize;
        if( lhsControlSize >= codeHashOffset + sourceAndGlasmAndUcodeHashSize )
        {
            if( auto secondComp = memcmp( nn::util::ConstBytePtr( pLhsControl, latterStart ).Get(),
                nn::util::ConstBytePtr( pRhsControl, latterStart ).Get(), codeHashOffset - latterStart ) )
            {
                return secondComp < 0;
            }
            latterStart = codeHashOffset + sourceAndGlasmAndUcodeHashSize;
        }
        return memcmp( nn::util::ConstBytePtr( pLhsControl, latterStart ).Get(),
            nn::util::ConstBytePtr( pRhsControl, latterStart ).Get(),
            lhsControlSize - latterStart ) < 0;
    };

    auto MergeCodeImpl = [ & ](
        std::function< const void* ( const NvnShaderCodeBinarizer* pNvnShaderCode ) > GetCode,
        std::function< size_t ( const NvnShaderCodeBinarizer* pNvnShaderCode ) > GetCodeSize,
        std::function< bool (const void* pLhs, size_t lhsSize, const void* pRhs, size_t rhsSize ) > LessFunc,
        std::function< void ( NvnShaderCodeBinarizer* pBinarizer, int idxTarget ) > Merge )
    {
        auto pSortBinarizers = pSortBinarizersOriginal;
        std::sort( pSortBinarizers.begin(), pSortBinarizers.end(), [ & ](
            decltype( pSortBinarizers[ 0 ] ) lhs, decltype( pSortBinarizers[ 0 ] ) rhs ) {
                auto pLhsNvnShaderCode = lhs->GetBinaryProgramBinarizer(
                    )->GetShaderCodeBinarizer< NvnShaderCodeBinarizer >( stage );
                auto pRhsNvnShaderCode = rhs->GetBinaryProgramBinarizer(
                    )->GetShaderCodeBinarizer< NvnShaderCodeBinarizer >( stage );
                return LessFunc( GetCode( pLhsNvnShaderCode ), GetCodeSize( pLhsNvnShaderCode ),
                    GetCode( pRhsNvnShaderCode ), GetCodeSize( pRhsNvnShaderCode ) ); } );
        for( int idxSort = 1, merge = 0; idxSort < static_cast<
            decltype( idxSort ) >( pSortBinarizersOriginal.size() ); ++idxSort )
        {
            auto pProgram = pSortBinarizers.at( idxSort )->GetBinaryProgramBinarizer();
            auto pTargetProgram = pSortBinarizers.at( idxSort - 1 )->GetBinaryProgramBinarizer();
            auto pLhsNvnShaderCode = pTargetProgram->GetShaderCodeBinarizer<
                NvnShaderCodeBinarizer >( stage );
            auto pRhsNvnShaderCode = pProgram->GetShaderCodeBinarizer<
                NvnShaderCodeBinarizer >( stage );
            if( LessFunc( GetCode( pLhsNvnShaderCode ), GetCodeSize( pLhsNvnShaderCode ),
                GetCode( pRhsNvnShaderCode ), GetCodeSize( pRhsNvnShaderCode ) ) )
            {
                merge = idxSort;
            }
            else
            {
                auto pNvnShaderCodeBinarizer =
                    pProgram->GetShaderCodeBinarizer< NvnShaderCodeBinarizer >( stage );
                auto targetVariation = pSortBinarizers.at( merge ) - pSortBinarizersOriginal.at( 0 );
                Merge( pNvnShaderCodeBinarizer, StaticCastAuto( targetVariation ) );
            }
        }
    };

    m_Count[ static_cast<int>( MergeType::Data ) ] += static_cast< int >( pSortBinarizersOriginal.size() );
    MergeCodeImpl( []( const NvnShaderCodeBinarizer* pNvnShaderCode ) {
        return pNvnShaderCode->GetShaderCode()->pData.Get(); },
        []( const NvnShaderCodeBinarizer* pNvnShaderCode ) {
            return pNvnShaderCode->GetShaderCode()->dataSize; }, LessCode,
        [ & ]( NvnShaderCodeBinarizer* pBinarizer, int idxTarget ) {
            ++m_MergedCount[ static_cast<int>( MergeType::Data ) ];
            pBinarizer->MergeData( idxTarget ); } );

    auto pFirstShaderCodeBinarizer = static_cast< const NvnShaderCodeBinarizer* >(
        pSortBinarizersOriginal.at( 0 )->GetBinaryProgramBinarizer()->GetShaderCodeBinarizer( stage ) );
    if( pFirstShaderCodeBinarizer->GetDecomposedControlSectionBinarizer() )
    {
        using SubsectionType = NvnDecomposedControlSectionBinarizer::ControlSubsectionType;
        for( int idxSubsection = 0; idxSubsection < static_cast<
            int >( SubsectionType::End ); ++idxSubsection )
        {
            int idxMergeType = idxSubsection + static_cast< int >( MergeType::MetaData );
            m_Count[ idxMergeType ] += static_cast< int >( pSortBinarizersOriginal.size() );
            MergeCodeImpl( [ & ]( const NvnShaderCodeBinarizer* pNvnShaderCode ) {
                return pNvnShaderCode->GetDecomposedControlSectionBinarizer()->GetControlSubsection(
                    StaticCastAuto( idxSubsection ) ).pData; },
                [ & ]( const NvnShaderCodeBinarizer* pNvnShaderCode ) {
                    return pNvnShaderCode->GetDecomposedControlSectionBinarizer()->GetControlSubsection(
                    StaticCastAuto( idxSubsection ) ).mergeBlock.block.GetSize(); }, LessNvnControl,
                [ & ]( NvnShaderCodeBinarizer* pBinarizer, int idxTarget ) {
                    ++m_MergedCount[ idxMergeType ];
                    pBinarizer->GetDecomposedControlSectionBinarizer()->MergeControlSubsection(
                    StaticCastAuto( idxSubsection ), idxTarget ); } );
        }
    }
    else
    {
        m_Count[ static_cast<int>( MergeType::Control ) ] += static_cast< int >( pSortBinarizersOriginal.size() );
        MergeCodeImpl( []( const NvnShaderCodeBinarizer* pNvnShaderCode ) {
            return pNvnShaderCode->GetShaderCode()->pControl.Get(); },
            []( const NvnShaderCodeBinarizer* pNvnShaderCode ) {
                return pNvnShaderCode->GetShaderCode()->controlSize; }, LessNvnControl,
            [ & ]( NvnShaderCodeBinarizer* pBinarizer, int idxTarget ) {
                ++m_MergedCount[ static_cast<int>( MergeType::Control ) ];
                pBinarizer->MergeControl( idxTarget ); } );
    }
} // NOLINT

void ShaderCodeMerger::ResetCount()
{
    for( auto&& count : m_Count )
    {
        count = 0;
    }
    for( auto&& mergedCount : m_MergedCount )
    {
        mergedCount = 0;
    }
}

// TODO GX2 シェーダの一致判定
void ShaderCodeMerger::MergeCode( ResShaderContainerBinarizer* pResShaderContainerBinarizer,
    const nngfxToolShaderConverterConvertArg* )
{
    ResetCount();

    auto& variationBinarizers = *pResShaderContainerBinarizer->GetVariationBinarizers();
    Custom< std::vector< decltype( &variationBinarizers[ 0 ] ) > >::Type pSortBinarizersOriginal;
    pSortBinarizersOriginal.resize( variationBinarizers.size() );
    std::transform( variationBinarizers.begin(), variationBinarizers.end(), pSortBinarizersOriginal.begin(),
        []( decltype( variationBinarizers[ 0 ] )& value ) { return &value; } );

    for( int idxStage = 0; idxStage < static_cast< int >( ShaderStage::End ); ++idxStage )
    {
        auto stage = static_cast< ShaderStage >( idxStage );
        auto& firstVariation = variationBinarizers.at( 0 );
        for( int idxCodeType = 0, codeTypeCount =
            sizeof( s_GetProgramBinarizer ) / sizeof( *s_GetProgramBinarizer );
            idxCodeType < codeTypeCount; ++idxCodeType )
        {
            auto pFirstProgram = ( firstVariation.*s_GetProgramBinarizer[ idxCodeType ] )();
            if( pFirstProgram == nullptr || pFirstProgram->GetShaderCodeBinarizer< void >( stage ) == nullptr )
            {
                continue;
            }

            if( pFirstProgram->GetInfoData()->codeType == nn::gfx::ShaderCodeType_SourceArray )
            {
                auto pFirstSourceArrayCode = pFirstProgram->GetShaderCodeBinarizer<
                    SourceArrayBinarizer >( stage );
                if( pFirstSourceArrayCode == nullptr )
                {
                    continue;
                }
                MergeSourceArrayCode( pSortBinarizersOriginal, stage,
                    pFirstSourceArrayCode->GetSourceArrayCode()->codeArrayLength );
            }
            else
            {
                if( pResShaderContainerBinarizer->GetArg()->pCompileArg->targetLowLevelApiType
                    == nngfxToolShaderCompilerLowLevelApiType_Nvn )
                {
                    MergeNvnShaderCode( pSortBinarizersOriginal, stage );
                }
                else
                {
                    MergeShaderCode( pSortBinarizersOriginal, stage, idxCodeType );
                }
            }
        }
    }
}

}
}
