﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/os.h>
#include <nn/nn_Abort.h>
#include <nn/nn_SdkLog.h>
#include <nn/nn_Common.h>
#include <nn/nn_Log.h>
#include <nnt.h>

#include <nn/usb/usb_Host.h>
#include <nn/cdmsc.h>

#include "../Common/CdmscTestUtil.h"

#ifndef ASSERT_NO_FATAL_FAILURE
#define ASSERT_NO_FATAL_FAILURE(stmt)           \
    do {                                        \
        stmt;                                   \
        if (::testing::Test::HasFatalFailure()) \
        {                                       \
            FAIL();                             \
        }                                       \
    } while (0)
#endif

#define CDMSC_TEST_CASE_P(testCaseName, dataSize, xferSize)             \
    typedef CdmscDeviceTest testCaseName;                               \
    TEST_P(testCaseName, VariousSizes)                                  \
    {                                                                   \
        auto perXferSize = GetParam();                                  \
        ASSERT_NO_FATAL_FAILURE(                                        \
            Run##testCaseName(                                          \
                &g_Profile, dataSize, perXferSize                       \
            )                                                           \
        );                                                              \
    }                                                                   \
    INSTANTIATE_TEST_CASE_P(                                            \
        Cdmsc,                                                          \
        testCaseName,                                                   \
        ::testing::ValuesIn(xferSize)                                   \
    )

namespace nnt {
namespace usb {
namespace cdmsc {

nn::os::Event           g_AttachEvent(nn::os::EventClearMode_AutoClear);
nn::os::Event           g_DetachEvent(nn::os::EventClearMode_AutoClear);
nn::cdmsc::UnitProfile  g_Profile;

const auto WaitSecondsForDevice = nn::TimeSpan::FromSeconds(30);

const int  DataSize   = 1024 * 1024 * 256;
const int  XferSize[] = {
    1024 * 64,
    1024 * 128,
    1024 * 512,
    1024 * 1024,
    1024 * 1024 * 2,
    1024 * 1024 * 4,
    1024 * 1024 * 8,
    nn::usb::HsLimitMaxUrbTransferSize
};

int g_PatternSeed = 1;

NN_USB_DMA_ALIGN uint8_t g_WriteData[DataSize];
NN_USB_DMA_ALIGN uint8_t g_ReadData[DataSize];

class CdmscDeviceTest : public ::testing::TestWithParam<int>
{
public:
    static void SetUpTestCase();
    static void TearDownTestCase();

    virtual void SetUp();
    virtual void TearDown();
};

void CdmscDeviceTest::SetUpTestCase()
{
    NNT_ASSERT_RESULT_SUCCESS(
        nn::cdmsc::Initialize(
            g_AttachEvent.GetBase(),
            std::aligned_alloc,
            [](void *p, size_t size) {
                NN_UNUSED(size);
                free(p);
            }
        )
    );

    ASSERT_TRUE(
        g_AttachEvent.TimedWait(WaitSecondsForDevice)
    ) << "Device was not Found.";

    NNT_ASSERT_RESULT_SUCCESS(
        nn::cdmsc::Probe(g_DetachEvent.GetBase(), &g_Profile)
    );
}

void CdmscDeviceTest::TearDownTestCase()
{
    NNT_ASSERT_RESULT_SUCCESS(nn::cdmsc::Finalize());
}

void CdmscDeviceTest::SetUp()
{
    ASSERT_NE(g_Profile.handle, 0) << "Device enumeration failure.";
    ASSERT_NE(g_Profile.blockCount, 0) << "Device block count cannot be zero.";
    ASSERT_EQ(
        GetParam() % g_Profile.blockSize, 0
    ) << "Device block size is not compatible with this test.";
    ASSERT_FALSE(g_Profile.isWriteProtected) << "Device is write protected.";
}

void CdmscDeviceTest::TearDown()
{
}

static uint64_t CalculateLbaNoWrapping(uint64_t startLba,
                                       uint64_t blocksToTransfer,
                                       uint64_t blockCount)
{
    // Calculate next starting lba without wrapping
    if (startLba + blocksToTransfer > blockCount && startLba <= blockCount)
    {
        startLba = blockCount - blocksToTransfer;
    }

    return startLba;
}

static void RunWrite(nn::cdmsc::UnitProfile   *profile,
                     uint32_t                  blocksToTransfer,
                     int                       blocksPerTransfer,
                     uint64_t                  lba,
                     bool                      isRandom)
{
    uint8_t* pWriteData = g_WriteData;
    uint64_t currentLba = lba;

    while (blocksToTransfer > 0)
    {
        if (isRandom)
        {
            currentLba = rand() % profile->blockCount;
            currentLba = CalculateLbaNoWrapping(currentLba, blocksToTransfer, profile->blockCount);
        }

        if (blocksToTransfer < blocksPerTransfer)
        {
            blocksPerTransfer = blocksToTransfer;
        }

        NNT_ASSERT_RESULT_SUCCESS(
            nn::cdmsc::Write(pWriteData, profile->handle, currentLba, blocksPerTransfer)
        );

        blocksToTransfer -= blocksPerTransfer;
        currentLba       += blocksPerTransfer;
        pWriteData       += blocksPerTransfer * profile->blockSize;
    }
}

static void RunRead(nn::cdmsc::UnitProfile      *profile,
                    uint32_t                     blocksToReceive,
                    int                          blocksPerTransfer,
                    uint64_t                     lba,
                    bool                         isRandom)
{
    uint8_t* pReadData  = g_ReadData;
    uint64_t currentLba = lba;

    while (blocksToReceive > 0)
    {
        if (isRandom)
        {
            currentLba = rand() % profile->blockCount;
            currentLba = CalculateLbaNoWrapping(currentLba, blocksToReceive, profile->blockCount);
        }

        if (blocksToReceive < blocksPerTransfer)
        {
            blocksPerTransfer = blocksToReceive;
        }

        NNT_ASSERT_RESULT_SUCCESS(
            nn::cdmsc::Read(pReadData, profile->handle, currentLba, blocksPerTransfer)
        );

        blocksToReceive -= blocksPerTransfer;
        currentLba      += blocksPerTransfer;
        pReadData       += (blocksPerTransfer * profile->blockSize);
    }
}

static void RunSequentialWriteAndReadTest(nn::cdmsc::UnitProfile *profile,
                                          int                     transferSize,
                                          int                     blockSizePerTransfer)
{
    int      iterations = 3;
    uint32_t blocksToTransfer = 1;
    uint32_t blocksToReceive  = 1;
    uint64_t startLba = 0;
    int      blocksPerTransfer = 1;

    nn::os::Tick startTick, stopTick, totalTicks;
    int64_t microseconds = 0, averageWriteSum = 0, averageReadSum = 0;

    ASSERT_LE(transferSize, DataSize);

    if (transferSize > profile->blockSize * profile->blockCount)
    {
        NN_LOG("Warning: Device capacity is less than transfer size of %d. "
               "Transfer size set to max capacity of %d\n",
               transferSize, profile->blockSize * profile->blockCount);
        transferSize = profile->blockSize * profile->blockCount;
    }

    if (transferSize > profile->blockSize)
    {
        blocksToTransfer = (transferSize + profile->blockSize - 1) / profile->blockSize;
        blocksToReceive = blocksToTransfer;
    }

    blocksPerTransfer = blockSizePerTransfer / profile->blockSize;

    // Test low, middle and high logical block addresses and take average
    for (int i = 0; i < iterations; i++)
    {
        // Create data
        MakeGaloisPattern(g_WriteData, transferSize, g_PatternSeed);

        startTick = nn::os::GetSystemTick();
        ASSERT_NO_FATAL_FAILURE(
            RunWrite(profile, blocksToTransfer, blocksPerTransfer, startLba, false)
        );
        stopTick = nn::os::GetSystemTick();

        totalTicks       = stopTick - startTick;
        microseconds     = totalTicks.ToTimeSpan().GetMicroSeconds();
        averageWriteSum += microseconds;

        NN_LOG(
            "Write LBA-%-10llu / %6.2fMB / %6.2fKB / %6.2fs / %6.2fMBps\n",
            startLba,
            transferSize / 1024.0 / 1024.0,
            blockSizePerTransfer / 1024.0,
            microseconds / 1000000.0,
            transferSize / (microseconds / 1000000.0) / 1024.0 / 1024.0
        );

        // Check data
        startTick = nn::os::GetSystemTick();
        ASSERT_NO_FATAL_FAILURE(
            RunRead(profile, blocksToReceive, blocksPerTransfer, startLba, false)
        );
        stopTick = nn::os::GetSystemTick();

        totalTicks      = stopTick - startTick;
        microseconds    = totalTicks.ToTimeSpan().GetMicroSeconds();
        averageReadSum += microseconds;

        NN_LOG(
            "Read  LBA-%-10llu / %6.2fMB / %6.2fKB / %6.2fs / %6.2fMBps\n",
            startLba,
            transferSize / 1024.0 / 1024.0,
            blockSizePerTransfer / 1024.0,
            microseconds / 1000000.0,
            transferSize / (microseconds / 1000000.0) / 1024.0 / 1024.0
        );

        ASSERT_TRUE(
            CheckGaloisPattern(g_ReadData, transferSize, g_PatternSeed++)
        ) << "Data corrupted";

        // Calculate next starting lba without wrapping
        startLba += profile->blockCount / 2;
        startLba = CalculateLbaNoWrapping(startLba, blocksToTransfer, profile->blockCount);
    }

    // Calculate average
    NN_LOG(
        "Write Average %6.2fMB / %6.2fKB / %6.2fMBps\n",
        transferSize / 1024.0 / 1024.0,
        blockSizePerTransfer / 1024.0,
        transferSize * iterations / (averageWriteSum / 1000000.0) / 1024.0 / 1024.0
    );
    NN_LOG(
        "Read  Average %6.2fMB / %6.2fKB / %6.2fMBps\n",
        transferSize / 1024.0 / 1024.0,
        blockSizePerTransfer / 1024.0,
        transferSize * iterations / (averageReadSum  / 1000000.0) / 1024.0 / 1024.0
    );
}

static void RunRandomWriteAndReadTest(nn::cdmsc::UnitProfile *profile,
                                      int                     transferSize,
                                      int                     blockSizePerTransfer)
{
    uint32_t blocksToTransfer = 1;
    uint32_t blocksToReceive = 1;
    int      blocksPerTransfer = 1;
    int64_t  microseconds = 0;

    nn::os::Tick startTick, stopTick, totalTicks;

    srand(time(nullptr));

    ASSERT_LE(transferSize, DataSize);

    if (transferSize > profile->blockSize * profile->blockCount)
    {
        NN_LOG("Warning: Device capacity is less than transfer size of %d. "
               "Transfer size set to max capacity of %d\n",
               transferSize, profile->blockSize * profile->blockCount);
        transferSize = profile->blockSize * profile->blockCount;
    }
    if (transferSize > profile->blockSize)
    {
        blocksToTransfer = (transferSize + profile->blockSize - 1) / profile->blockSize;
        blocksToReceive = blocksToTransfer;
    }

    blocksPerTransfer = blockSizePerTransfer / profile->blockSize;

    startTick = nn::os::GetSystemTick();
    ASSERT_NO_FATAL_FAILURE(
        RunWrite(profile, blocksToTransfer, blocksPerTransfer, 0, true)
    );
    stopTick = nn::os::GetSystemTick();

    totalTicks   = stopTick - startTick;
    microseconds = totalTicks.ToTimeSpan().GetMicroSeconds();

    NN_LOG(
        "Random Write %6.2fMB / %6.2fKB / %6.2fs / %6.2fMBps\n",
        transferSize / 1024.0 / 1024.0,
        blockSizePerTransfer / 1024.0,
        microseconds / 1000000.0,
        transferSize / (microseconds / 1000000.0) / 1024.0 / 1024.0
    );

    startTick = nn::os::GetSystemTick();
    ASSERT_NO_FATAL_FAILURE(
        RunRead(profile, blocksToReceive, blocksPerTransfer, 0, true)
    );
    stopTick = nn::os::GetSystemTick();

    totalTicks   = stopTick - startTick;
    microseconds = totalTicks.ToTimeSpan().GetMicroSeconds();
    NN_LOG(
        "Random Read  %6.2fMB / %6.2fKB / %6.2fs / %6.2fMBps\n",
        transferSize / 1024.0 / 1024.0,
        blockSizePerTransfer / 1024.0,
        microseconds / 1000000.0,
        transferSize / (microseconds / 1000000.0) / 1024.0 / 1024.0
    );
}

static void RunMixedWriteAndReadTest(nn::cdmsc::UnitProfile *profile,
                                     int                     transferSize,
                                     int                     blockSizePerTransfer)
{
    const int ReadPercent = 70;
    int blocksToTransfer  = 1;
    int blocksPerTransfer = 1;
    int64_t readSum = 0, writeSum = 0;
    int readSize = 0, writeSize = 0;

    nn::os::Tick startTick, stopTick, totalTicks;

    srand(time(nullptr));

    ASSERT_LE(transferSize, DataSize);
    ASSERT_GT(profile->blockCount, 0);

    if (transferSize > profile->blockSize * profile->blockCount)
    {
        NN_LOG("Warning: Device capacity is less than transfer size of %d. "
               "Transfer size set to max capacity of %d\n",
               transferSize, profile->blockSize * profile->blockCount);
        transferSize = profile->blockSize * profile->blockCount;
    }
    if (transferSize > profile->blockSize)
    {
        blocksToTransfer = (transferSize + profile->blockSize - 1) / profile->blockSize;
    }

    blocksPerTransfer = blockSizePerTransfer / profile->blockSize;

    while (blocksToTransfer > 0)
    {
        bool isRead = ((rand() % 100 + 1) <= ReadPercent);

        if (blocksToTransfer < blocksPerTransfer)
        {
            blocksPerTransfer = blocksToTransfer;
        }

        if (isRead)
        {
            startTick = nn::os::GetSystemTick();
            ASSERT_NO_FATAL_FAILURE(
                RunRead(profile, blocksPerTransfer, blocksPerTransfer, 0, true)
            );
            stopTick = nn::os::GetSystemTick();

            totalTicks   = stopTick - startTick;
            readSum     += totalTicks.ToTimeSpan().GetMicroSeconds();
            readSize    += blocksPerTransfer * profile->blockSize;
        }
        else
        {
            startTick = nn::os::GetSystemTick();
            ASSERT_NO_FATAL_FAILURE(
                RunWrite(profile, blocksPerTransfer, blocksPerTransfer, 0, true)
            );
            stopTick = nn::os::GetSystemTick();

            totalTicks   = stopTick - startTick;
            writeSum    += totalTicks.ToTimeSpan().GetMicroSeconds();
            writeSize   += blocksPerTransfer * profile->blockSize;
        }

        blocksToTransfer -= blocksPerTransfer;
    }

    NN_LOG(
        "Mixed Write %6.2fMB / %6.2fKB / %6.2fs / %6.2fMBps\n",
        writeSize / 1024.0 / 1024.0,
        blockSizePerTransfer / 1024.0,
        writeSum / 1000000.0,
        writeSize / (writeSum / 1000000.0) / 1024.0 / 1024.0
    );
    NN_LOG(
        "Mixed Read  %6.2fMB / %6.2fKB / %6.2fs / %6.2fMBps\n",
        readSize / 1024.0 / 1024.0,
        blockSizePerTransfer / 1024.0,
        readSum / 1000000.0,
        readSize / (readSum / 1000000.0) / 1024.0 / 1024.0
    );
}

CDMSC_TEST_CASE_P(SequentialWriteAndReadTest, DataSize, XferSize);
CDMSC_TEST_CASE_P(RandomWriteAndReadTest,     DataSize, XferSize);
CDMSC_TEST_CASE_P(MixedWriteAndReadTest,      DataSize, XferSize);


} // cdmsc
} // usb
} // nnt

extern "C" void nnMain()
{
    int    argc = ::nnt::GetHostArgc();
    char** argv = ::nnt::GetHostArgv();

    ::testing::InitGoogleTest(&argc, argv);

    const int exitCode = RUN_ALL_TESTS();

    ::nnt::Exit(exitCode);
}
