﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <mutex>
#include <nn/nn_Abort.h>
#include <nn/nn_SdkAssert.h>
#include <nn/os/os_Mutex.h>
#include <nn/result/result_HandlingUtility.h>
#include <nn/audio/audio_Result.h>

#include "audio_AudioOutClientImpl.win32.h"

#include "audio_Deleter.win32.h"
#include "audio_ScopedHandle.win32.h"
#include "audio_CharEncodingConverter-os.win32.h"

namespace nn {
namespace audio {

namespace {

class DeviceEnumerator
{
public:
    DeviceEnumerator()
        : m_Mutex(false),
          m_RefCount(0)
    {
    }

    ~DeviceEnumerator()
    {
    }

    IMMDeviceEnumerator* Get() const
    {
        return m_pEnumerator;
    }

    void Initialize()
    {
        std::lock_guard<nn::os::Mutex> lock(m_Mutex);
        ++m_RefCount;
        if( m_RefCount == 1 )
        {
            HRESULT hr = CoInitializeEx(nullptr, COINIT_MULTITHREADED);
            NN_ABORT_UNLESS(SUCCEEDED(hr) || hr == RPC_E_CHANGED_MODE);

            NN_ABORT_UNLESS(
                SUCCEEDED(
                    CoCreateInstance(
                        __uuidof(MMDeviceEnumerator),
                        nullptr,
                        CLSCTX_INPROC_SERVER,
                        IID_PPV_ARGS(&m_pEnumerator))));
        }
    }

    void Finalize()
    {
        std::lock_guard<nn::os::Mutex> lock(m_Mutex);
        NN_SDK_ASSERT(m_RefCount > 0);
        --m_RefCount;
        if( m_RefCount == 0 )
        {
            if (m_pEnumerator)
            {
                m_pEnumerator->Release();
            }
            CoUninitialize();
        }
    }

private:
    nn::os::Mutex m_Mutex;
    int m_RefCount;
    IMMDeviceEnumerator* m_pEnumerator;
};

DeviceEnumerator g_DeviceEnumerator;


IMMDeviceEnumerator* GetDeviceEnumerator()
{
    return g_DeviceEnumerator.Get();
}

Result InitializeAudioClient(IAudioClient* pAudioClient, int32_t* pSampleRate, int32_t* pChannelCount, int32_t* pBitsPerSample, int sampleRate = 0)
{
    NN_SDK_ASSERT(pAudioClient);
    NN_SDK_ASSERT(pSampleRate);
    NN_SDK_ASSERT(pChannelCount);
    NN_SDK_ASSERT(pBitsPerSample);

    WAVEFORMATEXTENSIBLE* pWaveFormat_;
    NN_RESULT_THROW_UNLESS(
        SUCCEEDED(
            pAudioClient->GetMixFormat(reinterpret_cast<WAVEFORMATEX**>(&pWaveFormat_))),
        ResultOperationFailed());

    WAVEFORMATEXTENSIBLE* pWaveFormat(pWaveFormat_);

    REFERENCE_TIME hnsDefaultDevicePeriod;
    REFERENCE_TIME hnsMinimumDevicePeriod;
    NN_RESULT_THROW_UNLESS(
        SUCCEEDED(
            pAudioClient->GetDevicePeriod(&hnsDefaultDevicePeriod, &hnsMinimumDevicePeriod)),
        ResultOperationFailed());

    pWaveFormat->Format.wFormatTag = WAVE_FORMAT_EXTENSIBLE;
    pWaveFormat->Format.wBitsPerSample = 16;
    if (sampleRate > 0)
    {
        pWaveFormat->Format.nSamplesPerSec = sampleRate;
    }
    pWaveFormat->Format.nBlockAlign = pWaveFormat->Format.wBitsPerSample / CHAR_BIT * pWaveFormat->Format.nChannels;
    pWaveFormat->Format.nAvgBytesPerSec = pWaveFormat->Format.nSamplesPerSec * pWaveFormat->Format.nBlockAlign;
    pWaveFormat->Samples.wValidBitsPerSample = pWaveFormat->Format.wBitsPerSample;
    pWaveFormat->SubFormat = KSDATAFORMAT_SUBTYPE_PCM;

    auto result = pAudioClient->Initialize(
        AUDCLNT_SHAREMODE_SHARED,
        AUDCLNT_STREAMFLAGS_EVENTCALLBACK | AUDCLNT_STREAMFLAGS_NOPERSIST,
        hnsDefaultDevicePeriod,
        0,
        reinterpret_cast<WAVEFORMATEX*>(pWaveFormat),
        &GUID_NULL);

    if( !SUCCEEDED(result) )
    {
        pWaveFormat = nullptr;
        CoTaskMemFree(pWaveFormat_);
        return ResultOperationFailed();
    }

    *pSampleRate = pWaveFormat->Format.nSamplesPerSec;
    *pChannelCount = pWaveFormat->Format.nChannels;

    NN_SDK_ASSERT(pWaveFormat->SubFormat == KSDATAFORMAT_SUBTYPE_PCM);
    *pBitsPerSample = pWaveFormat->Format.wBitsPerSample;

    CoTaskMemFree(pWaveFormat_);
    NN_RESULT_SUCCESS;
}

}

Result AudioOutClientImplByWin32::Initialize(const char* name) NN_NOEXCEPT
{
    IMMDevice* pDevice;

    g_DeviceEnumerator.Initialize();
    if (name[0] == '\0')
    {
        NN_RESULT_THROW_UNLESS(
            SUCCEEDED(GetDeviceEnumerator()->GetDefaultAudioEndpoint(eRender, eConsole, &pDevice)),
            ResultOperationFailed()
        );
    }
    else
    {
        IMMDeviceCollection* pDeviceCollection_ = nullptr;
        NN_RESULT_THROW_UNLESS(
            SUCCEEDED(GetDeviceEnumerator()->EnumAudioEndpoints(eRender, DEVICE_STATE_ACTIVE, &pDeviceCollection_)),
            ResultOperationFailed()
        );
        std::unique_ptr<IMMDeviceCollection> pDeviceCollection(pDeviceCollection_);

        UINT deviceCount;
        NN_RESULT_THROW_UNLESS(
            SUCCEEDED(pDeviceCollection->GetCount(&deviceCount)),
            ResultOperationFailed()
        );

        int deviceIndex = -1;
        for (int i = 0; i < static_cast<int>(deviceCount); ++i)
        {
            const int NameLength = 256;
            char deviceName[NameLength];
            pDeviceCollection->Item(i, &pDevice);
            detail::GetDeviceName(pDevice, deviceName, NameLength);
            if (std::strncmp(deviceName, name, NameLength) == 0)
            {
                deviceIndex = i;
                break;
            }
        }
        NN_RESULT_THROW_UNLESS(deviceIndex >= 0, ResultNotFound());

        pDeviceCollection->Item(deviceIndex, &pDevice);
    }

    NN_RESULT_TRY(InitializeCommon(pDevice, 0))
    NN_RESULT_CATCH_ALL
    {
        pDevice->Release();
        NN_RESULT_RETHROW;
    }
    NN_RESULT_END_TRY

    m_pDevice = pDevice;

    NN_RESULT_SUCCESS;
}

Result AudioOutClientImplByWin32::InitializeCommon(IMMDevice* pDevice, int sampleRate_) NN_NOEXCEPT
{
    IAudioClient* pAudioClient_;
    NN_RESULT_THROW_UNLESS(
        SUCCEEDED(
            pDevice->Activate(
                __uuidof(IAudioClient),
                CLSCTX_INPROC_SERVER,
                nullptr,
                reinterpret_cast<void**>(&pAudioClient_))),
        ResultOperationFailed());

    std::unique_ptr<IAudioClient> pAudioClient(pAudioClient_);

    int32_t sampleRate;
    int32_t channelCount;
    int32_t bitsPerSample;

    int attempts = 4;
    for (int i = 0; i < attempts; i++)
    {
        if (InitializeAudioClient(pAudioClient.get(), &sampleRate, &channelCount, &bitsPerSample, sampleRate_).IsSuccess())
        {
            if (i == attempts - 1)
            {
                NN_RESULT_DO(InitializeAudioClient(pAudioClient.get(), &sampleRate, &channelCount, &bitsPerSample, sampleRate_));
            }
            break;
        }
        pAudioClient_->Release();
        pAudioClient_ = nullptr;
        NN_RESULT_THROW_UNLESS(
            SUCCEEDED(
                pDevice->Activate(
                    __uuidof(IAudioClient),
                    CLSCTX_INPROC_SERVER,
                    nullptr,
                    reinterpret_cast<void**>(&pAudioClient_))),
            ResultOperationFailed());
    }

    UINT32 numBufferFrames;
    NN_RESULT_THROW_UNLESS(
        SUCCEEDED(
            pAudioClient->GetBufferSize(&numBufferFrames)),
        ResultOperationFailed());

    HANDLE hEvent_ = CreateEvent(NULL, FALSE, FALSE, NULL);
    NN_ABORT_UNLESS(hEvent_);
    detail::ScopedHandle hEvent(hEvent_);

    NN_RESULT_THROW_UNLESS(
        SUCCEEDED(
            pAudioClient->SetEventHandle(hEvent)),
        ResultOperationFailed());

    IAudioRenderClient* pAudioRenderClient;
    NN_RESULT_THROW_UNLESS(
        SUCCEEDED(
            pAudioClient->GetService(IID_PPV_ARGS(&pAudioRenderClient))),
        ResultOperationFailed());

    m_pAudioClient = pAudioClient.release();
    m_pAudioRenderClient = pAudioRenderClient;
    m_hEvent = hEvent.Release();
    m_SampleRate = sampleRate;
    m_ChannelCount = channelCount;
    m_BufferSampleCount = numBufferFrames / BufferSeparationCount;
    m_BitsPerSample = bitsPerSample;

    BYTE* pData;
    NN_RESULT_THROW_UNLESS(
        SUCCEEDED(
            m_pAudioRenderClient->GetBuffer(numBufferFrames, &pData)),
        ResultOperationFailed());

    NN_RESULT_THROW_UNLESS(
        SUCCEEDED(
            m_pAudioRenderClient->ReleaseBuffer(numBufferFrames, AUDCLNT_BUFFERFLAGS_SILENT)),
        ResultOperationFailed());

    NN_RESULT_THROW_UNLESS(
        SUCCEEDED(
            m_pAudioClient->Start()),
        ResultOperationFailed());

    NN_RESULT_SUCCESS;
}

Result AudioOutClientImplByWin32::Finalize() NN_NOEXCEPT
{
    if (m_pAudioClient != nullptr)
    {
        NN_RESULT_THROW_UNLESS(SUCCEEDED(m_pAudioClient->Stop()), ResultOperationFailed());
    }

    if (m_pAudioRenderClient != nullptr)
    {
        m_pAudioRenderClient->Release();
        m_pAudioRenderClient = nullptr;
    }
    if (m_hEvent != nullptr)
    {
        CloseHandle(m_hEvent);
        m_hEvent = nullptr;
    }
    if (m_pAudioClient != nullptr)
    {
        m_pAudioClient->Release();
        m_pAudioClient = nullptr;
    }
    if (m_pDevice != nullptr)
    {
        m_pDevice->Release();
        m_pDevice = nullptr;
    }

    g_DeviceEnumerator.Finalize();
    NN_RESULT_SUCCESS;
}

int AudioOutClientImplByWin32::GetSampleRate() const NN_NOEXCEPT
{
    return m_SampleRate;
}

int AudioOutClientImplByWin32::GetChannelCount() const NN_NOEXCEPT
{
    return m_ChannelCount;
}

int AudioOutClientImplByWin32::GetBitsPerSample() const NN_NOEXCEPT
{
    return m_BitsPerSample;
}

void AudioOutClientImplByWin32::Wait() NN_NOEXCEPT
{
    DWORD waitResult = WaitForSingleObject(m_hEvent, INFINITE);
    NN_SDK_ASSERT(waitResult == WAIT_OBJECT_0);
    NN_UNUSED(waitResult);
}

void* AudioOutClientImplByWin32::AcquireBuffer(int sampleCount) NN_NOEXCEPT
{
    UINT32 numPaddingFrames;
    if (FAILED(m_pAudioClient->GetCurrentPadding(&numPaddingFrames)))
    {
        return nullptr;
    }

    int32_t emptyBufferFrames = m_BufferSampleCount * BufferSeparationCount - numPaddingFrames;
    if (emptyBufferFrames < sampleCount)
    {
        return nullptr;
    }

    BYTE* pData;
    if (FAILED(m_pAudioRenderClient->GetBuffer(sampleCount, &pData)))
    {
        return nullptr;
    }

    return pData;
}

void AudioOutClientImplByWin32::ReleaseBuffer(int sampleCount) NN_NOEXCEPT
{
    if (FAILED(m_pAudioRenderClient->ReleaseBuffer(sampleCount, 0)))
    {
        return;  // TODO
    }
}

}  // namespace audio
}  // namespace nn
