﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <mutex>
#include <unordered_map>
#include <utility>
#include <iterator>
#include <d3dcompiler.h>
#include <nn/nn_SdkAssert.h>
#include <nn/gfxTool/gfxTool_Error.h>
#include <gfxTool_HlslCrossCompilerDxCompiler.h>
#include <hlslcc.h>

#pragma comment(lib, "d3dcompiler.lib")

namespace nn {
namespace gfxTool {

void AddMacro( std::unordered_map<std::string, std::string>& macros, const std::string& name, const std::string& value )
{
    auto comp = macros.find(name.c_str());
    if (comp == macros.cend())
    {
        macros.insert(std::make_pair(name.c_str(), value.c_str()));
    }
}

void AddMacro(std::unordered_map<std::string, std::string>& macros, const std::string& macro )
{
    // MACRO=value を分解する。
    auto posEq = macro.find_first_of('=');

    if( posEq != std::string::npos )
    {
        AddMacro( macros, macro.substr(0, posEq), macro.data() + posEq + 1);
    }
    else
    {
        AddMacro( macros, macro.c_str(), "1");
    }
}

bool HasByteOrderMark( const char* pSrc, int length )
{
    NN_SDK_ASSERT_NOT_NULL( pSrc );

    if(length < 3)
    {
        return false;
    }

    uint32_t bom = 0xBFBBEF;
    if (memcmp(pSrc, &bom, 3) == 0)
    {
        return true;
    }
    return false;
}

class IncludeFileIo : ID3DInclude
{
public:
    IncludeFileIo() : m_pD3dInclude( this )
    {
        char path[_MAX_PATH] = {};
        _fullpath( path, "", _MAX_PATH );
        m_ExecutedDirectory = path;
        std::replace( m_ExecutedDirectory.begin(), m_ExecutedDirectory.end(), L'\\', L'/' );
    }


    STDMETHOD(Open)(THIS_ D3D_INCLUDE_TYPE IncludeType, LPCSTR pFileName, LPCVOID pParentData, LPCVOID *ppData, UINT *pBytes)
    {
        NN_UNUSED( IncludeType );
        NN_UNUSED( pParentData );
        static std::mutex s_Mutex;
        std::lock_guard<decltype( s_Mutex )> lock( s_Mutex );
        std::string filePath;

        // Search include files
        auto OpenFile = [&pFileName, &filePath]( const std::string& baseDir, const std::string& includeDir )->HANDLE
        {
            HANDLE hFile = INVALID_HANDLE_VALUE;
            if( includeDir.find( ":" ) == std::string::npos )	// relative path
            {
                filePath = baseDir + "/" + includeDir + "/" + pFileName;
                hFile = ::CreateFileA( filePath.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, 0, nullptr );
            }
            else // The case passed full path as include-directory value.
            {
                filePath = includeDir + "/" + pFileName;
                hFile = ::CreateFileA( filePath.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, 0, nullptr );
            }
            return hFile;
        };

        HANDLE hFile = INVALID_HANDLE_VALUE;

        // #include is fullpath.
        if( std::string( pFileName ).find( ":" ) != std::string::npos )
        {
            hFile = ::CreateFileA( pFileName, GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, 0, nullptr );
        }

        std::string parentPath;
        if( pParentData != nullptr )
        {
            for( auto& file : m_FileArray )
            {
                if( &( file.second[0] ) == pParentData )
                {
                    parentPath = std::string( file.first.begin(), file.first.begin() + file.first.find_last_of( '/' ) );
                    break;
                }
            }
        }
        else
        {
            parentPath = m_InitialSourceDirectory;
        }

        // First serach include file relative to a file #include is written.
        bool isDeprecated = false;
        if( hFile == INVALID_HANDLE_VALUE )
        {
            hFile = OpenFile( parentPath, "" );
        }

        // DEPRECATED code block: search again from current dir. This code block should be removed in future.
        if( hFile == INVALID_HANDLE_VALUE )
        {
            hFile = OpenFile( m_ExecutedDirectory, "" );
            if( hFile != INVALID_HANDLE_VALUE )
            {
                isDeprecated = true;
            }
        }

        // serach from include directories passed by --include-directory.
        if( ( hFile == INVALID_HANDLE_VALUE ) || isDeprecated )	// search again from includ-path if isDeprecated = true in order to surpress warning.
        {
            for( auto& includeDir : m_IncludeDirectories )
            {
                auto hFileTemp = OpenFile( m_ExecutedDirectory, includeDir );
                if( hFileTemp != INVALID_HANDLE_VALUE )
                {
                    hFile = hFileTemp;
                    isDeprecated = false;
                    break;
                }
            }
        }
        if( hFile == INVALID_HANDLE_VALUE )
        {
            NN_GFXTOOL_PRINT_ERROR( "Failed to open include file: %s\n", pFileName );
            return D3D11_ERROR_FILE_NOT_FOUND;
        }

        if( isDeprecated )	// DEPRECATED code block.
        {
            NN_GFXTOOL_PRINT_WARNING(
                "DEPRECATED: Include file \"%s\" is opened from current directory( executed directory )."
                " Check if the include file exists under the directories specified by include-directory option."
                " Current directory will be removed from default search path in the future.\n", pFileName );
        }

        auto fileSize = ::GetFileSize( hFile, nullptr );
        std::vector<char> buffer( fileSize );
        DWORD read = 0;
        BOOL result = TRUE;
        if( fileSize > 0 )
        {
            result = ::ReadFile( hFile, &buffer[ 0 ], fileSize, &read, nullptr );
        }
        ::CloseHandle( hFile );
        if( result == FALSE || read != fileSize )
        {
            return D3D11_ERROR_FILE_NOT_FOUND;
        }

        m_FileArray.emplace_back();
        auto& file = m_FileArray.back();

        file.first = filePath;
        int ofsFile( 0 );
        if( (buffer.size() != 0) && HasByteOrderMark( &buffer[0], fileSize ) )
        {
            ofsFile = 3;
        }
        std::move( buffer.begin() + ofsFile, buffer.end(), std::back_inserter( file.second ) );

        *ppData = file.second.size() > 0 ? &file.second[0] : nullptr;
        *pBytes = static_cast<UINT>( file.second.size() );

        return S_OK;
    }

    STDMETHOD(Close)(THIS_ LPCVOID pData)
    {
        // do nothing relying on destructor.
        NN_UNUSED( pData );
        return S_OK;
    }

public:
    ID3DInclude* GetD3dInclude()
    {
        return m_pD3dInclude;
    }

    void SetInitialSourceDirectory( const std::string& initialSourceDirectory )
    {
        m_InitialSourceDirectory = initialSourceDirectory;
        std::replace( m_InitialSourceDirectory.begin(), m_InitialSourceDirectory.end(), L'\\', L'/' );

        // remove last '/'
        if( m_InitialSourceDirectory.c_str()[m_InitialSourceDirectory.length() - 1] == '/' )
        {
            m_InitialSourceDirectory.erase( m_InitialSourceDirectory.begin() + m_InitialSourceDirectory.length() - 1 );
        }
    }

    void AddIncludeDirectory( const std::string& includeDirectory )
    {
        if( std::find( m_IncludeDirectories.begin(), m_IncludeDirectories.end(), includeDirectory ) == m_IncludeDirectories.end() )
        {
            m_IncludeDirectories.emplace_back();
            auto& str = m_IncludeDirectories.back();
            str = includeDirectory;
            std::replace( str.begin(), str.end(), L'\\', L'/' );

            // remove last '/'
            if( str.c_str()[str.length() - 1] == '/' )
            {
                str.erase( str.begin() + str.length() - 1 );
            }
        }
    }

private:
    ID3DInclude* m_pD3dInclude;
    std::string m_InitialSourceDirectory;
    std::string m_ExecutedDirectory;
    std::vector<std::string> m_IncludeDirectories;
    typedef std::pair<std::string, std::vector<char>> File;
    std::vector<File>	m_FileArray;
};

IncludeFileIo g_IncludeFileIo;

void AddIncludeDirectory( const std::string& includeDirectory )
{
    g_IncludeFileIo.AddIncludeDirectory( includeDirectory );
}

void SetInitialSourceDirectory( const std::string& initialSourceDirectory )
{
    g_IncludeFileIo.SetInitialSourceDirectory( initialSourceDirectory );
}

void CompileShader( _In_ LPCSTR srcFileStr, _In_ LPCSTR entryPoint, _In_ LPCSTR profile, _Outptr_ ID3DBlob** blob, _In_ UINT compileOptionFlags, _In_ const D3D_SHADER_MACRO* pShaderMacro )
{
    NN_SDK_ASSERT_NOT_NULL(srcFileStr);
    NN_SDK_ASSERT_NOT_NULL(entryPoint);
    NN_SDK_ASSERT_NOT_NULL(profile);
    NN_SDK_ASSERT_NOT_NULL(blob);

    HANDLE hFile = INVALID_HANDLE_VALUE;
    hFile = ::CreateFileA( srcFileStr, GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, 0, nullptr );
    if( hFile == INVALID_HANDLE_VALUE )
    {
        NN_GFXTOOL_THROW_MSG(nngfxToolResultCode_FailedToLoadFile, "Failed to open source file: %s\n", srcFileStr);
    }

    auto fileSize = ::GetFileSize( hFile, nullptr );
    std::vector<char> buffer( fileSize );
    if( fileSize > 0 )
    {
        DWORD read = 0;
        ::ReadFile( hFile, &buffer[0], fileSize, &read, nullptr );
        if( HasByteOrderMark( &buffer[0], fileSize ) )
        {
            buffer.erase( buffer.begin(), buffer.begin() + 3 );
        }
        ::CloseHandle( hFile );
    }
    else
    {
        ::CloseHandle( hFile );
        NN_GFXTOOL_THROW_MSG(nngfxToolResultCode_FailedToLoadFile, "Source code may be empty: %s\n", srcFileStr);
    }

    ID3DBlob* shaderBlob = nullptr;
    ID3DBlob* errorBlob = nullptr;
    HRESULT hr = D3DCompile( &buffer[0], buffer.size(), srcFileStr, pShaderMacro, g_IncludeFileIo.GetD3dInclude(),
                             entryPoint, profile,
                             compileOptionFlags, 0, &shaderBlob, &errorBlob );

    if ( FAILED(hr) )
    {
        if ( errorBlob )
        {
            NN_GFXTOOL_PRINT_ERROR( "failed compiling shader: error=%08X\n", hr );
            NN_GFXTOOL_PRINT_ERROR( "%s", (char*)errorBlob->GetBufferPointer() );
            errorBlob->Release();
        }

        if (shaderBlob) {
            shaderBlob->Release();
        }
        NN_GFXTOOL_THROW(nngfxToolResultCode_CompileError);
    }

    *blob = shaderBlob;
}

}}
