﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <algorithm>
#include <vector>
#include <locale>

#include <nn/util/util_StringView.h>

#include <nn/gfxTool/gfxTool_Util.h>

#include <gfxTool_SimplePreprocess.h>

namespace nn {
namespace gfxTool {

Custom< std::string >::Type SimplePreprocess( const nn::util::string_view* pSource,
    PreprocessorDirectiveCallbackBase* pCallback )
{
    typedef nn::util::string_view SourceStringType;

    if( pSource == nullptr )
    {
        return Custom< std::string >::Type();
    }

    Custom< std::string >::Type ret;
    auto& source = *pSource;
    SourceStringType::size_type substrStart = 0;
    SourceStringType::size_type start = 0;
    for( ; ; )
    {
        // コメントを飛ばしながら # を検索
        auto sharp = source.find_first_of( "#/", start );
        if( sharp == SourceStringType::npos )
        {
            break;
        }
        else if( source[ sharp ] == '/' )
        {
            if( source[ sharp + 1 ] == '*' )
            {
                auto end = source.find( "*/", sharp + 2 );
                if( end == SourceStringType::npos )
                {
                    break;
                }
                start = end + 2;
            }
            else if( source[ sharp + 1 ] == '/' )
            {
                auto end = source.find_first_of( "\r\n", sharp + 2 );
                if( end == SourceStringType::npos )
                {
                    break;
                }
                start = end + 1;
            }
            else
            {
                start = sharp + 1;
            }
            continue;
        }
        start = sharp + 1;

        // 行頭
        auto head = source.find_last_of( "\r\n", sharp );
        auto headChar = source.find_first_not_of( " \t\n", head == SourceStringType::npos ? 0 : head + 1 );
        if( headChar != sharp )
        {
            continue;
        }

        // 末尾
        auto tail = sharp;
        while( ( tail = source.find_first_of( "\r\n", tail + 1 ) ) != SourceStringType::npos )
        {
            if( source[ tail - 1 ] != '\\' )
            {
                break;
            }
        }
        if( tail == SourceStringType::npos )
        {
            tail = source.length() - 1;
        }

        // Cスタイルコメントと改行を消しておく
        Custom< std::string >::Type directive;
        directive.reserve( tail - sharp + 1 );
        for( const char* ch = &source[ sharp ], *end = &source[ tail ]; ch < end; ++ch)
        {
            if( *ch == '/' && *( ch + 1 ) == '*' )
            {
                if( ( ch = strstr(ch, "*/") ) == nullptr )
                {
                    NN_GFXTOOL_THROW_MSG( nngfxToolResultCode_SyntaxError,
                        "Failed to detect the end of comment.\n%s", directive.data() );
                }
                ch += 2;
            }
            else if( *ch == '\\' || *ch == '\n' || *ch == '\r' )
            {
                directive += ' ';
            }
            else
            {
                directive += *ch;
            }
        }

        auto valueStart = directive.find_first_not_of( " \t", 1 );
        if( valueStart == Custom< std::string >::Type::npos )
        {
            NN_GFXTOOL_THROW_MSG( nngfxToolResultCode_SyntaxError,
                "Invalid preprocessor directive." );
        }
        nn::util::string_view value( directive.data() + valueStart );
        Custom< std::string >::Type result;
        if( pCallback && ( *pCallback )( &result, &value ) )
        {
            ret.append( &source[ substrStart ], sharp - substrStart );
            ret.append( result );
            substrStart = tail;
        }
    }
    ret.append( &source[ substrStart ], source.length() - substrStart );

    return std::move( ret );
} // NOLINT

template< nngfxToolShaderCompilerShaderSourceFormat SourceFormat >
Custom< std::string >::Type CreateVariationBufferSource(
    nn::util::string_view* pOutBeforeVariationBuffer, nn::util::string_view* pOutAfterVariationBuffer,
    const nn::util::string_view* pSource, const nngfxToolShaderCompilerVariationDefinition* pDefinition )
{
    // 暫定版
    // TOOD: ちゃんとプリプロセスをかけながら解析する
    if( pSource == nullptr )
    {
        return Custom< std::string >::Type();
    }
    auto& source = *pSource;
    std::locale locale( "C" );
    if( pOutBeforeVariationBuffer )
    {
        *pOutBeforeVariationBuffer = "";
    }
    if( pOutAfterVariationBuffer )
    {
        *pOutAfterVariationBuffer = nn::util::string_view( source.data(), source.size() );
    }
    if( pDefinition == nullptr )
    {
        return Custom< std::string >::Type();
    }
    auto& definition = *pDefinition;
    nn::util::string_view strUniform;

    if( NN_STATIC_CONDITION( SourceFormat == nngfxToolShaderCompilerShaderSourceFormat_Glsl ) )
    {
        strUniform = nn::util::string_view( "uniform" );
    }
    else if( NN_STATIC_CONDITION( SourceFormat == nngfxToolShaderCompilerShaderSourceFormat_Hlsl ) )
    {
        strUniform = nn::util::string_view( "cbuffer" );
    }

    for( nn::util::string_view::size_type uniform = source.find( strUniform.data(), 0 );
        uniform != nn::util::string_view::npos; uniform = source.find(
        strUniform.data(), uniform + strUniform.length() ) )
    {
        if( !std::isblank( source[ uniform - 1 ], locale ) ||
            !std::isblank( source[ uniform + strUniform.length() ], locale ) )
        {
            continue;
        }
        auto name = source.find_first_not_of( " \t\r\n", uniform + strUniform.length() );
        auto nameEnd = name + definition.variationConstantBufferName.length;
        if( name == nn::util::string_view::npos || std::strncmp( &source[ name ],
            definition.variationConstantBufferName.pValue,
            definition.variationConstantBufferName.length ) != 0 || std::isalpha( source[ nameEnd ],
            locale ) || std::isdigit( source[ nameEnd ], locale ) || source[ nameEnd ] == '_' )
        {
            continue;
        }
        auto bufferStart = source.find( '{', nameEnd );
        if( bufferStart == nn::util::string_view::npos )
        {
            continue;
        }
        auto bufferEnd = nn::util::string_view::npos;
        int level = 1;
        for( auto pos = source.find_first_of( "{}", bufferStart + 1 );
            pos != nn::util::string_view::npos; pos = source.find_first_of( "{}", pos ) )
        {
            level += source[ pos ] == '{' ? 1 : -1;
            if( level <= 0 )
            {
                bufferEnd = pos;
                break;
            }
        }
        if( bufferEnd == nn::util::string_view::npos )
        {
            continue;
        }

        const char* const pUnusedString = "_nngfx_unused";
        Custom< std::string >::Type buffer( &source[ bufferStart ], bufferEnd - bufferStart );
        for( int idxVariationConstant = 0, variationConstantCount = NumericCastAuto(
            definition.variationConstantDefinitionCount ); idxVariationConstant
            < variationConstantCount; ++ idxVariationConstant )
        {
            auto& variationConstant = definition.pVariationConstantDefinitionArray[ idxVariationConstant ];
            for( auto pos = buffer.find( variationConstant.name.pValue, 0, variationConstant.name.length );
                pos != nn::util::string_view::npos; pos = buffer.find( variationConstant.name.pValue,
                pos + variationConstant.name.length, variationConstant.name.length ) )
            {
                auto nameEndIndex = pos + variationConstant.name.length;
                if( ( !std::isalpha( buffer[ pos - 1 ], locale ) && buffer[ pos - 1 ] != '_'  ) &&
                    ( !std::isalpha( buffer[ nameEndIndex], locale ) && !std::isdigit( buffer[ nameEndIndex], locale )
                    && buffer[ nameEndIndex] != '_' ) )
                {
                    Custom< std::string >::Type replace(
                        variationConstant.name.pValue, variationConstant.name.length );
                    replace.append( pUnusedString );
                    buffer.replace( pos, nameEndIndex - pos, replace.data() );
                    break;
                }
            }
        }

        if( pOutBeforeVariationBuffer )
        {
            *pOutBeforeVariationBuffer = nn::util::string_view( source.data(), bufferStart );
        }
        if( pOutAfterVariationBuffer )
        {
            *pOutAfterVariationBuffer = nn::util::string_view(
                &source[ bufferEnd ], source.length() - bufferEnd );
        }
        return std::move( buffer );
    }

    return Custom< std::string >::Type();
}

template
Custom< std::string >::Type CreateVariationBufferSource<
    nngfxToolShaderCompilerShaderSourceFormat_Glsl >(
    nn::util::string_view* pOutBeforeVariationBuffer, nn::util::string_view* pOutAfterVariationBuffer,
    const nn::util::string_view* pSource, const nngfxToolShaderCompilerVariationDefinition* pDefinition );

template
Custom< std::string >::Type CreateVariationBufferSource<
    nngfxToolShaderCompilerShaderSourceFormat_Hlsl >(
    nn::util::string_view* pOutBeforeVariationBuffer, nn::util::string_view* pOutAfterVariationBuffer,
    const nn::util::string_view* pSource, const nngfxToolShaderCompilerVariationDefinition* pDefinition );

bool ExpandIncludeCallback::operator()( Custom< std::string >::Type* pOut,
    const nn::util::string_view* pDirective )
{
    // TODO 循環参照チェック
    nn::util::string_view strInclude( "include" );
    if( std::strncmp( pDirective->data(), strInclude.data(), strInclude.length() ) == 0 )
    {
        auto valueStart = pDirective->find_first_of( "\"<" );
        if( valueStart == Custom< std::string >::Type::npos )
        {
            return false;
        }
        ++valueStart;
        auto valueEnd = pDirective->find_first_of( "\">", valueStart );
        if( valueEnd == Custom< std::string >::Type::npos )
        {
            return false;
        }
        Custom< std::string >::Type filename( pDirective->data() + valueStart, valueEnd - valueStart );
        void* pExpanded;
        size_t size;
        if( m_pCallback == nullptr || !m_pCallback( &pExpanded, &size, filename.data(), m_pCallbackParam ) )
        {
            NN_GFXTOOL_THROW_MSG( nngfxToolResultCode_FailedToLoadFile,
                "Failed to load %s.", filename.data() );
        }
        // TODO
        auto pStart = static_cast< const char* >( pExpanded );
        if( IsBom( pStart ) )
        {
            pStart += 3;
            size -= 3;
        }
        nn::util::string_view expandedSource( pStart, size );
        *pOut = SimplePreprocess( &expandedSource, m_pSimplePreprocessCallback );
        return true;
    }

    return false;
}

bool CommentOutMacroCallback::operator()( Custom< std::string >::Type* pOut,
    const nn::util::string_view* pDirective )
{
    nn::util::string_view strDefine( "define" );
    if( std::strncmp( pDirective->data(), strDefine.data(), strDefine.length() ) == 0 )
    {
        auto start = strDefine.length() + 1;
        auto end = pDirective->find_first_of( " \t", start );
        if( end == nn::util::string_view::npos )
        {
            end = pDirective->length();
        }
        auto length = end - start;
        for( auto& macro : *m_pMacros )
        {
            if( length == macro.length() && std::strncmp(
                pDirective->data() + start, macro.c_str(), length ) == 0 )
            {
                *pOut = "// commented out by gfx: ";
                pOut->append( "#" );
                pOut->append( pDirective->data() );
                return true;
            }
        }
    }

    return false;
}

bool PreprocessorDirectiveMultiCallback::operator()( Custom< std::string >::Type* pOut,
    const nn::util::string_view* pDirective )
{
    for( auto& callback : *m_pCallbacks )
    {
        if( ( *callback )( pOut, pDirective ) )
        {
            return true;
        }
    }
    return false;
}

}
}
