﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <memory>

#include <nn/util/util_BitArray.h>
#include <nn/util/util_Decompression.h>

#include <nn/gfx/gfx_ShaderInfo.h>

#include <nn/gfx/detail/gfx_Device-api.d3d.11.h>
#include <nn/gfx/detail/gfx_Shader-api.d3d.11.h>

#include "gfx_D3dHelper.h"

namespace nn {
namespace gfx {
namespace detail {

typedef ApiVariationD3d11 Target;

namespace {

ShaderInitializeResult InitializeSourceShader(
    ShaderImpl< Target >* pThis, DeviceImpl< Target >* pDevice, const ShaderInfo& info ) NN_NOEXCEPT
{
    NN_SDK_ASSERT_NOT_NULL( pThis );

    DeviceImpl<Target>::DataType &deviceData = pDevice->ToData();

    static const char* shaderStageTarget[] =
    {
        "vs_5_0",
        "hs_5_0",
        "ds_5_0",
        "gs_5_0",
        "ps_5_0",
        "cs_5_0"
    };
    NN_STATIC_ASSERT( NN_GFX_ARRAY_LENGTH( shaderStageTarget ) == ShaderStage_End );

    for( int idxStage = 0; idxStage < ShaderStage_End; ++idxStage )
    {
        if( const void* pCode = info.GetShaderCodePtr( static_cast< ShaderStage >( idxStage ) ) )
        {
            const ShaderCode* pShaderCode = static_cast< const ShaderCode* >( pCode );
            int codeSize = static_cast< int >( pShaderCode->codeSize );
            const char* pCodeString = static_cast< const char* >( pShaderCode->pCode );
            UINT compileFlag = D3DCOMPILE_ENABLE_STRICTNESS;
            if ( deviceData.debugMode )
            {
                compileFlag |= D3DCOMPILE_DEBUG | D3DCOMPILE_SKIP_OPTIMIZATION;
            }

            ID3DBlob* pBlob;
            ID3DBlob* pErrorMessages;
            // entry pointは main で固定とします。
            HRESULT hResult = NN_GFX_CALL_D3D_FUNCTION( D3DCompile( pCodeString, codeSize, NULL,
                NULL, NULL, "main", shaderStageTarget[ idxStage ], compileFlag, 0, &pBlob, &pErrorMessages ) );

            if ( FAILED( hResult ) )
            {
                NN_SDK_ASSERT( 0, "%s", pErrorMessages->GetBufferPointer() );
                return ShaderInitializeResult_SetupFailed;
            }

            ID3D11Device* pD3dDevice = static_cast< ID3D11Device* >( deviceData.renderingContext.hD3dDevice );

            switch ( static_cast< ShaderStage >( idxStage ) )
            {
            case ShaderStage_Vertex:
                {
                    ID3D11VertexShader* pVertexShader;
                    NN_GFX_CALL_D3D_FUNCTION( pD3dDevice->CreateVertexShader( pBlob->GetBufferPointer(),
                        pBlob->GetBufferSize(), NULL, &pVertexShader ) );
                    pThis->ToData()->pShaderProgram[ idxStage ] = pVertexShader;
                }
                break;
            case ShaderStage_Hull:
                {
                    ID3D11HullShader* pHullShader;
                    NN_GFX_CALL_D3D_FUNCTION( pD3dDevice->CreateHullShader( pBlob->GetBufferPointer(),
                        pBlob->GetBufferSize(), NULL, &pHullShader ) );
                    pThis->ToData()->pShaderProgram[ idxStage ] = pHullShader;
                }
                break;
            case ShaderStage_Domain:
                {
                    ID3D11DomainShader* pDomainShader;
                    NN_GFX_CALL_D3D_FUNCTION( pD3dDevice->CreateDomainShader( pBlob->GetBufferPointer(),
                        pBlob->GetBufferSize(), NULL, &pDomainShader ) );
                    pThis->ToData()->pShaderProgram[ idxStage ] = pDomainShader;
                }
                break;
            case ShaderStage_Geometry:
                {
                    ID3D11GeometryShader* pGeometryShader;
                    NN_GFX_CALL_D3D_FUNCTION( pD3dDevice->CreateGeometryShader( pBlob->GetBufferPointer(),
                        pBlob->GetBufferSize(), NULL, &pGeometryShader ) );
                    pThis->ToData()->pShaderProgram[ idxStage ] = pGeometryShader;
                }
                break;
            case ShaderStage_Pixel:
                {
                    ID3D11PixelShader* pPixelShader;
                    NN_GFX_CALL_D3D_FUNCTION( pD3dDevice->CreatePixelShader( pBlob->GetBufferPointer(),
                        pBlob->GetBufferSize(), NULL, &pPixelShader ) );
                    pThis->ToData()->pShaderProgram[ idxStage ] = pPixelShader;
                }
                break;
            case ShaderStage_Compute:
                {
                    ID3D11ComputeShader* pComputeShader;
                    NN_GFX_CALL_D3D_FUNCTION( pD3dDevice->CreateComputeShader( pBlob->GetBufferPointer(),
                        pBlob->GetBufferSize(), NULL, &pComputeShader ) );
                    pThis->ToData()->pShaderProgram[ idxStage ] = pComputeShader;
                }
                break;
            default: NN_UNEXPECTED_DEFAULT;
            }

            ID3D11ShaderReflection* pReflector;
            hResult = NN_GFX_CALL_D3D_FUNCTION( D3DReflect( pBlob->GetBufferPointer(),
                pBlob->GetBufferSize(), IID_ID3D11ShaderReflection, (void**) &pReflector ) );
            pThis->ToData()->pShaderReflector[ idxStage ] = pReflector;
            pThis->ToData()->pShaderBlob[ idxStage ] = pBlob;

            if ( FAILED( hResult ) || !IsD3dHandleValid( pThis->ToData()->pShaderProgram[ idxStage ] ) ||
                !IsD3dHandleValid( pThis->ToData()->pShaderReflector[ idxStage ] ) )
            {
                return ShaderInitializeResult_SetupFailed;
            }
        }
    }

    return ShaderInitializeResult_Success;
}

void DeletePrograms( ShaderImpl< Target >* pThis ) NN_NOEXCEPT
{
    NN_SDK_ASSERT_NOT_NULL( pThis );
    for( int idxStage = 0; idxStage < static_cast< int >( ShaderStage_End ); ++idxStage )
    {
        if( IsD3dHandleValid( pThis->ToData()->pShaderProgram[ idxStage ] ) )
        {
            switch ( static_cast< ShaderStage >( idxStage ) )
            {
            case ShaderStage_Vertex:
                {
                    ID3D11VertexShader* pVertexShader = static_cast< ID3D11VertexShader* >(
                        pThis->ToData()->pShaderProgram[ idxStage ] );
                    NN_GFX_CALL_D3D_FUNCTION( pVertexShader->Release() );
                }
                break;
            case ShaderStage_Hull:
                {
                    ID3D11HullShader* pHullShader = static_cast< ID3D11HullShader* >(
                        pThis->ToData()->pShaderProgram[ idxStage ] );
                    NN_GFX_CALL_D3D_FUNCTION( pHullShader->Release() );
                }
                break;
            case ShaderStage_Domain:
                {
                    ID3D11DomainShader* pDomainShader = static_cast< ID3D11DomainShader* >(
                        pThis->ToData()->pShaderProgram[ idxStage ] );
                    NN_GFX_CALL_D3D_FUNCTION( pDomainShader->Release() );
                }
                break;
            case ShaderStage_Geometry:
                {
                    ID3D11GeometryShader* pGeometryShader = static_cast< ID3D11GeometryShader* >(
                        pThis->ToData()->pShaderProgram[ idxStage ] );
                    NN_GFX_CALL_D3D_FUNCTION( pGeometryShader->Release() );
                }
                break;
            case ShaderStage_Pixel:
                {
                    ID3D11PixelShader* pPixelShader = static_cast< ID3D11PixelShader* >(
                        pThis->ToData()->pShaderProgram[ idxStage ] );
                    NN_GFX_CALL_D3D_FUNCTION( pPixelShader->Release() );
                }
                break;
            case ShaderStage_Compute:
                {
                    ID3D11ComputeShader* pComputeShader = static_cast< ID3D11ComputeShader* >(
                        pThis->ToData()->pShaderProgram[ idxStage ] );
                    NN_GFX_CALL_D3D_FUNCTION( pComputeShader->Release() );
                }
                break;
            default: NN_UNEXPECTED_DEFAULT;
            }
            pThis->ToData()->pShaderProgram[ idxStage ] = NULL;

            ID3D11ShaderReflection* pReflector = static_cast< ID3D11ShaderReflection* >(
                pThis->ToData()->pShaderReflector[ idxStage ] );
            NN_GFX_CALL_D3D_FUNCTION( pReflector->Release() );
            pThis->ToData()->pShaderReflector[ idxStage ] = NULL;

            ID3DBlob* pBlob = static_cast< ID3DBlob* >( pThis->ToData()->pShaderBlob[ idxStage ] );
            NN_GFX_CALL_D3D_FUNCTION( pBlob->Release() );
            pThis->ToData()->pShaderBlob[ idxStage ] = NULL;
        }
    }
}

}

size_t ShaderImpl< Target >::GetBinaryCodeAlignment( DeviceImpl< Target >* ) NN_NOEXCEPT
{
    return 1;
}

ShaderImpl< Target >::ShaderImpl() NN_NOEXCEPT
{
    this->state = State_NotInitialized;
}

ShaderImpl< Target >::~ShaderImpl() NN_NOEXCEPT
{
    NN_SDK_ASSERT( this->state == State_NotInitialized || this->flags.GetBit( Flag_Shared ) );
}

ShaderInitializeResult ShaderImpl< Target >::Initialize( DeviceImpl< Target >* pDevice, const InfoType& info ) NN_NOEXCEPT
{
    NN_SDK_REQUIRES( this->state == State_NotInitialized );
    NN_SDK_REQUIRES( info.GetSourceFormat() == ShaderSourceFormat_Hlsl );
    NN_SDK_ASSERT( info.IsSeparationEnabled() || ( info.GetShaderCodePtr(
        ShaderStage_Vertex ) || info.GetShaderCodePtr( ShaderStage_Compute ) ) );

    this->pGfxDevice = pDevice;
    this->flags = info.ToData()->flags;

    for( int idxStage = 0; idxStage < static_cast< int >( ShaderStage_End ); ++idxStage )
    {
        this->pShaderProgram[ idxStage ] = NULL;
    }

    // 現状ソース入力のみ対応しています。
    NN_SDK_ASSERT( info.GetCodeType() == ShaderCodeType_Source );
    ShaderInitializeResult result = InitializeSourceShader( this, pDevice, info );

    if( result != ShaderInitializeResult_Success )
    {
        DeletePrograms( this );
        return result;
    }

    this->flags.SetBit( Flag_Shared, false );
    this->state = State_Initialized;

    return ShaderInitializeResult_Success;
}

void ShaderImpl< Target >::Finalize( DeviceImpl< Target >* pDevice ) NN_NOEXCEPT
{
    NN_SDK_REQUIRES( this->state == State_Initialized );
    NN_SDK_ASSERT( !this->flags.GetBit( Flag_Shared ) );
    NN_UNUSED( pDevice );

    DeletePrograms( this );

    this->state = State_NotInitialized;
}

int ShaderImpl< Target >::GetInterfaceSlot( ShaderStage stage,
    ShaderInterfaceType shaderInterfaceType, const char* pName ) const NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL( pName );
    NN_SDK_REQUIRES( IsInitialized( *this ) );
    NN_UNUSED( shaderInterfaceType );

    int ret = -1;

    ID3D11ShaderReflection* pReflector = static_cast< ID3D11ShaderReflection* >( this->pShaderReflector[ stage ] );
    NN_SDK_ASSERT_NOT_NULL( pReflector );

    switch ( shaderInterfaceType )
    {
    case ShaderInterfaceType_Input:
        {
            D3D11_SHADER_DESC shaderDesc;
            pReflector->GetDesc( &shaderDesc );

            for ( UINT idxInputSlot = 0; idxInputSlot < shaderDesc.InputParameters; ++idxInputSlot )
            {
                D3D11_SIGNATURE_PARAMETER_DESC signatureDesc;
                HRESULT hResult = pReflector->GetInputParameterDesc( idxInputSlot, &signatureDesc );

                if ( SUCCEEDED( hResult ) )
                {
                    if ( strcmp( signatureDesc.SemanticName, pName ) == 0 )
                    {
                        ret = static_cast< int >( idxInputSlot );
                        break;
                    }
                    // index付きの名前のチェック
                    // input parameterの名前の末尾に数字が付いている場合、
                    // D3D11_SIGNATURE_PARAMETER_DESC::SemanticNameは末尾数字を除いた前半の文字列のみを返し
                    // D3D11_SIGNATURE_PARAMETER_DESC::SemanticIndexは末尾数字を返します。
                    size_t nameLength = strlen( pName );
                    size_t semanticNameLength = strlen( signatureDesc.SemanticName );
                    if ( semanticNameLength < nameLength &&
                        strncmp( signatureDesc.SemanticName, pName, semanticNameLength ) == 0 )
                    {
                        int digitCount = 0;
                        for ( int idx = static_cast< int >( semanticNameLength );
                            idx < static_cast< int >( nameLength ); ++idx )
                        {
                            if ( pName[ idx ] >= '0' && pName[ idx ] <= '9' )
                            {
                                ++digitCount;
                            }
                        }

                        if ( digitCount == static_cast< int >( nameLength - semanticNameLength ) )
                        {
                            UINT index = static_cast< UINT >( std::atoi( &pName[ semanticNameLength ] ) );
                            if ( index == signatureDesc.SemanticIndex )
                            {
                                ret = static_cast< int >( idxInputSlot );
                                break;
                            }
                        }
                    }
                }
            }
        }
        break;
    default:
        {
            D3D11_SHADER_INPUT_BIND_DESC desc;

            HRESULT hResult = NN_GFX_CALL_D3D_FUNCTION( pReflector->GetResourceBindingDescByName( pName, &desc ) );

            if ( SUCCEEDED( hResult ) )
            {
                ret = static_cast< int >( desc.BindPoint );
            }
        }
        break;
    }

    return ret;
}

void ShaderImpl< Target >::GetWorkGroupSize( int* pOutWorkGroupSizeX,
    int* pOutWorkGroupSizeY, int* pOutWorkGroupSizeZ ) const NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL( pOutWorkGroupSizeX );
    NN_SDK_REQUIRES_NOT_NULL( pOutWorkGroupSizeY );
    NN_SDK_REQUIRES_NOT_NULL( pOutWorkGroupSizeZ );
    NN_SDK_REQUIRES( IsInitialized( *this ) );
    NN_SDK_REQUIRES( IsD3dHandleValid( this->pShaderReflector[ ShaderStage_Compute ] ) );

    ID3D11ShaderReflection* pReflector = static_cast< ID3D11ShaderReflection* >( this->pShaderReflector[ ShaderStage_Compute ] );

    UINT threadGroupSizeX, threadGroupSizeY, threadGroupSizeZ;
    NN_GFX_CALL_D3D_FUNCTION( pReflector->GetThreadGroupSize( &threadGroupSizeX, &threadGroupSizeY, &threadGroupSizeZ ) );

    *pOutWorkGroupSizeX = static_cast< int >( threadGroupSizeX );
    *pOutWorkGroupSizeY = static_cast< int >( threadGroupSizeY );
    *pOutWorkGroupSizeZ = static_cast< int >( threadGroupSizeZ );
}

}
}
}
