﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

// GetDataSize ConvertFormat IsGpuEncodingAvailable
// ConvertLevelsFormat
// UsesNvtt InitializeNvtt ConvertByNvtt

//=============================================================================
// include
//=============================================================================
#include "Encoder.h"
#include "SimpleFormat.h"
#include "BcFormat.h"
#include "PvrFormat.h"
#include "AstcFormat.h"
#include "../NvttDll/NvttDll.h"

#ifdef _OPENMP
    #include <omp.h>
#endif

#define DLLEXPORT extern "C" __declspec(dllexport)

using namespace std;
using namespace nn::gfx::tool::texenc;

//=============================================================================
//! @brief DLL から関数を取得するためのマクロです。
//=============================================================================
#define PROC_ADDRESS(handle, name)                                   \
*reinterpret_cast<void**>(&name) = GetProcAddress(handle, #name);    \
    if (name == nullptr)                                             \
    {                                                                \
        cerr << "Error: Cannot find function: " << #name << endl;    \
    }                                                                \

//-----------------------------------------------------------------------------
// 無名名前空間を開始します。
namespace
{

//=============================================================================
// variables
//=============================================================================
std::string g_DllFolderPath; //!< この dll ファイルが存在するフォルダのパスです。

bool g_IsOpenMpUsed = false; //!< OpenMP を使用済みなら true です。

bool g_IsNvttInitialized = false; //!< nvtt を初期化済みなら true です（初期化に失敗した場合も true）。
HMODULE g_hNvttDll = nullptr; //!< nvtt 用 DLL のインスタンスハンドルです。
NvttConvertFunction NvttConvert = nullptr; //!< nvtt で画像のフォーマットを変換する関数へのポインタです。
NvttIsGpuEncodingAvailableFunction NvttIsGpuEncodingAvailable = nullptr; //!< nvtt で GPU によるエンコーディングが可能なら true を返す関数へのポインタです。

//-----------------------------------------------------------------------------
//! @brief エンコーダとして nvtt を使用するか判定します。
//!
//! @param[in] dstFormatStr 変換後のフォーマット文字列です。
//! @param[in] qualityStr エンコード品質です。
//! @param[in] encodeFlag エンコードフラグです。
//!
//! @return エンコーダとして nvtt を使用するなら true を返します。
//-----------------------------------------------------------------------------
bool UsesNvtt(
    const std::string& dstFormatStr,
    const std::string& qualityStr,
    const int encodeFlag
)
{
    const bool usesGpu = ((encodeFlag & EncodeFlag_GpuEncoding) != 0);

    if (IsBc123Format(dstFormatStr))
    {
        return (qualityStr.find("dxtex") == std::string::npos &&
            GetQualityLevel(qualityStr) >= 1 &&
            GetEnvVariable("NINTENDO_TEXTURE_CONVERTER_NVTT_BC123") != "0");
    }
    else if (dstFormatStr.find("_bc6") != std::string::npos)
    {
        return true;
    }
    else if (dstFormatStr.find("_bc7") != std::string::npos)
    {
        return true;
    }
    else if (IsAstcFormat(dstFormatStr))
    {
        return (usesGpu || (
            qualityStr.find("evaluation") == std::string::npos &&
            GetEnvVariable("NINTENDO_TEXTURE_CONVERTER_NVTT_ASTC_CPU") != "0")); // 環境変数は非公開機能
    }
    else
    {
        return false;
    }
}

//-----------------------------------------------------------------------------
//! @brief nvtt が未初期化であれば初期化します。
//!
//! @return nvtt を利用可能なら true を返します。
//-----------------------------------------------------------------------------
bool InitializeNvtt()
{
    if (!g_IsNvttInitialized)
    {
        g_IsNvttInitialized = true;
        g_hNvttDll = nullptr;
        NvttConvert = nullptr;
        NvttIsGpuEncodingAvailable = nullptr;

        const std::string nvttDllPath = g_DllFolderPath + "\\TextureConverterNvtt.dll";
        if (FileExists(nvttDllPath))
        {
            g_hNvttDll = LoadLibraryEx(nvttDllPath.c_str(), nullptr, LOAD_WITH_ALTERED_SEARCH_PATH);
            if (g_hNvttDll != nullptr)
            {
                PROC_ADDRESS(g_hNvttDll, NvttConvert);
                PROC_ADDRESS(g_hNvttDll, NvttIsGpuEncodingAvailable);
            }
            else
            {
                cerr << "Error: Cannot load dll: " << nvttDllPath
                     << " (" << GetWindowsLastErrorMessage() << ")" << endl;
            }
        }
    }
    return (NvttConvert != nullptr);
}

//-----------------------------------------------------------------------------
//! @brief nvtt を終了します。
//-----------------------------------------------------------------------------
void FinalizeNvtt()
{
    g_IsNvttInitialized = false;
    if (g_hNvttDll != nullptr)
    {
        FreeLibrary(g_hNvttDll);
        g_hNvttDll = nullptr;
    }
    NvttConvert = nullptr;
    NvttIsGpuEncodingAvailable = nullptr;
}

//-----------------------------------------------------------------------------
//! @brief nvtt でフォーマットを変換します。
//!
//! @param[out] pDst 変換後のデータを格納します。
//! @param[in] pSrc 変換前のデータです。
//! @param[in] dstFormatStr 変換後のフォーマット文字列です。
//! @param[in] srcFormatStr 変換前のフォーマット文字列です。
//! @param[in] qualityStr エンコード品質です。
//! @param[in] encodeFlag エンコードフラグです。
//! @param[in] dimension 次元です。
//! @param[in] imageW 画像の幅です。
//! @param[in] imageH 画像の高さです。
//! @param[in] imageD 画像の奥行きです。
//! @param[in] mipCount ミップマップのレベル数です。
//!
//! @return 処理成功なら true を返します。
//-----------------------------------------------------------------------------
bool ConvertByNvtt(
    void* pDst,
    const void* pSrc,
    const std::string& dstFormatStr,
    const std::string& srcFormatStr,
    const std::string& qualityStr,
    const int encodeFlag,
    const int dimension,
    const int imageW,
    const int imageH,
    const int imageD,
    const int mipCount
)
{
    //cerr << "convert by nvtt: " << qualityStr << ": flag = 0x" << hex << encodeFlag << dec << endl;
    return NvttConvert(pDst, pSrc,
        GetUnicodeFromAnsi(dstFormatStr).c_str(),
        GetUnicodeFromAnsi(srcFormatStr).c_str(),
        GetUnicodeFromAnsi(qualityStr).c_str(),
        encodeFlag, dimension, imageW, imageH, imageD, mipCount);
}

//-----------------------------------------------------------------------------
//! @brief 全ミップマップレベルのフォーマットを変換します。
//!
//! @param[out] pDst 変換後のデータを格納します。
//! @param[in] pSrc 変換前のデータです。
//! @param[in] dstFormatStr 変換後のフォーマット文字列です。
//! @param[in] srcFormatStr 変換前のフォーマット文字列です。
//! @param[in] qualityStr エンコード品質です。
//! @param[in] encodeFlag エンコードフラグです。
//! @param[in] dimension 次元です。
//! @param[in] imageW 画像の幅です。
//! @param[in] imageH 画像の高さです。
//! @param[in] imageD 画像の奥行きです。
//! @param[in] mipCount ミップマップのレベル数です。
//!
//! @return 処理成功なら true を返します。
//-----------------------------------------------------------------------------
bool ConvertLevelsFormat(
    void* pDst,
    const void* pSrc,
    const std::string& dstFormatStr,
    const std::string& srcFormatStr,
    const std::string& qualityStr,
    const int encodeFlag,
    const int dimension,
    const int imageW,
    const int imageH,
    const int imageD,
    const int mipCount
)
{
    //-----------------------------------------------------------------------------
    // 変換前または変換後がプレーンな RGBA フォーマットでなければ変換しません。
    if (!IsPlainRgbaFormat(srcFormatStr) &&
        !IsPlainRgbaFormat(dstFormatStr))
    {
        return false;
    }

    //-----------------------------------------------------------------------------
    // フォーマットの情報を取得します。
    const size_t srcBpp = GetBitsPerPixel(srcFormatStr);
    const size_t dstBpp = GetBitsPerPixel(dstFormatStr);
    int srcMinW = 1;
    int srcMinH = 1;
    int srcMinD = 1;
    GetMinimumWhd(&srcMinW, &srcMinH, &srcMinD, srcFormatStr);
    int dstMinW = 1;
    int dstMinH = 1;
    int dstMinD = 1;
    GetMinimumWhd(&dstMinW, &dstMinH, &dstMinD, dstFormatStr);

    const bool isSrcBc = IsBcFormat(srcFormatStr);
    const bool isDstBc = IsBcFormat(dstFormatStr);
    const bool isSrcEtcEacPvrtc =
        IsEtcFormat(srcFormatStr) || IsEacFormat(srcFormatStr) || IsPvrtcFormat(srcFormatStr);
    const bool isDstEtcEacPvrtc =
        IsEtcFormat(dstFormatStr) || IsEacFormat(dstFormatStr) || IsPvrtcFormat(dstFormatStr);
    const bool isSrcAstc = IsAstcFormat(srcFormatStr);
    const bool isDstAstc = IsAstcFormat(dstFormatStr);

    const bool usesNvtt = UsesNvtt(dstFormatStr, qualityStr, encodeFlag) &&
        InitializeNvtt();
    const bool usesDirectXTex = !usesNvtt && (isSrcBc || isDstBc);
    const bool usesPvrTexLib = (isSrcEtcEacPvrtc || isDstEtcEacPvrtc);
    const bool usesAstcCodec = !usesNvtt && (isSrcAstc || isDstAstc);

    //-----------------------------------------------------------------------------
    // nvtt を使用する場合、全レベルを一括変換します。
    EncTimeMeasure tm;
    bool isSucceeded = true;
    if (usesNvtt && GetEnvVariable("NINTENDO_TEXTURE_CONVERTER_NVTT_BATCH") != "0")
    {
        isSucceeded = ConvertByNvtt(pDst, pSrc, dstFormatStr, srcFormatStr, qualityStr, encodeFlag,
            dimension, imageW, imageH, imageD, mipCount);
        //cerr << "nvtt batch: " << tm.GetMilliSec() << " ms" << endl;
        //NoteTrace("nvtt batch: %fms", tm.GetMilliSec());
        return isSucceeded;
    }

    //-----------------------------------------------------------------------------
    // レベルごとに変換します。
    const uint8_t* pLevelSrc = reinterpret_cast<const uint8_t*>(pSrc);
    uint8_t* pLevelDst = reinterpret_cast<uint8_t*>(pDst);
    for (int level = 0; level < mipCount; ++level)
    {
        //-----------------------------------------------------------------------------
        const int levelW = std::max(imageW >> level, 1);
        const int levelH = std::max(imageH >> level, 1);
        const int levelD = (dimension == ImageDimension_3d) ?
            std::max(imageD >> level, 1) : imageD;

        //-----------------------------------------------------------------------------
        // 各ライブラリでエンコードします。
        if (usesNvtt)
        {
            isSucceeded = ConvertByNvtt(pLevelDst, pLevelSrc, dstFormatStr, srcFormatStr, qualityStr, encodeFlag,
                dimension, levelW, levelH, levelD, 1);
        }
        else if (usesDirectXTex)
        {
            isSucceeded = ConvertByDirectXTex(&g_IsOpenMpUsed,
                pLevelDst, pLevelSrc, dstFormatStr, srcFormatStr, qualityStr, encodeFlag,
                levelW, levelH, levelD);
        }
        else if (usesPvrTexLib)
        {
            isSucceeded = ConvertByPvrTexLib(pLevelDst, pLevelSrc, dstFormatStr, srcFormatStr, qualityStr, encodeFlag,
                levelW, levelH, levelD);
        }
        else if (usesAstcCodec)
        {
            isSucceeded = ConvertByAstcEvaluationCodec(pLevelDst, pLevelSrc, dstFormatStr, srcFormatStr, qualityStr, encodeFlag,
                levelW, levelH, levelD);
        }
        else
        {
            isSucceeded = ConvertSimpleFormat(pLevelDst, pLevelSrc, dstFormatStr, srcFormatStr, encodeFlag,
                levelW, levelH, levelD);
        }

        if (!isSucceeded)
        {
            break;
        }

        pLevelSrc += GetLevelDataSize(levelW, levelH, levelD, isSrcAstc, srcBpp, srcMinW, srcMinH, srcMinD);
        pLevelDst += GetLevelDataSize(levelW, levelH, levelD, isDstAstc, dstBpp, dstMinW, dstMinH, dstMinD);
    }
    //cerr << "convert format: " << tm.GetMilliSec() << " ms" << endl;
    //NoteTrace("convert format: %fms", tm.GetMilliSec());

    return isSucceeded;
}

//-----------------------------------------------------------------------------
// 無名名前空間を終了します。
} // unnamed namespace

//-----------------------------------------------------------------------------
//! @brief 構造化例外の変換関数です。
//-----------------------------------------------------------------------------
void TranslateSe(unsigned int code, struct _EXCEPTION_POINTERS* ep)
{
    throw ep; // 標準 C++ の例外を発生させます。
    ENC_UNUSED_VARIABLE(code);
}

//-----------------------------------------------------------------------------
//! @brief DLL のメイン関数です。
//-----------------------------------------------------------------------------
BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved)
{
    switch (fdwReason)
    {
        case DLL_PROCESS_ATTACH:
            _set_se_translator(TranslateSe); // このスレッドでのみ有効です。

            char dllPath[MAX_PATH];
            GetModuleFileNameA(hinstDLL, dllPath, sizeof(dllPath));
            g_DllFolderPath = GetFolderFromFilePath(dllPath);

            g_IsOpenMpUsed = false;
            break;

        case DLL_THREAD_ATTACH:
            break;

        case DLL_THREAD_DETACH:
            break;

        case DLL_PROCESS_DETACH:
        {
            //cerr << "encoder process detach: " << g_IsOpenMpUsed << endl;
            if (g_IsOpenMpUsed)
            {
                // OpenMP 使用後にすぐに FreeLibrary を呼ぶとアクセス違反が発生する問題に対処
                Sleep(1000);
            }
            FinalizeNvtt();
            break;
        }

        default:
            break;
    }
    ENC_UNUSED_VARIABLE(lpvReserved);
    return TRUE;
}

//-----------------------------------------------------------------------------
//! @brief エンコード後のデータサイズ（バイト数）を取得します。
//!
//! @param[in] format 中間ファイルのフォーマット文字列です。
//! @param[in] dimension 次元です。
//! @param[in] imageW 画像の幅です。
//! @param[in] imageH 画像の高さです。
//! @param[in] imageD 画像の奥行きです。
//! @param[in] mipCount ミップマップのレベル数です。
//!
//! @return エンコード後のデータサイズを返します。
//-----------------------------------------------------------------------------
DLLEXPORT size_t GetDataSize(
    const wchar_t* format,
    const int dimension,
    const int imageW,
    const int imageH,
    const int imageD,
    const int mipCount
)
{
    if (format == nullptr)
    {
        return 0;
    }
    const std::string formatStr = GetAnsiFromUnicode(format);
    const bool isAstc = IsAstcFormat(formatStr);
    const size_t bpp = GetBitsPerPixel(formatStr);
    int minW = 1;
    int minH = 1;
    int minD = 1;
    GetMinimumWhd(&minW, &minH, &minD, formatStr);

    size_t dataSize = 0;
    for (int level = 0; level < mipCount; ++level)
    {
        const int levelW = std::max(imageW >> level, 1);
        const int levelH = std::max(imageH >> level, 1);
        const int levelD = (dimension == ImageDimension_3d) ?
            std::max(imageD >> level, 1) : imageD;
        dataSize += GetLevelDataSize(levelW, levelH, levelD, isAstc, bpp, minW, minH, minD);
    }
    return dataSize;
}

//-----------------------------------------------------------------------------
//! @brief 画像のフォーマットを変換します。
//!
//! @param[out] pDst 変換後のデータを格納します。
//! @param[in] pSrc 変換前のデータです。
//! @param[in] dstFormat 変換後のフォーマット文字列です。
//! @param[in] srcFormat 変換前のフォーマット文字列です。
//! @param[in] quality エンコード品質文字列です。
//! @param[in] encodeFlag エンコードフラグです。
//! @param[in] dimension 次元です。
//! @param[in] imageW 画像の幅です。
//! @param[in] imageH 画像の高さです。
//! @param[in] imageD 画像の奥行きです。
//! @param[in] mipCount ミップマップのレベル数です。
//!
//! @return 処理成功なら true を返します。
//-----------------------------------------------------------------------------
DLLEXPORT bool ConvertFormat(
    void* pDst,
    const void* pSrc,
    const wchar_t* dstFormat,
    const wchar_t* srcFormat,
    const wchar_t* quality,
    const int encodeFlag,
    const int dimension,
    const int imageW,
    const int imageH,
    const int imageD,
    const int mipCount
)
{
    //-----------------------------------------------------------------------------
    // 引数をチェックします。
    if (pDst      == nullptr ||
        pSrc      == nullptr ||
        dstFormat == nullptr ||
        srcFormat == nullptr ||
        quality   == nullptr)
    {
        return false;
    }
    const std::string dstFormatStr = GetAnsiFromUnicode(dstFormat);
    const std::string srcFormatStr = GetAnsiFromUnicode(srcFormat);
    const std::string qualityStr   = GetAnsiFromUnicode(quality);

    //-----------------------------------------------------------------------------
    // 全ミップマップレベルのフォーマットを変換します。
    return ConvertLevelsFormat(pDst, pSrc, dstFormatStr, srcFormatStr,
        qualityStr, encodeFlag, dimension, imageW, imageH, imageD, mipCount);
}

//-----------------------------------------------------------------------------
//! @brief 現在の環境で GPU によるエンコーディングが可能なら true を返します。
//!
//! @return GPU によるエンコーディングが可能なら true を返します。
//-----------------------------------------------------------------------------
DLLEXPORT bool IsGpuEncodingAvailable()
{
    if (InitializeNvtt())
    {
        if (NvttIsGpuEncodingAvailable())
        {
            return true;
        }
    }
    return false;
}

