﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <cstdint>
#include <cstddef>
#include <forward_list> // std::forward_list

#include <nn/nn_Assert.h>

#include <nn/codec.h>

#include <nnt.h>
#include <nnt/codecUtil/testCodec_Util.h>
#include <nnt/codecUtil/testCodec_NativeOpusDecoder.h>

struct OpusMSDecoder;

// 以下のテストパラメータでは、列挙型の値を決め打ちしているので注意。
NN_STATIC_ASSERT(nn::codec::OpusCodingMode_Celt == 2);
NN_STATIC_ASSERT(nn::codec::OpusCodingMode_Silk == 3);
NN_STATIC_ASSERT(nn::codec::OpusCodingMode_Hybrid == 4);

NN_STATIC_ASSERT(nn::codec::OpusBandwidth_NarrowBand == 2);
NN_STATIC_ASSERT(nn::codec::OpusBandwidth_MediumBand == 3);
NN_STATIC_ASSERT(nn::codec::OpusBandwidth_WideBand == 4);
NN_STATIC_ASSERT(nn::codec::OpusBandwidth_SuperWideBand == 5);
NN_STATIC_ASSERT(nn::codec::OpusBandwidth_FullBand == 6);

namespace {

const char* TestParameterListStrings[] =
{
    "Frame=2500,  Bandwidth=2",
    "Frame=2500,  Bandwidth=3",
    "Frame=2500,  Bandwidth=4",
    "Frame=2500,  Bandwidth=5",
    "Frame=2500,  Bandwidth=6",
    "Frame=5000,  Bandwidth=2",
    "Frame=5000,  Bandwidth=3",
    "Frame=5000,  Bandwidth=4",
    "Frame=5000,  Bandwidth=5",
    "Frame=5000,  Bandwidth=6",
    "Frame=10000, Bandwidth=2",
    "Frame=10000, Bandwidth=3",
    "Frame=10000, Bandwidth=4",
    "Frame=10000, Bandwidth=5",
    "Frame=10000, Bandwidth=6",
    "Frame=20000, Bandwidth=2",
    "Frame=20000, Bandwidth=3",
    "Frame=20000, Bandwidth=4",
    "Frame=20000, Bandwidth=5",
    "Frame=20000, Bandwidth=6",
    "Frame=40000, Bandwidth=2",
    "Frame=40000, Bandwidth=3",
    "Frame=40000, Bandwidth=4",
    "Frame=40000, Bandwidth=5",
    "Frame=40000, Bandwidth=6",
    "Frame=60000, Bandwidth=2",
    "Frame=60000, Bandwidth=3",
    "Frame=60000, Bandwidth=4",
    "Frame=60000, Bandwidth=5",
    "Frame=60000, Bandwidth=6",
};

class ConfigurationManager
{
public:
    void Add(int frame) NN_NOEXCEPT
    {
        NN_ASSERT(frame == 2500
            || frame == 5000
            || frame == 10000
            || frame == 20000
            || frame == 40000
            || frame == 60000);
        m_AcceptableFrames.push_front(frame);
    }

    void Add(nn::codec::OpusBandwidth bandwidth) NN_NOEXCEPT
    {
        NN_ASSERT(bandwidth == nn::codec::OpusBandwidth_NarrowBand
            || bandwidth == nn::codec::OpusBandwidth_MediumBand
            || bandwidth == nn::codec::OpusBandwidth_WideBand
            || bandwidth == nn::codec::OpusBandwidth_SuperWideBand
            || bandwidth == nn::codec::OpusBandwidth_FullBand);
        m_AcceptableBandwidths.push_front(bandwidth);
    }

    bool IsAcceptable(int frame, nn::codec::OpusBandwidth bandwidth) NN_NOEXCEPT
    {
        auto it1 = std::find(m_AcceptableFrames.begin(), m_AcceptableFrames.end(), frame);
        auto it2 = std::find(m_AcceptableBandwidths.begin(), m_AcceptableBandwidths.end(), bandwidth);
        return it1 != m_AcceptableFrames.end() && it2 != m_AcceptableBandwidths.end();
    }

private:
    std::forward_list<int> m_AcceptableFrames;
    std::forward_list<nn::codec::OpusBandwidth> m_AcceptableBandwidths;
};

class OpusPacketMakerTestBase
{
public:
    void Initialize(const char* stringValue) NN_NOEXCEPT
    {
        frame = nnt::codec::util::GetIntegerValueLabeledWith(stringValue, "Frame");
        bandwidth =
            static_cast<nn::codec::OpusBandwidth>(
                nnt::codec::util::GetIntegerValueLabeledWith(stringValue, "Bandwidth") );
        NN_ASSERT(frame == 2500
            || frame == 5000
            || frame == 10000
            || frame == 20000
            || frame == 40000
            || frame == 60000);
        NN_ASSERT(bandwidth == nn::codec::OpusBandwidth_NarrowBand
            || bandwidth == nn::codec::OpusBandwidth_MediumBand
            || bandwidth == nn::codec::OpusBandwidth_WideBand
            || bandwidth == nn::codec::OpusBandwidth_SuperWideBand
            || bandwidth == nn::codec::OpusBandwidth_FullBand);
    }

protected:
    nn::codec::OpusCodingMode codingMode;
    int frame;
    nn::codec::OpusBandwidth bandwidth;
};

// Hybrid: 10,20ms/SWB,FB
// Silk: 10,20,40,60/NW,MD,WB
// Celt: 2.5,5,10,20/NW,WB,SWB,FB

class OpusPacketMakerTestForCodingModeCelt
    : public OpusPacketMakerTestBase
    , public ::testing::TestWithParam<const char*>
{
protected:
    inline bool IsAcceptable(int _frame, nn::codec::OpusBandwidth _bandwidth) NN_NOEXCEPT
    {
        return m_ConfigurationManager.IsAcceptable(_frame, _bandwidth);
    }

    virtual void SetUp()
    {
        OpusPacketMakerTestBase::Initialize(GetParam());
        codingMode = nn::codec::OpusCodingMode_Celt;
        m_ConfigurationManager.Add(2500);
        m_ConfigurationManager.Add(5000);
        m_ConfigurationManager.Add(10000);
        m_ConfigurationManager.Add(20000);
        m_ConfigurationManager.Add(nn::codec::OpusBandwidth_NarrowBand);
        m_ConfigurationManager.Add(nn::codec::OpusBandwidth_WideBand);
        m_ConfigurationManager.Add(nn::codec::OpusBandwidth_SuperWideBand);
        m_ConfigurationManager.Add(nn::codec::OpusBandwidth_FullBand);
    }
protected:
    ConfigurationManager m_ConfigurationManager;
};

class OpusPacketMakerTestForCodingModeSilk
    : public OpusPacketMakerTestBase
    , public ::testing::TestWithParam<const char*>
{
protected:
    inline bool IsAcceptable(int _frame, nn::codec::OpusBandwidth _bandwidth) NN_NOEXCEPT
    {
        return m_ConfigurationManager.IsAcceptable(_frame, _bandwidth);
    }

    virtual void SetUp()
    {
        OpusPacketMakerTestBase::Initialize(GetParam());
        codingMode = nn::codec::OpusCodingMode_Silk;
        m_ConfigurationManager.Add(10000);
        m_ConfigurationManager.Add(20000);
        m_ConfigurationManager.Add(40000);
        m_ConfigurationManager.Add(60000);
        m_ConfigurationManager.Add(nn::codec::OpusBandwidth_NarrowBand);
        m_ConfigurationManager.Add(nn::codec::OpusBandwidth_MediumBand);
        m_ConfigurationManager.Add(nn::codec::OpusBandwidth_WideBand);
    }
protected:
    ConfigurationManager m_ConfigurationManager;
};

class OpusPacketMakerTestForCodingModeHybrid
    : public OpusPacketMakerTestBase
    , public ::testing::TestWithParam<const char*>
{
protected:
    inline bool IsAcceptable(int _frame, nn::codec::OpusBandwidth _bandwidth) NN_NOEXCEPT
    {
        return m_ConfigurationManager.IsAcceptable(_frame, _bandwidth);
    }

    virtual void SetUp()
    {
        OpusPacketMakerTestBase::Initialize(GetParam());
        codingMode = nn::codec::OpusCodingMode_Hybrid;
        m_ConfigurationManager.Add(10000);
        m_ConfigurationManager.Add(20000);
        m_ConfigurationManager.Add(nn::codec::OpusBandwidth_SuperWideBand);
        m_ConfigurationManager.Add(nn::codec::OpusBandwidth_FullBand);
    }
protected:
    ConfigurationManager m_ConfigurationManager;
};

}

/**
 * @brief       正常系の基底テストです。
 */
void OpusPacketMakerTestFunction(
    nn::codec::OpusCodingMode codingMode,
    int frame,
    nn::codec::OpusBandwidth bandwidth,
    const char* parameterStringBase
) NN_NOEXCEPT
{
    struct nnt::codec::util::OpusPacketConfiguration configuration;
    nnt::codec::util::DefaultOpusPacket(&configuration);

    configuration.codingMode = codingMode;
    configuration.frameInMicroSeconds = frame;
    configuration.bandwidth = bandwidth;

    const int channelCounts[] = { 1 , 2 };
    for (const auto channelCount : channelCounts)
    {
        for (auto streamCount = 1; streamCount < nnt::codec::util::OpusFrameCountMax; ++streamCount)
        {
            auto traceString = std::string(parameterStringBase);
            traceString += std::string(", ChannelCount=") + std::to_string(channelCount)
                + std::string(", StreamCount=") + std::to_string(streamCount);
            SCOPED_TRACE(traceString.c_str());
            std::size_t encodedDataSize = 0;
            const std::size_t encodeBufferSize = 1600;
            uint8_t encodeBuffer[encodeBufferSize];

            configuration.channelCount = channelCount;
            configuration.streamCount = streamCount;
            // パケット作成
            ASSERT_EQ(
                true,
                nnt::codec::util::MakeOpusPacket(
                    &encodedDataSize,
                    encodeBuffer,
                    encodeBufferSize,
                    configuration)
            );
            //------------------------------------------------------
            // 各値を確認
            //------------------------------------------------------
            // 符号化モード
            nn::codec::OpusCodingMode codingModeInToc;
            ASSERT_EQ(
                nn::codec::OpusResult_Success,
                nnt::codec::util::GetOpusPacketCodingMode(
                    &codingModeInToc,
                    encodeBuffer,
                    encodedDataSize)
            );
            EXPECT_EQ(codingMode, codingModeInToc);
            // フレーム
            int frameInToc;
            ASSERT_EQ(
                nn::codec::OpusResult_Success,
                nnt::codec::util::GetOpusPacketFrameInMicroSeconds(
                    &frameInToc,
                    encodeBuffer,
                    encodedDataSize)
            );
            EXPECT_EQ(frame, frameInToc);
            // フレーム数
            int frameCountInToc;
            ASSERT_EQ(
                nn::codec::OpusResult_Success,
                nnt::codec::util::GetOpusPacketFrameCount(
                    &frameCountInToc,
                    encodeBuffer,
                    encodedDataSize)
            );
            // 現状では 1 フレームのみの対応
            EXPECT_EQ(1, frameCountInToc);
            // バンド幅
            nn::codec::OpusBandwidth bandwidthInToc;
            ASSERT_EQ(
                nn::codec::OpusResult_Success,
                nnt::codec::util::GetOpusPacketBandwidth(
                    &bandwidthInToc,
                    encodeBuffer,
                    encodedDataSize)
            );
            EXPECT_EQ(bandwidth, bandwidthInToc);
            // チャンネル数
            int channelCountInToc;
            ASSERT_EQ(
                nn::codec::OpusResult_Success,
                nnt::codec::util::GetOpusPacketChannelCount(
                    &channelCountInToc,
                    encodeBuffer,
                    encodedDataSize)
            );
            EXPECT_EQ(channelCount, channelCountInToc);
            // デコードテスト
            const int sampleRate = 48000;
            const int coupledStreamCount = 0;
            const int outputChannelCount = streamCount + coupledStreamCount;
            uint8_t channelMapping[nnt::codec::util::OpusStreamCountMax];
            for (uint8_t i = 0u ; i < streamCount; ++i)
            {
                channelMapping[i] = i;
            }
            nnt::codec::NativeOpusDecoder<::OpusMSDecoder> decoder;
            auto workBufferSize = decoder.GetWorkBufferSize(sampleRate, streamCount, 0);
            ASSERT_GT(workBufferSize, 0u);
            auto workBuffer = std::malloc(workBufferSize);
            ASSERT_NE(nullptr, workBuffer);
            ASSERT_EQ(
                nn::codec::OpusResult_Success,
                decoder.Initialize(
                    sampleRate,
                    outputChannelCount,
                    streamCount,
                    coupledStreamCount,
                    channelMapping,
                    workBuffer,
                    workBufferSize)
            );
            int sampleCount;
            std::size_t consumed;
            const auto pcmBufferCount = static_cast<int>(
                streamCount * outputChannelCount * (static_cast<int64_t>(configuration.frameInMicroSeconds) * sampleRate / 1000 / 1000) );
            auto pcmBuffer = new int16_t[pcmBufferCount];
            ASSERT_NE(nullptr, pcmBuffer);
            ASSERT_EQ(
                nn::codec::OpusResult_Success,
                decoder.DecodeInterleaved(
                    &consumed,
                    &sampleCount,
                    pcmBuffer,
                    pcmBufferCount * sizeof(int16_t),
                    encodeBuffer,
                    encodedDataSize)
            );
            NN_UNUSED(consumed);
            NN_UNUSED(sampleCount);
            delete [] pcmBuffer;
            std::free(workBuffer);
        }
    }
} // NOLINT(readability/fn_size)

/**
 * @brief       異常系の基底テストです。
 */
void OpusPacketMakerDeathTestFunction(
    nn::codec::OpusCodingMode codingMode,
    int frame,
    nn::codec::OpusBandwidth bandwidth,
    const char* parameterStringBase
) NN_NOEXCEPT
{
    NN_UNUSED(codingMode);
    NN_UNUSED(frame);
    NN_UNUSED(bandwidth);
    NN_UNUSED(parameterStringBase);
#if !defined(NN_SDK_BUILD_RELEASE)
    struct nnt::codec::util::OpusPacketConfiguration configuration;
    nnt::codec::util::DefaultOpusPacket(&configuration);

    configuration.codingMode = codingMode;
    configuration.frameInMicroSeconds = frame;
    configuration.bandwidth = bandwidth;

    const int channelCounts[] = { 1 , 2 };
    for (const auto channelCount : channelCounts)
    {
        auto traceString = std::string(parameterStringBase);
        traceString += std::string(", ChannelCount=") + std::to_string(channelCount);
        SCOPED_TRACE(traceString.c_str());
        std::size_t encodedDataSize = 0;
        const std::size_t encodeBufferSize = 1600;
        uint8_t encodeBuffer[encodeBufferSize];

        configuration.channelCount = channelCount;

        EXPECT_DEATH_IF_SUPPORTED(
            nnt::codec::util::MakeOpusPacket(
                &encodedDataSize,
                encodeBuffer,
                encodeBufferSize,
                configuration),
            ""
        );
    }
#endif // !defined(NN_SDK_BUILD_RELEASE)
}

/**
 * @brief       Celt 符号化モードのテストです。
 */
TEST_P(OpusPacketMakerTestForCodingModeCelt, AllCombinationOfConfigration)
{
    if (IsAcceptable(frame, bandwidth))
    {
        OpusPacketMakerTestFunction(codingMode, frame, bandwidth, GetParam());
    }
    else
    {
        OpusPacketMakerDeathTestFunction(codingMode, frame, bandwidth, GetParam());
    }
}

/**
 * @brief       Silk 符号化モードのテストです。
 */
TEST_P(OpusPacketMakerTestForCodingModeSilk, AllCombinationOfConfigration)
{
    if (IsAcceptable(frame, bandwidth))
    {
        OpusPacketMakerTestFunction(codingMode, frame, bandwidth, GetParam());
    }
    else
    {
        OpusPacketMakerDeathTestFunction(codingMode, frame, bandwidth, GetParam());
    }
}

/**
 * @brief       Hybrid 符号化モードのテストです。
 */
TEST_P(OpusPacketMakerTestForCodingModeHybrid, AllCombinationOfConfigration)
{
    if (IsAcceptable(frame, bandwidth))
    {
        OpusPacketMakerTestFunction(codingMode, frame, bandwidth, GetParam());
    }
    else
    {
        OpusPacketMakerDeathTestFunction(codingMode, frame, bandwidth, GetParam());
    }
}

INSTANTIATE_TEST_CASE_P(
    RoundRobin,
    OpusPacketMakerTestForCodingModeCelt,
    ::testing::ValuesIn(TestParameterListStrings)
);

INSTANTIATE_TEST_CASE_P(
    RoundRobin,
    OpusPacketMakerTestForCodingModeSilk,
    ::testing::ValuesIn(TestParameterListStrings)
);

INSTANTIATE_TEST_CASE_P(
    RoundRobin,
    OpusPacketMakerTestForCodingModeHybrid,
    ::testing::ValuesIn(TestParameterListStrings)
);

