﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <cstdint>
#include <cstddef>
#include <algorithm>
#include <limits>
#include <random>

#include <nn/os.h>
#include <nn/nn_Assert.h>

#include <nn/codec.h>
#include <nn/codec/codec_OpusEncoder.h>
#include <nn/codec/detail/codec_OpusPacketInternal.h> // nn::codec::detail::OpusPacketInternal

#include <nnt.h>
#include <nnt/codecUtil/testCodec_Util.h>

// #define ENABLE_DEPRECATED_TEST

namespace {

const char* TestParameterListStrings[] =
{
    "SampleRate=48000, ChannelCount=1, Frame=2500",
    "SampleRate=48000, ChannelCount=2, Frame=2500",
    "SampleRate=24000, ChannelCount=1, Frame=2500",
    "SampleRate=24000, ChannelCount=2, Frame=2500",
    "SampleRate=16000, ChannelCount=1, Frame=2500",
    "SampleRate=16000, ChannelCount=2, Frame=2500",
    "SampleRate=12000, ChannelCount=1, Frame=2500",
    "SampleRate=12000, ChannelCount=2, Frame=2500",
    "SampleRate=8000,  ChannelCount=1, Frame=2500",
    "SampleRate=8000,  ChannelCount=2, Frame=2500",
};

}

class OpusPacketApiTestBase
{
private:
    std::nullptr_t FreeIfNotNull(void* pointer)
    {
        if(nullptr != pointer)
        {
            std::free(pointer);
        }
        return nullptr;
    }
protected:
    OpusPacketApiTestBase()
        : encoder()
        , sampleRate(0)
        , channelCount(0)
        , frame(0)
        , workBufferAddress(nullptr)
        , workBufferSize(0)
        , outputBufferAddress(nullptr)
        , outputBufferSize(0)
        , inputBufferAddress(nullptr)
        , inputBufferSize(0)
        , outputDataSize(0)
        , inputSampleCount(0)
    {}

    virtual ~OpusPacketApiTestBase()
    {
        workBufferAddress   = FreeIfNotNull(workBufferAddress);
        outputBufferAddress = FreeIfNotNull(outputBufferAddress);
        inputBufferAddress  = FreeIfNotNull(inputBufferAddress);
    }

    virtual void Initialize(const char* stringValue)
    {
        sampleRate = nnt::codec::util::GetIntegerValueLabeledWith(stringValue, "SampleRate");
        channelCount = nnt::codec::util::GetIntegerValueLabeledWith(stringValue, "ChannelCount");
        frame = nnt::codec::util::GetIntegerValueLabeledWith(stringValue, "Frame");
    }

    virtual void SetUp()
    {
        NN_ASSERT(sampleRate != 0);
        NN_ASSERT(channelCount != 0);
        NN_ASSERT(frame != 0);
        workBufferSize = encoder.GetWorkBufferSize(sampleRate, channelCount);
        workBufferAddress = std::malloc(workBufferSize);
        ASSERT_NE(nullptr, workBufferAddress);

        outputBufferSize = nn::codec::OpusPacketSizeMaximum;
        outputBufferAddress = static_cast<uint8_t*>(std::malloc(outputBufferSize));
        ASSERT_NE(nullptr, outputBufferAddress);

        inputSampleCount = static_cast<int>(frame * sampleRate / 1000 / 1000);
        inputBufferSize = inputSampleCount * channelCount * sizeof(int16_t);
        inputBufferAddress = static_cast<int16_t*>(std::malloc(inputBufferSize));
        ASSERT_NE(nullptr, inputBufferAddress);

        const int sampleCount = static_cast<int>(inputBufferSize / sizeof(int16_t));
        std::mt19937 mt(static_cast<uint32_t>(nn::os::GetSystemTick().GetInt64Value()));
        const auto a = std::numeric_limits<uint16_t>::max();
        const auto b = std::numeric_limits<int16_t>::min();
        for (int k = 0; k < sampleCount; ++k)
        {
            inputBufferAddress[k] = (mt() % a) + b; // 何でも良いけど、乱数
        }
    }

    virtual void TearDown()
    {
        workBufferAddress   = FreeIfNotNull(workBufferAddress);
        outputBufferAddress = FreeIfNotNull(outputBufferAddress);
        inputBufferAddress  = FreeIfNotNull(inputBufferAddress);
    }


protected:
    nn::codec::OpusEncoder encoder;
    int sampleRate;
    int channelCount;
    int frame;
    void* workBufferAddress;
    std::size_t workBufferSize;
    uint8_t* outputBufferAddress;
    std::size_t outputBufferSize;
    int16_t* inputBufferAddress;
    std::size_t inputBufferSize;
    std::size_t outputDataSize;
    int inputSampleCount;
};


class OpusPacketApiTest
    : public OpusPacketApiTestBase
    , public ::testing::TestWithParam<const char*>

{
public:
    OpusPacketApiTest()
        : OpusPacketApiTestBase()
    {}

    virtual void SetUp()
    {
        OpusPacketApiTestBase::Initialize(GetParam());
        OpusPacketApiTestBase::SetUp();
        ASSERT_EQ(
            nn::codec::OpusResult_Success,
            encoder.Initialize(sampleRate, channelCount, workBufferAddress, workBufferSize)
        );
        // Stereo が Monoral に圧縮されないように、最大ビットレートを指定しておく。
        encoder.SetBitRate(nn::codec::GetOpusBitRateMax(channelCount));
        ASSERT_EQ(
            nn::codec::OpusResult_Success,
            encoder.EncodeInterleaved(
                &outputDataSize,
                outputBufferAddress,
                outputBufferSize,
                inputBufferAddress,
                inputSampleCount
            )
        );
        ASSERT_GT(outputDataSize, 0u);
    }

    virtual void TearDown()
    {
        if (encoder.IsInitialized())
        {
            encoder.Finalize();
        }
        OpusPacketApiTestBase::TearDown();
    }
};

/**
 * @brief       GetOpusPacketSampleCountInPacket() の正常系テストです。
 */
TEST_P(OpusPacketApiTest, GetOpusPacketSampleCountInPacket)
{
    SCOPED_TRACE(GetParam());

    const auto  correct = encoder.CalculateFrameSampleCount(frame);
    int value;
    EXPECT_EQ(
        nn::codec::OpusResult_Success,
        nn::codec::GetOpusPacketSampleCountInPacket(
            &value, outputBufferAddress, outputBufferSize, sampleRate) );
    EXPECT_EQ(correct, value);

#if defined(ENABLE_DEPRECATED_TEST)
    EXPECT_EQ(
        correct,
        nn::codec::GetOpusPacketSampleCountInPacket(
            outputBufferAddress, outputBufferSize, sampleRate) ); // deprecated
    // Smaller buffer size
    const std::size_t smallerSize = nn::codec::detail::OpusPacketInternal::HeaderSize + 2 - 1;
    EXPECT_EQ(
        nn::codec::OpusPacketErrorCode,
        nn::codec::GetOpusPacketSampleCountInPacket(
            outputBufferAddress, smallerSize, sampleRate) ); // deprecated
#endif
}

/**
 * @brief       GetOpusPacketSampleCountPerFramce() の正常系テストです。
 */
TEST_P(OpusPacketApiTest, GetOpusPacketSampleCountPerFramce)
{
    SCOPED_TRACE(GetParam());

    const auto  correct = encoder.CalculateFrameSampleCount(frame);
    int value;
    EXPECT_EQ(
        nn::codec::OpusResult_Success,
        nn::codec::GetOpusPacketSampleCountPerFrame(
            &value, outputBufferAddress, outputBufferSize, sampleRate) );
    EXPECT_EQ(correct, value);

#if defined(ENABLE_DEPRECATED_TEST)
    EXPECT_EQ(
        correct,
        nn::codec::GetOpusPacketSampleCountPerFrame(
            outputBufferAddress, outputBufferSize, sampleRate) ); // deprecated
    // Smaller buffer size
    const std::size_t smallerSize = nn::codec::detail::OpusPacketInternal::HeaderSize + 1 - 1;
    EXPECT_EQ(
        nn::codec::OpusPacketErrorCode,
        nn::codec::GetOpusPacketSampleCountPerFrame(
            outputBufferAddress, smallerSize, sampleRate) ); // deprecated
#endif
}

/**
 * @brief       GetOpusPacketFrameCount() の正常系テストです。
 */
TEST_P(OpusPacketApiTest, GetOpusPacketFrameCount)
{
    SCOPED_TRACE(GetParam());

    const auto  correct = 1;
    int value;
    EXPECT_EQ(
        nn::codec::OpusResult_Success,
        nn::codec::GetOpusPacketFrameCount(
            &value, outputBufferAddress, outputBufferSize) );
    EXPECT_EQ(correct, value);

#if defined(ENABLE_DEPRECATED_TEST)
    EXPECT_EQ(
        correct,
        nn::codec::GetOpusPacketFrameCount(
            outputBufferAddress, outputBufferSize) ); // deprecated
    // Smaller buffer size
    const std::size_t smallerSize = nn::codec::detail::OpusPacketInternal::HeaderSize + 2 - 1;
    EXPECT_EQ(
        nn::codec::OpusPacketErrorCode,
        nn::codec::GetOpusPacketFrameCount(
            outputBufferAddress, smallerSize) ); // deprecated
#endif
}

/**
 * @brief       GetOpusPacketChannelCount() の正常系テストです。
 */
TEST_P(OpusPacketApiTest, GetOpusPacketChannelCount)
{
    SCOPED_TRACE(GetParam());

    const auto  correct = channelCount;
    int value;
    EXPECT_EQ(
        nn::codec::OpusResult_Success,
        nn::codec::GetOpusPacketChannelCount(
            &value, outputBufferAddress, outputBufferSize) );
    EXPECT_EQ(correct, value);

#if defined(ENABLE_DEPRECATED_TEST)
    EXPECT_EQ(
        correct,
        nn::codec::GetOpusPacketChannelCount(
            outputBufferAddress, outputBufferSize) ); // deprecated
    // Smaller buffer size
    const std::size_t smallerSize = nn::codec::detail::OpusPacketInternal::HeaderSize + 1 - 1;
    EXPECT_EQ(
        nn::codec::OpusPacketErrorCode,
        nn::codec::GetOpusPacketChannelCount(
            outputBufferAddress, smallerSize) ); // deprecated
#endif
}

INSTANTIATE_TEST_CASE_P(
    RoundRobin,
    OpusPacketApiTest,
    ::testing::ValuesIn(TestParameterListStrings)
);

#if !defined(NN_SDK_BUILD_RELEASE)
typedef OpusPacketApiTest OpusPacketApiDeathTest;

/**
 * @brief       GetOpusPacketSampleCountInPacket() の異常系テストです。
 */
TEST_P(OpusPacketApiDeathTest, GetOpusPacketSampleCountInPacket)
{
    SCOPED_TRACE(GetParam());

    int value;
    // Unsupported sample rate.
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountInPacket(&value, outputBufferAddress, outputBufferSize, sampleRate - 1), "");
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountInPacket(&value, outputBufferAddress, outputBufferSize, sampleRate + 1), "");
#if defined(ENABLE_DEPRECATED_TEST)
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountInPacket(outputBufferAddress, outputBufferSize, sampleRate - 1), ""); // deprecated
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountInPacket(outputBufferAddress, outputBufferSize, sampleRate + 1), ""); // deprecated
#endif

    // Smaller buffer size
    const std::size_t smallerSize = nn::codec::detail::OpusPacketInternal::HeaderSize + 1 - 1;
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountInPacket(&value, outputBufferAddress, smallerSize, sampleRate), "");

    // nullptr
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountInPacket(&value, nullptr, outputBufferSize, sampleRate), "");
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountInPacket(nullptr, outputBufferAddress, outputBufferSize, sampleRate), "");
#if defined(ENABLE_DEPRECATED_TEST)
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountInPacket(nullptr, outputBufferSize, sampleRate), ""); // deprecated
#endif
}

/**
 * @brief       GetOpusPacketSampleCountPerFramce() の異常系テストです。
 */
TEST_P(OpusPacketApiDeathTest, GetOpusPacketSampleCountPerFramce)
{
    SCOPED_TRACE(GetParam());

    int value;
    // Unsupported sample rate.
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountPerFrame(&value, outputBufferAddress, outputBufferSize, sampleRate - 1), "");
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountPerFrame(&value, outputBufferAddress, outputBufferSize, sampleRate + 1), "");
#if defined(ENABLE_DEPRECATED_TEST)
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountPerFrame(outputBufferAddress, outputBufferSize, sampleRate - 1), ""); // deprecated
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountPerFrame(outputBufferAddress, outputBufferSize, sampleRate + 1), ""); // deprecated
#endif

    // Smaller buffer size
    const std::size_t smallerSize = nn::codec::detail::OpusPacketInternal::HeaderSize + 1 - 1;
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountPerFrame(&value, outputBufferAddress, smallerSize, sampleRate), "");

    // nullptr
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountPerFrame(&value, nullptr, outputBufferSize, sampleRate), "");
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountPerFrame(nullptr, outputBufferAddress, outputBufferSize, sampleRate), "");
#if defined(ENABLE_DEPRECATED_TEST)
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketSampleCountPerFrame(nullptr, outputBufferSize, sampleRate), ""); // deprecated
#endif
}

/**
 * @brief       GetOpusPacketFrameCount() の異常系テストです。
 */
TEST_P(OpusPacketApiDeathTest, GetOpusPacketFrameCount)
{
    SCOPED_TRACE(GetParam());

    int value;
    // Smaller buffer size
    const std::size_t smallerSize = nn::codec::detail::OpusPacketInternal::HeaderSize + 1 - 1;
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketFrameCount(&value, outputBufferAddress, smallerSize), "");
    // nullptr
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketFrameCount(&value, nullptr, outputBufferSize), "");
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketFrameCount(nullptr, outputBufferAddress, outputBufferSize), "");
#if defined(ENABLE_DEPRECATED_TEST)
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketFrameCount(nullptr, outputBufferSize), ""); // deprecated
#endif
}

/**
 * @brief       GetOpusPacketChannelCount() の異常系テストです。
 */
TEST_P(OpusPacketApiDeathTest, GetOpusPacketChannelCount)
{
    SCOPED_TRACE(GetParam());

    int value;
    // Smaller buffer size
    const std::size_t smallerSize = nn::codec::detail::OpusPacketInternal::HeaderSize + 1 - 1;
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketChannelCount(&value, outputBufferAddress, smallerSize), "");
    // nullptr
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketChannelCount(&value, nullptr, outputBufferSize), "");
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketChannelCount(nullptr, outputBufferAddress, outputBufferSize), "");
#if defined(ENABLE_DEPRECATED_TEST)
    EXPECT_DEATH_IF_SUPPORTED(nn::codec::GetOpusPacketChannelCount(nullptr, outputBufferSize), "");  // deprecated
#endif
}

INSTANTIATE_TEST_CASE_P(
    RoundRobin,
    OpusPacketApiDeathTest,
    ::testing::ValuesIn(TestParameterListStrings)
);

#endif // !defined(NN_SDK_BUILD_RELEASE)
