﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include "cdmsc_Driver.h"

namespace nn {
namespace cdmsc {
namespace driver {

DeviceRegistry::DeviceRegistry()
    : m_Index(0)
{
    std::memset(m_Registry, 0, sizeof(m_Registry));
}

DeviceRegistry::~DeviceRegistry()
{
    // do nothing
}

void
DeviceRegistry::Update(uint16_t vid, uint16_t pid, uint32_t capability)
{
    for (auto& record : m_Registry)
    {
        if (record.vid == vid && record.pid == pid)
        {
            record.capability &= capability;
            return;
        }
    }

    Add(vid, pid, capability);
}

uint32_t
DeviceRegistry::Get(uint16_t vid, uint16_t pid)
{
    for (auto& record : m_Registry)
    {
        if (record.vid == vid && record.pid == pid)
        {
            return record.capability;
        }
    }

    Add(vid, pid, DefaultCapability);

    return DefaultCapability;
}

void
DeviceRegistry::Add(uint16_t vid, uint16_t pid, uint32_t capability)
{
    m_Registry[m_Index] = {
        .vid        = vid,
        .pid        = pid,
        .capability = capability
    };

    m_Index = (m_Index + 1) % RegistrySize;
}

Driver::Driver()
    : m_IsMainThreadRunning(false)
    , m_BreakEvent(nn::os::EventClearMode_ManualClear)
    , m_pDeviceAvailableEvent(nullptr)
    , m_DeviceSequence(0)
{
    nn::os::InitializeMultiWait(&m_MultiWait);
    nn::os::InitializeMultiWaitHolder(
        &m_BreakEventHolder, m_BreakEvent.GetBase()
    );
    nn::os::LinkMultiWaitHolder(&m_MultiWait, &m_BreakEventHolder);

    m_UsbDeviceFilter.matchFlags         =
        nn::usb::DeviceFilterMatchFlags_InterfaceClass   |
        nn::usb::DeviceFilterMatchFlags_InterfaceProtocol;
    m_UsbDeviceFilter.bInterfaceClass    = nn::usb::UsbClass_MassStorage;
    m_UsbDeviceFilter.bInterfaceProtocol = MscProtocol_Bbb;

    std::memset(m_pDevice, 0, sizeof(m_pDevice));
}

Driver::~Driver()
{
    nn::os::UnlinkMultiWaitHolder(&m_BreakEventHolder);
    nn::os::FinalizeMultiWaitHolder(&m_BreakEventHolder);
    nn::os::FinalizeMultiWait(&m_MultiWait);
}

Result
Driver::Initialize(nn::os::EventType *pDeviceAvailableEvent,
                   AllocateFunction   alloc,
                   DeallocateFunction dealloc)
{
    Result result = ResultSuccess();

    NN_CDMSC_ABORT_IF_NULL(pDeviceAvailableEvent);

    // Establish as user of host stack
    NN_CDMSC_ABORT_UPON_ERROR(m_UsbHost.Initialize());

    // Register an event for newly attached mass storage devices
    NN_CDMSC_ABORT_UPON_ERROR(
        m_UsbHost.CreateInterfaceAvailableEvent(
            &m_UsbIfAvailableEvent,
            nn::os::EventClearMode_ManualClear,
            0,
            &m_UsbDeviceFilter
        )
    );
    nn::os::InitializeMultiWaitHolder(
        &m_UsbIfAvailableEventHolder, &m_UsbIfAvailableEvent
    );
    nn::os::LinkMultiWaitHolder(&m_MultiWait, &m_UsbIfAvailableEventHolder);

    m_pDeviceAvailableEvent = pDeviceAvailableEvent;
    nn::os::ClearEvent(m_pDeviceAvailableEvent);

    detail::SetAllocator(alloc, dealloc);

    // Start main thread
    NN_CDMSC_ABORT_UPON_ERROR(
        nn::os::CreateThread(
            &m_MainThread, MainThreadEntry, this,
            m_MainThreadStack, sizeof(m_MainThreadStack),
            NN_SYSTEM_THREAD_PRIORITY(cdmsc, MainThread)
        )
    );
    nn::os::SetThreadNamePointer(
        &m_MainThread, NN_SYSTEM_THREAD_NAME(cdmsc, MainThread)
    );
    m_BreakEvent.Clear();
    m_IsMainThreadRunning = true;
    nn::os::StartThread(&m_MainThread);

    return result;
}

Result
Driver::Finalize()
{
    Result result = ResultSuccess();

    // Stop the main thread
    m_IsMainThreadRunning = false;
    m_BreakEvent.Signal();
    nn::os::WaitThread(&m_MainThread);
    nn::os::DestroyThread(&m_MainThread);

    m_pDeviceAvailableEvent = nullptr;

    // Destroy the new device event
    nn::os::UnlinkMultiWaitHolder(&m_UsbIfAvailableEventHolder);
    nn::os::FinalizeMultiWaitHolder(&m_UsbIfAvailableEventHolder);
    m_UsbHost.DestroyInterfaceAvailableEvent(&m_UsbIfAvailableEvent, 0);

    // No longer a host stack user
    NN_CDMSC_ABORT_UPON_ERROR(m_UsbHost.Finalize());

    return result;
}

Result
Driver::Probe(nn::os::EventType *pDetachEvent, UnitProfile *pOutProfile)
{
    Result result = ResultSuccess();
    UnitHandle handle;

    nn::os::ClearEvent(pDetachEvent);

    result = m_LogicalUnitHandleManager.Discover(pDetachEvent, &handle);
    if (result.IsSuccess())
    {
        LogicalUnit* pLu = m_LogicalUnitHandleManager.Acquire(handle);

        if (pLu != nullptr)
        {
            pLu->GetProfile(pOutProfile);
            m_LogicalUnitHandleManager.Release(handle);
        }
        else
        {
            nn::os::ClearEvent(pDetachEvent);
            result = ResultDeviceNotAvailable();
        }
    }

    return result;
}

Result
Driver::Read(void* pOutBuffer, UnitHandle handle,
             uint64_t lba, uint32_t block)
{
    Result  result  = ResultInvalidUnitHandle();
    LogicalUnit* pLu = m_LogicalUnitHandleManager.Acquire(handle);

    if (pLu != nullptr)
    {
        result = pLu->Read(lba, block, pOutBuffer);
        m_LogicalUnitHandleManager.Release(handle);
    }

    return result;
}

Result
Driver::Write(const void* pInBuffer, UnitHandle handle,
              uint64_t lba, uint32_t block)
{
    Result  result  = ResultInvalidUnitHandle();
    LogicalUnit* pLu = m_LogicalUnitHandleManager.Acquire(handle);

    if (pLu != nullptr)
    {
        result = pLu->Write(lba, block, pInBuffer);
        m_LogicalUnitHandleManager.Release(handle);
    }

    return result;
}

Result
Driver::Flush(UnitHandle handle, uint64_t lba, uint32_t block)
{
    Result  result  = ResultInvalidUnitHandle();
    LogicalUnit* pLu = m_LogicalUnitHandleManager.Acquire(handle);

    if (pLu != nullptr)
    {
        result = pLu->Flush(lba, block);
        m_LogicalUnitHandleManager.Release(handle);
    }

    return result;
}

void
Driver::GetCapability(uint16_t vid, uint16_t pid, uint32_t *pOutCapability)
{
    *pOutCapability = m_Registry.Get(vid, pid);
}

void
Driver::UpdateCapability(uint16_t vid, uint16_t pid, uint32_t capability)
{
    m_Registry.Update(vid, pid, capability);
}

void
Driver::RegisterDeviceEvent(nn::os::MultiWaitHolderType& eventHolder)
{
    nn::os::LinkMultiWaitHolder(&m_MultiWait, &eventHolder);
}

void
Driver::UnregisterDeviceEvent(nn::os::MultiWaitHolderType& eventHolder)
{
    nn::os::UnlinkMultiWaitHolder(&eventHolder);
}

Result
Driver::RegisterLogicalUnit(LogicalUnit *pLu)
{
    UnitHandle handle;

    Result result = m_LogicalUnitHandleManager.Register(pLu, &handle);

    if (result.IsSuccess())
    {
        pLu->SetHandle(handle);
        nn::os::SignalEvent(m_pDeviceAvailableEvent);
    }

    return result;
}

void
Driver::UnregisterLogicalUnit(LogicalUnit *pLu)
{
    UnitHandle handle = pLu->GetHandle();

    m_LogicalUnitHandleManager.Unregister(handle);
}

void
Driver::MainThread()
{
    while (m_IsMainThreadRunning)
    {
        nn::os::MultiWaitHolderType *holder = nn::os::WaitAny(&m_MultiWait);

        if (holder == &m_BreakEventHolder)
        {
            // m_MainThreadRun will be checked, breaking out of loop
            m_BreakEvent.Clear();
        }
        else if (holder == &m_UsbIfAvailableEventHolder)
        {
            nn::os::ClearSystemEvent(&m_UsbIfAvailableEvent);
            CreateAttached();
        }
        else // Device detach or watchdog timeout, destroy the device either way
        {
            Device *pDevice = reinterpret_cast<Device*>(
                nn::os::GetMultiWaitHolderUserData(holder)
            );
            DestroyDevice(pDevice);
        }
    }

    for (auto& pDevice : m_pDevice)
    {
        if (pDevice != nullptr)
        {
            DestroyDevice(pDevice);
        }
    }
}

void
Driver::CreateAttached()
{
    int32_t                       ifCount = 0;
    nn::usb::InterfaceQueryOutput outBuffer[DeviceCountMax];

    Result result = m_UsbHost.QueryAvailableInterfaces(
        &ifCount, outBuffer, sizeof(outBuffer), &m_UsbDeviceFilter
    );
    if (result.IsSuccess())
    {
        // Look through what is available
        for (int i = 0; i < ifCount; i++)
        {
            CreateDevice(&outBuffer[i]);
        }
    }
}

void
Driver::CreateDevice(nn::usb::InterfaceQueryOutput *pProfile)
{
    uint32_t  index;

    for (index = 0; index < DeviceCountMax; index++)
    {
        if (m_pDevice[index] == nullptr)
        {
            break;
        }
    }

    if (index >= DeviceCountMax)
    {
        NN_CDMSC_WARN("MSC: Maximum device count (%d) reached.\n", DeviceCountMax);
        return;
    }

    m_pDevice[index] = new Device(m_UsbHost, pProfile, *this, m_DeviceSequence++);
    if (m_pDevice[index] == nullptr)
    {
        NN_CDMSC_WARN("MSC: Unable to create new device: OOM!\n");
        return;
    }

    if (m_pDevice[index]->Initialize().IsFailure())
    {
        delete m_pDevice[index];
        m_pDevice[index] = nullptr;
    }
}

void
Driver::DestroyDevice(Device *pDevice)
{
    uint32_t  index;

    NN_CDMSC_ABORT_IF_NULL(pDevice);

    for (index = 0; index < DeviceCountMax; index++)
    {
        if (m_pDevice[index] == pDevice)
        {
            NN_CDMSC_ABORT_UNLESS_SUCCESS(m_pDevice[index]->Finalize());
            delete m_pDevice[index];
            m_pDevice[index] = nullptr;

            break;
        }
    }

    NN_CDMSC_ABORT_UNLESS(index < DeviceCountMax);
}

} // driver
} // cdmsc
} // nn
