﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/


#include "VibrationEncoder.h"
#include <nn/os/os_Result.h>
#include <nn/nn_Assert.h>
#include <nn/nn_Log.h>
#include <algorithm>
#include <cmath>


namespace nns
{


const uint32_t VibrationEncoder::MinSamplingRate;
const uint32_t VibrationEncoder::MaxSamplingRate;
const uint32_t VibrationEncoder::DefaultResamplingRate;
const uint32_t VibrationEncoder::ChunkDurationMs;
const size_t   VibrationEncoder::MaxChunkLength;


/**
 * @brief Initializes the encoder
 *
 * This function must be called before using an encoder instance. You can call it again to reset the encoder.
 *
 * @param [in] samplingRate   The sampling rate for the input waveform (in Hz). Must be a multiple of `resamplingRate`.
 * @param [in] lowFreq1       Lower  cutoff frequency (in Hz) for the 1st pass-band filter. In the range [10; 2500].
 * @param [in] highFreq1      Higher cutoff frequency (in Hz) for the 1st pass-band filter. In the range [10; 2500].
 * @param [in] lowFreq2       Lower  cutoff frequency (in Hz) for the 2nd pass-band filter. In the range [10; 2500].
 * @param [in] highFreq2      Higher cutoff frequency (in Hz) for the 2nd pass-band filter. In the range [10; 2500].
 * @param [in] resamplingRate The encoder's internal sampling rate (in Hz). Must be a multiple of 200 and whithin the range [4000; 48000]. The default is 8000.
 */
nn::Result VibrationEncoder::Init(uint32_t samplingRate, float lowFreq1, float highFreq1, float lowFreq2, float highFreq2, uint32_t resamplingRate) NN_NOEXCEPT
{
    if (!(MinSamplingRate <= samplingRate && samplingRate <= MaxSamplingRate))
    {
        NN_LOG("Sampling rate (%u) must be in [%u; %u]\n", unsigned(samplingRate), unsigned(MinSamplingRate), unsigned(MaxSamplingRate));
        return nn::os::ResultInvalidParameter();
    }
    if (!(MinSamplingRate <= resamplingRate && resamplingRate <= MaxSamplingRate))
    {
        NN_LOG("Resampling rate (%u) must be in [%u; %u]\n", unsigned(resamplingRate), unsigned(MinSamplingRate), unsigned(MaxSamplingRate));
        return nn::os::ResultInvalidParameter();
    }
    if (resamplingRate % 200u != 0u)
    {
        NN_LOG("Resampling rate (%u) must be a multiple of 200\n", unsigned(resamplingRate));
        return nn::os::ResultInvalidParameter();
    }
    if (samplingRate % resamplingRate != 0)
    {
        NN_LOG("Sampling rate (%u) must be a multiple of resampling rate (%u)\n", unsigned(samplingRate), unsigned(resamplingRate));
        return nn::os::ResultInvalidParameter();
    }
    if (lowFreq1 >= highFreq1 || lowFreq2 >= highFreq2)
    {
        NN_LOG("Low cutoff frequencies must be less than high cutoff frequencies\n");
        return nn::os::ResultInvalidParameter();
    }
    if (!(10.0f <= lowFreq1 && highFreq1 <= 2500.0) || !(10.0f <= lowFreq2 && highFreq2 <= 2500.0))
    {
        NN_LOG("Cutoff frequencies must be in the range [10; 2500]\n");
        return nn::os::ResultInvalidParameter();
    }
    if (highFreq1 >= resamplingRate / 2 || highFreq2 >= resamplingRate / 2)
    {
        NN_LOG("High cutoff frequencies must be less than resamplingRate/2\n");
        return nn::os::ResultInvalidParameter();
    }

    m_samplingRate       = samplingRate;
    m_resamplingRate     = resamplingRate;
    m_downsamplingFactor = samplingRate / resamplingRate;
    m_chunkLength        = (resamplingRate * ChunkDurationMs) / 1000;

    // Init encoder's state
    for (int i=0; i<2; ++i)
    {
        m_bands[i].lastCrossing  = 0;
        m_bands[i].prevSample    = 0;
        m_bands[i].prevAmplitude = 0.0f;
        m_bands[i].prevPitch     = 0.0f;
    }
    m_samplesProcessed = 0;

    // Choose downsampling filter
    m_downsamplingFilter = -1;
    if (m_downsamplingFactor != 1)
    {
        m_downsamplingFilter = (highFreq2 > highFreq1);
    }

    // Init filters
    const double   lowFreqs[2]  = {lowFreq1,  lowFreq2};
    const double   highFreqs[2] = {highFreq1, highFreq2};
    const uint32_t filterOrder  = 6;
    for (int i=0; i<2; ++i)
    {
        nn::Result result = m_bands[i].highPass.InitButterworthHighPass(resamplingRate, lowFreqs[i], filterOrder);
        if (result.IsFailure())
        {
            return result;
        }

        uint32_t lowPassSamplingRate = (i == m_downsamplingFilter) ? samplingRate : resamplingRate;
        result = m_bands[i].lowPass.InitButterworthLowPass(lowPassSamplingRate, highFreqs[i], filterOrder);
        if (result.IsFailure())
        {
            return result;
        }
    }

    // Init decoder
    nn::Result result = m_decoder.Init(resamplingRate);
    if (result.IsFailure())
    {
        return result;
    }

    return nn::ResultSuccess();
}


/**
 * @brief Initializes the encoder with an optional trace stream
 *
 * This function must be called before using an encoder instance. You can call it again to reset the encoder.
 *
 * @param [in] samplingRate   The sampling rate for the input waveform (in Hz). Must be a multiple of `resamplingRate`.
 * @param [in] lowFreq1       Lower  cutoff frequency (in Hz) for the 1st pass-band filter. In the range [10; 2500].
 * @param [in] highFreq1      Higher cutoff frequency (in Hz) for the 1st pass-band filter. In the range [10; 2500].
 * @param [in] lowFreq2       Lower  cutoff frequency (in Hz) for the 2nd pass-band filter. In the range [10; 2500].
 * @param [in] highFreq2      Higher cutoff frequency (in Hz) for the 2nd pass-band filter. In the range [10; 2500].
 * @param [in] resamplingRate The encoder's internal sampling rate (in Hz). Must be a multiple of 200 and whithin the range [4000; 48000].
 * @param [in] trace          A stream to write debug traces to. If `NULL` no trace is written.
 */
nn::Result VibrationEncoder::Init(uint32_t samplingRate, float lowFreq1, float highFreq1, float lowFreq2, float highFreq2, uint32_t resamplingRate, FILE* trace) NN_NOEXCEPT
{
    nn::Result result = Init(samplingRate, lowFreq1, highFreq1, lowFreq2, highFreq2, resamplingRate);
    if (trace != NULL)
    {
        fprintf(trace, "[VibrationEncoder] Init(%u, %3.1f, %3.1f, %3.1f, %3.1f, %u) --> %s\n", unsigned(samplingRate), lowFreq1, highFreq1, lowFreq2, highFreq2, unsigned(resamplingRate), (result.IsSuccess())?"Success":"Failure");
    }
    return result;
}


/**
 * @brief Initializes the encoder with default pass-band parameters
 *
 * This function must be called before using an encoder instance. You can call it again to reset the encoder.
 *
 * The first  band is set to [128; 200] Hz
 * The second band is set to [256; 400] Hz
 *
 * @param [in] samplingRate   The sampling rate for the input waveform. Must be a multiple of the resampling rate.
 * @param [in] resamplingRate The encoder's internal sampling rate (in Hz). Must be a multiple of 200. The default is 8000.
 */
nn::Result VibrationEncoder::Init(uint32_t samplingRate, uint32_t resamplingRate) NN_NOEXCEPT
{
    return Init(samplingRate, 128.0f, 200.0f, 256.0f, 400.0f, resamplingRate);
}


/**
 * @brief Initializes the encoder with default pass-band parameters and an optional trace stream
 *
 * This function must be called before using an encoder instance. You can call it again to reset the encoder.
 *
 * The first  band is set to [128; 200] Hz
 * The second band is set to [256; 400] Hz
 *
 * @param [in] samplingRate   The sampling rate for the input waveform. Must be a multiple of the resampling rate.
 * @param [in] resamplingRate The encoder's internal sampling rate (in Hz). Must be a multiple of 200.
 * @param [in] trace          A stream to write debug traces to. If `NULL` no trace is written.
 */
nn::Result VibrationEncoder::Init(uint32_t samplingRate, uint32_t resamplingRate, FILE* trace) NN_NOEXCEPT
{
    return Init(samplingRate, 128.0f, 200.0f, 256.0f, 400.0f, resamplingRate, trace);
}


/**
 * @brief Processes a 5 ms chunk
 *
 * @param [in]  samples The samples to analyze, in mono PCM. There be must be exactly 5 ms worth of data.
 * @param [out] result  The computed vibration data
 */
void VibrationEncoder::ProcessChunk(const int16_t* samples, nn::hid::VibrationValue* result) NN_NOEXCEPT
{
    NN_ASSERT(result != NULL);

    const int16_t* input = samples;

    // Apply low pass filter prior to downsampling
    if (m_downsamplingFactor != 1u)
    {
        NN_ASSERT(m_downsamplingFilter >= 0);
        m_bands[m_downsamplingFilter].lowPass.Apply(samples, m_tmp2, m_chunkLength * m_downsamplingFactor);
        input = m_tmp2;
    }

    // Extracts both bands
    ExtractBand(input, 0, result->amplitudeLow,  result->frequencyLow);
    ExtractBand(input, 1, result->amplitudeHigh, result->frequencyHigh);
    m_samplesProcessed += m_chunkLength;

    // Decode waveform
    m_decoder.GenerateChunk(*result, m_tmp1);

    // Compute energy of original and decoded
    uint64_t energyOriginal = 0;
    uint64_t energyDecoded  = 0;
    for (uint32_t i=0; i<m_chunkLength*m_downsamplingFactor; ++i)
    {
        energyOriginal += samples[i] * samples[i];
    }
    for (uint32_t i=0; i<m_chunkLength; ++i)
    {
        energyDecoded += m_tmp1[i] * m_tmp1[i];
    }
    energyDecoded *= m_downsamplingFactor; // Compensate energy loss due to downsampling

    // Compensate energy attenuation
    if (energyOriginal == 0 || energyDecoded == 0)
    {
        result->amplitudeLow  = 0.0;
        result->amplitudeHigh = 0.0;
    }
    else
    {
        float amplification = sqrtf(float(energyOriginal) / float(energyDecoded));
        result->amplitudeLow  = std::min(1.0f, result->amplitudeLow  * amplification);
        result->amplitudeHigh = std::min(1.0f, result->amplitudeHigh * amplification);
    }
}


/**
 * @brief Processes a 5 ms chunk and optionally writes traces to a stream
 *
 * @warning Enabling traces considerably slows down processing
 *
 * @param [in]  samples The samples to analyze, in mono PCM. There be must be exactly 5 ms worth of data.
 * @param [out] result  The computed vibration data
 * @param [in]  trace   A stream to write debug traces to. If `NULL` no trace is written.
 */
void VibrationEncoder::ProcessChunk(const int16_t* samples, nn::hid::VibrationValue* result, FILE* trace) NN_NOEXCEPT
{
    ProcessChunk(samples, result);
    if (trace != NULL)
    {
        char buffer[100 + MaxChunkLength * 7];
        const size_t bufferLength = sizeof(buffer);

        int length = snprintf(buffer, bufferLength, "[VibrationEncoder] Input: ");
        for (uint32_t i=0; i<m_chunkLength*m_downsamplingFactor; ++i)
        {
            length += snprintf(buffer + length, bufferLength - length, " %d", int(samples[i]));
        }
        length += snprintf(buffer + length, bufferLength - length, "\n[VibrationEncoder] Output: Low(%3.1f, %4.2f) High(%3.1f, %4.2f)\n", result->frequencyLow, result->amplitudeLow, result->frequencyHigh, result->amplitudeHigh);
        fwrite(buffer, 1, length, trace);
    }
}


/**
 * @brief Performs the analysis of a single band
 */
void VibrationEncoder::ExtractBand(const int16_t* samples, uint32_t bandNumber, float& amplitude, float& pitch) NN_NOEXCEPT
{
    Band& band = m_bands[bandNumber];

    // Apply high pass filter and downsample on the fly (stupid nearest neighbor)
    band.highPass.Apply(samples, m_tmp1, m_chunkLength, m_downsamplingFactor);

    // Apply low passs filter
    if (static_cast<int>(bandNumber) != m_downsamplingFilter)
    {
        band.lowPass.Apply(m_tmp1, m_tmp1, m_chunkLength);
    }

    // Count number of zero crossings and find out sample of highest absolute value
    int      maxSample     = 0;
    int      crossingCount = 0;
    int      prevSample    = (m_samplesProcessed!=0) ? band.prevSample : m_tmp1[0];
    uint32_t crossingPositions[MaxChunkLength];
    for (uint32_t i=0; i<m_chunkLength; ++i)
    {
        int sample = m_tmp1[i];
        int crossing = uint32_t(prevSample ^ sample) >> 31;
        crossingPositions[crossingCount] = i;
        crossingCount += crossing;

        prevSample = sample;
        maxSample  = std::max(maxSample, abs(sample));
    }

    // Determine the pitch based on the number of 0 crossings
    if (crossingCount != 0)
    {
        size_t span = crossingPositions[crossingCount - 1] - crossingPositions[0];
        size_t lastCrossing = m_samplesProcessed + crossingPositions[crossingCount - 1];
        if (crossingCount == 1)
        {
            span = lastCrossing - band.lastCrossing;
            crossingCount = 2;
        }

        amplitude = float(maxSample) * (1.0f / 32768.0f);
        pitch     = float((m_resamplingRate / 2) * (crossingCount - 1)) / float(span);

        band.lastCrossing = lastCrossing;
    }
    else
    {
        amplitude = band.prevAmplitude;
        pitch     = band.prevPitch;
    }

#if 0
    // Convert pitch to an index in the frequency table and convert it back to a frequency.
    // This is to ensure that the pitch we return is one of the allowed values.
    int pitchIndex = int(roundf(logf(pitch / 10.0f) * float(32.0 / M_LN2)));
    pitchIndex = std::min(std::max(0, pitchIndex), 255);
    pitch = 10.0f * powf(2.0f, pitchIndex / 32.0f);
#else
    // Same as above but using binary search instead, which is faster
    uint32_t pitchIndex = std::min(255u, static_cast<uint32_t>(std::lower_bound(g_freqs, g_freqs + 256, pitch) - g_freqs));
    pitch = g_freqs[pitchIndex];
#endif

    band.prevSample    = prevSample;
    band.prevAmplitude = amplitude;
    band.prevPitch     = pitch;
}


} // namespace nerd::duck::iso
