﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <memory>

#include <nn/nn_SdkLog.h>

#include <nn/util/util_BitArray.h>
#include <nn/util/util_Decompression.h>

#include <nn/gfx/gfx_ShaderInfo.h>

#include <nn/gfx/gfx_ResShaderData.h>
#include <nn/gfx/detail/gfx_Shader-api.vk.1.h>

#include "gfx_VkHelper.h"

namespace nn {
namespace gfx {
namespace detail {

typedef ApiVariationVk1 Target;

namespace {

uint64_t GenerateValidMask( ShaderImpl< Target >* pShader, ShaderStage stage ) NN_NOEXCEPT
{
    uint64_t ret = 0ULL;

    if ( pShader->ToData()->pReflection.ptr )
    {
        static nn::util::BinTPtr< ResShaderReflectionStageData >
            const ResShaderReflectionData::* s_pResShaderReflectionStages[] =
        {
            &ResShaderReflectionData::pVertexReflection,
            &ResShaderReflectionData::pHullReflection,
            &ResShaderReflectionData::pDomainReflection,
            &ResShaderReflectionData::pGeometryReflection,
            &ResShaderReflectionData::pPixelReflection,
            &ResShaderReflectionData::pComputeReflection
        };

        static nn::util::BinTPtr< nn::util::ResDic >
            const ResShaderReflectionStageData::*s_pInterfaceDics[] =
        {
            &ResShaderReflectionStageData::pSamplerDic,
            &ResShaderReflectionStageData::pConstantBufferDic,
            &ResShaderReflectionStageData::pUnorderedAccessBufferDic,
            &ResShaderReflectionStageData::pImageDic,
        };

        static int32_t const ResShaderReflectionStageData::*s_pOffsets[] =
        {
            &ResShaderReflectionStageData::offsetSampler,
            &ResShaderReflectionStageData::offsetConstantBuffer,
            &ResShaderReflectionStageData::offsetUnorderedAccessBuffer,
            &ResShaderReflectionStageData::offsetImage
        };

        const ResShaderReflectionData* pResShaderReflection =
            static_cast< const ResShaderReflectionData* >( pShader->ToData()->pReflection );
        if ( const ResShaderReflectionStageData* pResShaderReflectionStage =
            ( pResShaderReflection->*s_pResShaderReflectionStages[ stage ] ).Get() )
        {
            for ( int shaderInterfaceType = 0; shaderInterfaceType < sizeof( s_pInterfaceDics ) / sizeof( s_pInterfaceDics[ 0 ] ); shaderInterfaceType++ )
            {
                if ( const nn::util::ResDic* pResDic =
                    ( pResShaderReflectionStage->*s_pInterfaceDics[ shaderInterfaceType ] ).Get() )
                {
                    int count = pResDic->GetCount();
                    for ( int entryIdx = 0; entryIdx < count; entryIdx++ )
                    {
                        int idxFound = pResDic->FindIndex( pResDic->GetKey( entryIdx ) );
                        if ( idxFound >= 0 )
                        {
                            int offset = s_pOffsets[ shaderInterfaceType ] == NULL
                                ? 0 : pResShaderReflectionStage->*s_pOffsets[ shaderInterfaceType ];
                            ret |= 1ULL << pResShaderReflectionStage->pShaderSlotArray.Get()[ offset + idxFound ];
                        }
                    }
                }
            }
        }
    }
    else
    {
        // Assume all slots are valid
        ret = ~0ULL;
    }

    return ret;
}

template< ShaderCodeType CodeType >
ShaderInitializeResult InitializeShader( DeviceImpl< Target >* pDevice, ShaderImpl< Target >* pThis,
    const ShaderInfo& info )
{
    NN_UNUSED( pDevice );
    NN_UNUSED( pThis );
    NN_UNUSED( info );
    return ShaderInitializeResult_InvalidType;
}

template<>
ShaderInitializeResult InitializeShader< ShaderCodeType_Ir >(
    DeviceImpl< Target >* pDevice, ShaderImpl< Target >* pThis, const ShaderInfo& info )
{
    NN_SDK_ASSERT_NOT_NULL( pThis );
    VkDevice vkDevice = CastToVkDispatchableObject< VkDevice >( pDevice->ToData()->hDevice );

    if ( info.ToData()->flags.GetBit( ShaderInfo::DataType::Flag_ResShader ) )
    {
        const ResShaderProgramData* pResShaderProgram = nn::util::ConstBytePtr( &info,
            -static_cast< ptrdiff_t >( offsetof( ResShaderProgramData, info ) ) ).Get< ResShaderProgramData >();
        pThis->ToData()->pReflection = pResShaderProgram->pShaderReflection.Get();
    }

    for ( int idxStage = 0; idxStage < ShaderStage_End; ++idxStage )
    {
        VkShaderModule shaderModule = VK_NULL_HANDLE;
        const ShaderCode* pShaderCode = static_cast< const ShaderCode* >(
            info.GetShaderCodePtr( static_cast< ShaderStage >( idxStage ) ) );
        if ( pShaderCode )
        {
            VkShaderModuleCreateInfo moduleCreateInfo;
            moduleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
            moduleCreateInfo.pNext = NULL;
            moduleCreateInfo.flags = 0;
            if ( pShaderCode->decompressedCodeSize )
            {
                char* pCode = static_cast< char* >( Vk::AllocDriverMemory( pShaderCode->decompressedCodeSize, 8 ) );
                char* pWorkBuffer = static_cast< char* >( Vk::AllocDriverMemory( nn::util::DecompressZlibWorkBufferSize, 8 ) );
                bool resultDecompress = nn::util::DecompressZlib( pCode,
                    pShaderCode->decompressedCodeSize, pShaderCode->pCode,
                    pShaderCode->codeSize, pWorkBuffer, nn::util::DecompressZlibWorkBufferSize );
                NN_SDK_ASSERT( resultDecompress );
                NN_UNUSED( resultDecompress );

                moduleCreateInfo.codeSize = pShaderCode->decompressedCodeSize;
                moduleCreateInfo.pCode = reinterpret_cast< uint32_t* >( pCode );

                VkResult result;
                result = NN_GFX_CALL_VK_FUNCTION( vkCreateShaderModule( vkDevice, &moduleCreateInfo,
                    pDevice->ToData()->pAllocationCallback, &shaderModule ) );

                Vk::FreeDriverMemory( pWorkBuffer );
                Vk::FreeDriverMemory( pCode );

                if ( result != VK_SUCCESS )
                {
                    return ShaderInitializeResult_SetupFailed;
                }
            }
            else
            {
                moduleCreateInfo.codeSize = pShaderCode->codeSize;
                moduleCreateInfo.pCode = pShaderCode->pCode;

                VkResult result;
                result = NN_GFX_CALL_VK_FUNCTION( vkCreateShaderModule( vkDevice, &moduleCreateInfo,
                    pDevice->ToData()->pAllocationCallback, &shaderModule ) );
                if ( result != VK_SUCCESS )
                {
                    return ShaderInitializeResult_SetupFailed;
                }
            }

            pThis->ToData()->hShaderModule[ idxStage ] = CastFromVkNonDispatchableObject< VkShaderModule >( shaderModule );
        }

        pThis->ToData()->validSlotMask[ idxStage ] = GenerateValidMask( pThis, static_cast< ShaderStage >( idxStage ) );
    }

    return ShaderInitializeResult_Success;
}

template< bool IsArray >
ShaderInitializeResult InitializeSourceShader(
    DeviceImpl< Target >* pDevice, ShaderImpl< Target >* pThis, const ShaderInfo& info ) NN_NOEXCEPT
{
    NN_SDK_ASSERT_NOT_NULL( pThis );

    if ( !pDevice->ToData()->availableExtensions.GetBit( VkDeviceExtension_NvGlslShader ) )
    {
        return ShaderInitializeResult_SetupFailed;
    }

    VkDevice vkDevice = CastToVkDispatchableObject< VkDevice >( pDevice->ToData()->hDevice );

    if ( info.ToData()->flags.GetBit( ShaderInfo::DataType::Flag_ResShader ) )
    {
        const ResShaderProgramData* pResShaderProgram = nn::util::ConstBytePtr( &info,
            -static_cast< ptrdiff_t >( offsetof( ResShaderProgramData, info ) ) ).Get< ResShaderProgramData >();
        pThis->ToData()->pReflection = pResShaderProgram->pShaderReflection.Get();
    }

    for ( int idxStage = 0; idxStage < ShaderStage_End; ++idxStage )
    {
        const void* pCode = info.GetShaderCodePtr( static_cast< ShaderStage >( idxStage ) );
        if ( pCode )
        {
            VkShaderModuleCreateInfo moduleCreateInfo;
            moduleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
            moduleCreateInfo.pNext = NULL;
            moduleCreateInfo.flags = 0;

            VkResult result;
            VkShaderModule shaderModule;

            if ( NN_STATIC_CONDITION( IsArray ) )
            {
                const SourceArrayCode* pSourceArrayCode = static_cast< const SourceArrayCode* >( pCode );
                int codeCount = static_cast< int >( pSourceArrayCode->codeArrayLength );
                int codeSize = 0;
                for( int idxCode = 0; idxCode < codeCount; ++idxCode )
                {
                    codeSize += pSourceArrayCode->pCodeSizeArray.ptr[ idxCode ];
                }
                char* pConnectedCode = static_cast< char* >( Vk::AllocDriverMemory( codeSize + 1, 8 ) );
                pConnectedCode[ codeSize ] = 0;
                codeSize = 0;
                for( int idxCode = 0; idxCode < codeCount; ++idxCode )
                {
                    memcpy( &pConnectedCode[ codeSize ], pSourceArrayCode->pCodePtrArray.ptr[ idxCode ],
                        pSourceArrayCode->pCodeSizeArray.ptr[ idxCode ] );
                    codeSize += pSourceArrayCode->pCodeSizeArray.ptr[ idxCode ];
                }

                moduleCreateInfo.codeSize = codeSize;
                moduleCreateInfo.pCode = reinterpret_cast< const uint32_t* >( pConnectedCode );
                NN_GFX_CALL_VK_FUNCTION( result = vkCreateShaderModule( vkDevice, &moduleCreateInfo,
                    pDevice->ToData()->pAllocationCallback, &shaderModule ) );

                Vk::FreeDriverMemory( pConnectedCode );
            }
            else
            {
                const ShaderCode* pShaderCode = static_cast< const ShaderCode* >( pCode );
                moduleCreateInfo.codeSize = pShaderCode->codeSize;
                moduleCreateInfo.pCode = pShaderCode->pCode;
                NN_GFX_CALL_VK_FUNCTION( result = vkCreateShaderModule( vkDevice, &moduleCreateInfo,
                    pDevice->ToData()->pAllocationCallback, &shaderModule ) );
            }

            if ( result != VK_SUCCESS )
            {
                return ShaderInitializeResult_SetupFailed;
            }
            pThis->ToData()->hShaderModule[ idxStage ] = CastFromVkNonDispatchableObject< VkShaderModule >( shaderModule );
        }

        pThis->ToData()->validSlotMask[ idxStage ] = GenerateValidMask( pThis, static_cast< ShaderStage >( idxStage ) );
    }

    return ShaderInitializeResult_Success;
}

}

size_t ShaderImpl< Target >::GetBinaryCodeAlignment( DeviceImpl< Target >* ) NN_NOEXCEPT
{
    return 1;
}

ShaderImpl< Target >::ShaderImpl() NN_NOEXCEPT
{
    this->state = State_NotInitialized;
}

ShaderImpl< Target >::~ShaderImpl() NN_NOEXCEPT
{
    NN_SDK_ASSERT( this->state == State_NotInitialized || this->flags.GetBit( Flag_Shared ) );
}

ShaderInitializeResult ShaderImpl< Target >::Initialize( DeviceImpl< Target >* pDevice, const InfoType& info ) NN_NOEXCEPT
{
    NN_SDK_REQUIRES( this->state == State_NotInitialized );
    NN_SDK_ASSERT( info.IsSeparationEnabled() || ( info.GetShaderCodePtr(
        ShaderStage_Vertex ) || info.GetShaderCodePtr( ShaderStage_Compute ) ) );

    this->pGfxDevice = pDevice;
    this->flags = info.ToData()->flags;

    static ShaderInitializeResult( *const InitializeFunction[] )(
        DeviceImpl< Target >*, ShaderImpl< Target >*, const ShaderInfo& ) =
    {
        InitializeShader< ShaderCodeType_Binary >,
        InitializeShader< ShaderCodeType_Ir >,
        InitializeSourceShader< false >,
        InitializeSourceShader< true >
    };

    ShaderInitializeResult result = InitializeFunction[ info.GetCodeType() ]( pDevice, this, info );
    // gfx_DebugFontTextWriterが、本関数の戻り値にShaderInitializeResult_Success以外が返ることを考慮した
    // 実装となっているため、一先ずassertせずにエラーリターンとしておきます。
    //NN_SDK_ASSERT( result == ShaderInitializeResult_Success );

    if ( result != ShaderInitializeResult_Success )
    {
        return result;
    }

    this->flags.SetBit( Flag_Shared, false );
    this->state = State_Initialized;
    return ShaderInitializeResult_Success;
}

void ShaderImpl< Target >::Finalize( DeviceImpl< Target >* pDevice ) NN_NOEXCEPT
{
    NN_SDK_REQUIRES( this->state == State_Initialized );
    NN_SDK_ASSERT( !this->flags.GetBit( Flag_Shared ) );

    // Register shader objects so that they're destroyed after clearing piepline cache referring to these objects.
    Vk::RegisterDestroyingShaderObject( pDevice, this );

    this->state = State_NotInitialized;
}

int ShaderImpl< Target >::GetInterfaceSlot( ShaderStage stage,
    ShaderInterfaceType shaderInterfaceType, const char* pName ) const NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL( pName );
    NN_SDK_REQUIRES( IsInitialized( *this ) );

    int ret = -1;

    if ( this->pReflection.ptr )
    {
        static nn::util::BinTPtr< ResShaderReflectionStageData >
            const ResShaderReflectionData::* s_pResShaderReflectionStages[] =
        {
            &ResShaderReflectionData::pVertexReflection,
            &ResShaderReflectionData::pHullReflection,
            &ResShaderReflectionData::pDomainReflection,
            &ResShaderReflectionData::pGeometryReflection,
            &ResShaderReflectionData::pPixelReflection,
            &ResShaderReflectionData::pComputeReflection
        };

        static nn::util::BinTPtr< nn::util::ResDic >
            const ResShaderReflectionStageData::*s_pInterfaceDics[] =
        {
            &ResShaderReflectionStageData::pShaderInputDic,
            &ResShaderReflectionStageData::pShaderOutputDic,
            &ResShaderReflectionStageData::pSamplerDic,
            &ResShaderReflectionStageData::pConstantBufferDic,
            &ResShaderReflectionStageData::pUnorderedAccessBufferDic,
            &ResShaderReflectionStageData::pImageDic,
        };

        static int32_t const ResShaderReflectionStageData::*s_pOffsets[] =
        {
            NULL,
            &ResShaderReflectionStageData::offsetShaderOutput,
            &ResShaderReflectionStageData::offsetSampler,
            &ResShaderReflectionStageData::offsetConstantBuffer,
            &ResShaderReflectionStageData::offsetUnorderedAccessBuffer,
            &ResShaderReflectionStageData::offsetImage
        };

        const ResShaderReflectionData* pResShaderReflection =
            static_cast< const ResShaderReflectionData* >( this->pReflection );
        if ( const ResShaderReflectionStageData* pResShaderReflectionStage =
            ( pResShaderReflection->*s_pResShaderReflectionStages[ stage ] ).Get() )
        {
            if ( const nn::util::ResDic* pResDic =
                ( pResShaderReflectionStage->*s_pInterfaceDics[ shaderInterfaceType ] ).Get() )
            {
                int idxFound = pResDic->FindIndex( pName );
                if ( idxFound >= 0 )
                {
                    int offset = s_pOffsets[ shaderInterfaceType ] == NULL
                        ? 0 : pResShaderReflectionStage->*s_pOffsets[ shaderInterfaceType ];
                    return pResShaderReflectionStage->pShaderSlotArray.Get()[ offset + idxFound ];
                }
            }
        }

        return -1;
    }

    return ret;
}

void ShaderImpl< Target >::GetWorkGroupSize( int* pOutWorkGroupSizeX,
    int* pOutWorkGroupSizeY, int* pOutWorkGroupSizeZ ) const NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL( pOutWorkGroupSizeX );
    NN_SDK_REQUIRES_NOT_NULL( pOutWorkGroupSizeY );
    NN_SDK_REQUIRES_NOT_NULL( pOutWorkGroupSizeZ );
    NN_SDK_REQUIRES( IsInitialized( *this ) );

    if ( this->pReflection )
    {
        const ResShaderReflectionData* pResShaderReflection =
            static_cast< const ResShaderReflectionData* >( this->pReflection );
        const ResShaderReflectionStageData* pComputeReflection
            = pResShaderReflection->pComputeReflection.Get();
        NN_SDK_ASSERT_NOT_NULL( pComputeReflection );
        *pOutWorkGroupSizeX = pComputeReflection->computeWorkGroupSizeX;
        *pOutWorkGroupSizeY = pComputeReflection->computeWorkGroupSizeY;
        *pOutWorkGroupSizeZ = pComputeReflection->computeWorkGroupSizeZ;
    }
    else
    {
        *pOutWorkGroupSizeX = -1;
        *pOutWorkGroupSizeY = -1;
        *pOutWorkGroupSizeZ = -1;
    }
}

}
}
}
