﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <cstring>
#include <cstdio>
#include <cstdlib>

#include <windows.h>
#include <cassert>      // assert
#include <objbase.h>    // IObjects
#include <algorithm>

#pragma warning( disable : 4458 )
using namespace std;
#include <gdiplus.h>    // This doesn't work unless you use namespace std... AND it causes multiple C4458 warnings

bool silentMode;

#define PRINT(...)          \
if (!silentMode)            \
{                           \
    printf(__VA_ARGS__);    \
}

struct Color
{
    unsigned char b : 8;
    unsigned char g : 8;
    unsigned char r : 8;
    unsigned char a : 8;
};

/// Source from https://msdn.microsoft.com/en-us/library/windows/desktop/dd407288(v=vs.85).aspx

// Writes a bitmap file
//  pszFileName:  Output file name.
//  pBMI:         Bitmap format information (including pallete).
//  cbBMI:        Size of the BITMAPINFOHEADER, including palette, if present.
//  pData:        Pointer to the bitmap bits.
//  cbData        Size of the bitmap, in bytes.

HRESULT WriteBitmap(PCSTR pszFileName, BITMAPINFOHEADER *pBMI, DWORD cbBMI, BYTE *pData, DWORD cbData)
{
    HANDLE hFile = CreateFileA(pszFileName, GENERIC_WRITE, 0, NULL,
        CREATE_ALWAYS, 0, NULL);
    if (hFile == NULL)
    {
        return HRESULT_FROM_WIN32(GetLastError());
    }

    BITMAPFILEHEADER bmf = {};

    bmf.bfType = 'MB';
    bmf.bfSize = cbBMI + cbData + sizeof(bmf);
    bmf.bfOffBits = sizeof(bmf) + cbBMI;

    DWORD cbWritten = 0;
    BOOL result = WriteFile(hFile, &bmf, sizeof(bmf), &cbWritten, NULL);
    if (result)
    {
        result = WriteFile(hFile, pBMI, cbBMI, &cbWritten, NULL);
    }
    if (result)
    {
        result = WriteFile(hFile, pData, cbData, &cbWritten, NULL);
    }

    HRESULT hr = result ? S_OK : HRESULT_FROM_WIN32(GetLastError());

    CloseHandle(hFile);

    return hr;
}

/// Source from https://msdn.microsoft.com/en-us/library/windows/desktop/ms533843(v=vs.85).aspx

// Returns the CLSID for a specific format
//  format:  The format of the CLSID
//  pClsid:  Output CLSID.

int GetEncoderClsid(const WCHAR* format, CLSID* pClsid)
{
    UINT  num = 0;          // number of image encoders
    UINT  size = 0;         // size of the image encoder array in bytes

    Gdiplus::ImageCodecInfo* pImageCodecInfo = NULL;

    Gdiplus::GetImageEncodersSize(&num, &size);
    if (size == 0)
    {
        return -1;  // Failure
    }

    pImageCodecInfo = reinterpret_cast<Gdiplus::ImageCodecInfo*>((malloc(size)));
    if (pImageCodecInfo == NULL)
    {
        return -1;  // Failure
    }

    Gdiplus::GetImageEncoders(num, size, pImageCodecInfo);

    for (UINT j = 0; j < num; ++j)
    {
        if (wcscmp(pImageCodecInfo[j].MimeType, format) == 0)
        {
            *pClsid = pImageCodecInfo[j].Clsid;
            free(pImageCodecInfo);
            return j;  // Success
        }
    }

    free(pImageCodecInfo);
    return -1;  // Failure
}

void ConvertBMPtoPNG(const wchar_t* bmp, const wchar_t* png)
{
    CLSID   encoderClsid;
    Gdiplus::Status  status;
    Gdiplus::Image*   image = new Gdiplus::Image(bmp);

    GetEncoderClsid(L"image/png", &encoderClsid);

    DeleteFileW(png);
    status = image->Save(png, &encoderClsid, NULL);

    if (status != Gdiplus::Ok)
    {
        PRINT("Failed to save image: status = %d\n", status);
    }

    delete image;
}

const wchar_t* Towstring(const char* str)
{
    int len = static_cast<int>(strlen(str));
    int wLen = MultiByteToWideChar(CP_UTF8, 0, str, len, nullptr, 0) + 1; // +1 for null
    wchar_t* wStr = new wchar_t[wLen];
    memset(wStr, 0, sizeof(*wStr) * wLen);
    MultiByteToWideChar(CP_UTF8, 0, str, len, wStr, wLen);
    return wStr;
}

void CreateDiffImage(const char* diffFile, BITMAP* pInput, BITMAP* pGolden, BITMAP* pMask, int threshold)
{
    // A temporary array the same size as the golden image
    Color* diffBits = new Color[pGolden->bmWidth * pGolden->bmHeight];

    Color* inputBits = reinterpret_cast<Color*>(pInput->bmBits);
    Color* goldenBits = reinterpret_cast<Color*>(pGolden->bmBits);
    char* maskBits = reinterpret_cast<char*>(pMask->bmBits);

    for (int i = 0; i < pGolden->bmWidth * pGolden->bmHeight; ++i)
    {
        // The mask is in bits, not bytes
        if (maskBits && !(maskBits[i >> 3] & (0x1 << (i & 0x7))))
        {
            // If the mask is triggered, pure white
            diffBits[i].r = 0xFF;
            diffBits[i].g = 0xFF;
            diffBits[i].b = 0xFF;
            continue;
        }

        if (abs(static_cast<int>(inputBits[i].r) - goldenBits[i].r) > threshold ||
            abs(static_cast<int>(inputBits[i].g) - goldenBits[i].g) > threshold ||
            abs(static_cast<int>(inputBits[i].b) - goldenBits[i].b) > threshold)
        {
            // If it's outside the threshold, make it pure red
            diffBits[i].r = 0xFF;
            diffBits[i].g = 0;
            diffBits[i].b = 0;
        }
        else if (inputBits[i].r != goldenBits[i].r || inputBits[i].g != goldenBits[i].g || inputBits[i].b != goldenBits[i].b)
        {
            // If it doesn't match but is within the threshold, make it faded redscale
            unsigned char average = static_cast<unsigned char>(((static_cast<int>(goldenBits[i].r) + goldenBits[i].g + goldenBits[i].b) / 3));

            diffBits[i].r = 0x80 + (average / 2);
            diffBits[i].g = 0x40;
            diffBits[i].b = 0x40;
        }
        else
        {
            // Otherwise, make it a faded grayscale
            unsigned char average = static_cast<unsigned char>(((static_cast<int>(goldenBits[i].r) + goldenBits[i].g + goldenBits[i].b) / 3));

            diffBits[i].r = diffBits[i].g = diffBits[i].b = 0x80 + (average / 2);
        }
    }

    // Save the diff
    BITMAPINFOHEADER diffHeader = { 0 };

    diffHeader.biSize = sizeof(BITMAPINFOHEADER);
    diffHeader.biBitCount = 32;
    diffHeader.biPlanes = 1;
    diffHeader.biCompression = BI_RGB;
    diffHeader.biWidth = pGolden->bmWidth;
    diffHeader.biHeight = pGolden->bmHeight;
    diffHeader.biSizeImage = ((((diffHeader.biWidth * diffHeader.biBitCount) + 31) & ~31) >> 3) * diffHeader.biHeight;

    WriteBitmap("temp.tmp", &diffHeader, diffHeader.biSize, reinterpret_cast<BYTE*>(diffBits), diffHeader.biSizeImage);
    const wchar_t* wDiffFile = Towstring(diffFile);
    ConvertBMPtoPNG(L"temp.tmp", wDiffFile);
    PRINT("Wrote %s\n", diffFile);

    DeleteFileA("temp.tmp");
    delete[] diffBits;
    delete[] wDiffFile;
}

HBITMAP LoadAnyImage(const char* pFilePath)
{
    const wchar_t* wFilePath = Towstring(pFilePath);

    HBITMAP result = nullptr;
    Gdiplus::Bitmap* bitmap = Gdiplus::Bitmap::FromFile(wFilePath, false);
    if (bitmap)
    {
        bitmap->GetHBITMAP(Gdiplus::Color(255, 255, 255), &result);
        delete bitmap;
    }

    delete[] wFilePath;

    return result;
}

int main(int argc, char** argv)
{
    // Initialize GDI+.
    Gdiplus::GdiplusStartupInput gdiplusStartupInput;
    ULONG_PTR gdiplusToken;
    Gdiplus::GdiplusStartup(&gdiplusToken, &gdiplusStartupInput, NULL);

    const char* inputFile = "input.png";
    const char* goldenFile = "golden.png";
    const char* diffFile = "diff.png";
    const char* maskFile = nullptr;
    const char* thresholdStr = "5";

    silentMode = false;

    // Check to see if we should just take a picture and quit
    for (int i = 1; i < argc; ++i)
    {

#define CHECK_ARG(name)                                     \
if (!_stricmp(argv[i], "-" #name))                          \
{                                                           \
        if (i + 1 >= argc)                                  \
        {                                                   \
            PRINT("Error: expected additional argument after -" #name "\n");    \
            return -1;                                      \
        }                                                   \
        ++i;                                                \
        name = argv[i];                                     \
        continue;                                           \
}

        CHECK_ARG(inputFile);
        CHECK_ARG(goldenFile);
        CHECK_ARG(maskFile);
        CHECK_ARG(diffFile);
        CHECK_ARG(thresholdStr);

        if (!_stricmp(argv[i], "-silent"))
        {
            silentMode = true;
            continue;
        }

#undef CHECK_ARG
    }

    BITMAP inputBitmap = { 0 };
    BITMAP goldenBitmap = { 0 };
    BITMAP maskBitmap = { 0 };
    HBITMAP hBitmap;

    int threshold = atoi(thresholdStr);

    if (maskFile)
    {
        PRINT("Comparing %s with %s with mask %s\n", inputFile, goldenFile, maskFile);
    }
    else
    {
        PRINT("Comparing %s with %s\n", inputFile, goldenFile);
    }

    hBitmap = LoadAnyImage(inputFile);
    if (!hBitmap || !GetObject(hBitmap, sizeof(BITMAP), &inputBitmap))
    {
        PRINT("Failed to load %s\n", inputFile);
        return -1;
    }
    hBitmap = LoadAnyImage(goldenFile);
    if (!hBitmap || !GetObject(hBitmap, sizeof(BITMAP), &goldenBitmap))
    {
        PRINT("Failed to load %s\n", goldenFile);
        return -1;
    }
    if (maskFile)
    {
        hBitmap = reinterpret_cast<HBITMAP>(LoadImageA(NULL, maskFile, IMAGE_BITMAP, 0, 0, LR_LOADFROMFILE | LR_CREATEDIBSECTION));
        if (!hBitmap || !GetObject(hBitmap, sizeof(BITMAP), &maskBitmap))
        {
            PRINT("Failed to load %s\n", maskFile);
            return -1;
        }
    }

    // Just a sanity check
    if (inputBitmap.bmBitsPixel != 32)
    {
        PRINT("Unexpected BBP of input bitmap (expected 32, got %i)\n", inputBitmap.bmBitsPixel);
        return -1;
    }
    if (goldenBitmap.bmBitsPixel != 32)
    {
        PRINT("Unexpected BBP of golden bitmap (expected 32, got %i)\n", goldenBitmap.bmBitsPixel);
        return -1;
    }
    if (maskFile && maskBitmap.bmBitsPixel != 1)
    {
        PRINT("Unexpected BBP of mask bitmap (expected 1, got %i)\n", maskBitmap.bmBitsPixel);
        return -1;
    }

    // Another sanity check
    if (inputBitmap.bmWidth != goldenBitmap.bmWidth)
    {
        PRINT("Image width differs (%i vs %i)\n", inputBitmap.bmWidth, goldenBitmap.bmWidth);
        return 1;
    }
    if (inputBitmap.bmHeight != goldenBitmap.bmHeight)
    {
        PRINT("Image height differs (%i vs %i)\n", inputBitmap.bmHeight, goldenBitmap.bmHeight);
        return 1;
    }
    if (maskFile && (maskBitmap.bmWidth != goldenBitmap.bmWidth || maskBitmap.bmHeight != goldenBitmap.bmHeight))
    {
        PRINT("Mask dimensions do not match golden image ((%i, %i) vs (%i, %i))\n", maskBitmap.bmWidth, maskBitmap.bmHeight, goldenBitmap.bmWidth, goldenBitmap.bmHeight);
        return 1;
    }

    Color* inputBits = reinterpret_cast<Color*>(inputBitmap.bmBits);
    Color* goldenBits = reinterpret_cast<Color*>(goldenBitmap.bmBits);
    char* maskBits = reinterpret_cast<char*>(maskBitmap.bmBits);

    for (int i = 0; i < goldenBitmap.bmWidth * goldenBitmap.bmHeight; ++i)
    {
        // The mask is in bits, not bytes
        if (maskBits && !(maskBits[i >> 3] & (0x1 << (i & 0x7))))
        {
            continue;
        }

        if (abs(static_cast<int>(inputBits[i].r) - goldenBits[i].r) > threshold ||
            abs(static_cast<int>(inputBits[i].g) - goldenBits[i].g) > threshold ||
            abs(static_cast<int>(inputBits[i].b) - goldenBits[i].b) > threshold)
        {
            PRINT("Difference found, creating %s...\n", diffFile);
            CreateDiffImage(diffFile, &inputBitmap, &goldenBitmap, &maskBitmap, threshold);
            return 1;
        }
    }

    Gdiplus::GdiplusShutdown(gdiplusToken);

    PRINT("Images match!\n");
    return 0;
} // NOLINT(impl/function_size)
