﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/


#define _USE_MATH_DEFINES
#include "Butterworth.h"
#include <nn/nn_Assert.h>
#include <nn/nn_Log.h>
#include <nn/os/os_Result.h>
#include <cmath>
#include <complex>
#include <cstdlib>


// Whether to build routines specialized for certain filter orders
#ifndef BUTTERWORTH_SPECIALIZE
    #define BUTTERWORTH_SPECIALIZE  1 // NOLINT(preprocessor/const)
#endif


// Whether to unroll filter loops
#ifndef BUTTERWORTH_UNROLL
    #define BUTTERWORTH_UNROLL      1 // NOLINT(preprocessor/const)
#endif


namespace nns
{


// fma() is nice as it brings higher precision, but unless we have hardware support for it it slows things down.
// So we only use it on ARM64 and replace it with a simple mul-add on other platforms.
#ifdef NN_BUILD_CONFIG_CPU_ARM64
using std::fma;
#else
template<typename T>
static NN_FORCEINLINE T fma(T a, T b, T c)
{
    return (a * b) + c;
}
#endif


typedef std::complex<double> Complex;


static void MultIn(const Complex& w, uint32_t npz, Complex coeffs[]) NN_NOEXCEPT
{
    Complex nw = -w;
    for (uint32_t i=npz; i>=1; --i)
    {
        coeffs[i] = (nw * coeffs[i]) + coeffs[i - 1];
    }
    coeffs[0] = nw * coeffs[0];
}


/**
 * @brief Compute product of poles or zeros as a polynomial of z
 */
static void ExpandPolynomial(const Complex input[], uint32_t order, Complex output[]) NN_NOEXCEPT
{
    output[0] = 1.0;
    for (uint32_t i=0; i<order; ++i)
    {
        output[i + 1] = 0.0;
    }
    for (uint32_t i=0; i<order; ++i)
    {
        MultIn(input[i], order, output);
    }
}


/**
 * @brief Evaluate a polynomial at a given `z` value
 */
static Complex EvaluatePoly(const Complex coeffs[], uint32_t order, const Complex z) NN_NOEXCEPT
{
    Complex sum = Complex(0.0);
    for (int i=order; i>=0; --i)
    {
        sum = (sum * z) + coeffs[i];
    }
    return sum;
}


nn::Result Butterworth::DoInit(double samplingRate, double cutoff, uint32_t order, bool highPass) NN_NOEXCEPT
{
    m_order         = 0;
    m_applyFunction = NULL;

    if (!(0 < order && order <= MaxOrder))
    {
        NN_LOG("Invalid filter order: %u (must be in range [1; %u])\n", order, unsigned(MaxOrder));
        return nn::os::ResultInvalidParameter();
    }
    if (cutoff >= 0.5 * samplingRate)
    {
        NN_LOG("Invalid cutoff frequency: %3.1f Hz (must be less than %3.1f Hz)\n", cutoff, 0.5 * samplingRate);
        return nn::os::ResultInvalidParameter();
    }

    Complex sPoles[MaxOrder];

    // Poles are evenly scattered on the left half of the unit circle.
    // We generate them in increasing cosine order so as to reduce numerical instability.
    double thetaOffset = (order % 2) ? 0.0 : 0.5;
    const int shift = CHAR_BIT * sizeof(int) - 1;
    for (int i=0; i<int(order); ++i)
    {
        int     sign = (i << shift) >> shift;
        int     index = order + ((((i + 1) >> 1) + sign) ^ sign);
        double  theta = ((index + thetaOffset) * M_PI) / order;
        Complex pole(cos(theta), sin(theta));
        sPoles[i] = pole;
        NN_ASSERT((i == 0) || (float(sPoles[i - 1].real()) <= float(pole.real())));
    }
    uint32_t poleCount = order;

    // Compute warped cutoff ratio
    double rawAlpha1 = cutoff / samplingRate;
    double warpedAlpha1 = tan(M_PI * rawAlpha1) / M_PI;

    // Normalize S-plane poles
    double w1 = 2.0 * M_PI * warpedAlpha1;
    if (!highPass)
    {
        for (uint32_t i=0; i<poleCount; ++i)
        {
            sPoles[i] = w1 * sPoles[i];
        }
    }
    else
    {
        for (uint32_t i=0; i<poleCount; ++i)
        {
            sPoles[i] = w1 / sPoles[i];
        }
    }

    // Compute Z-plane poles by bilinear transform
    Complex zPoles[MaxOrder];
    for (uint32_t i=0; i<poleCount; ++i)
    {
        zPoles[i] = (2.0 + sPoles[i]) / (2.0 - sPoles[i]);
    }

    // Expand Z-plane polynomial
    Complex zZeroes[MaxOrder];
    const Complex zZero((highPass) ? 1.0 : -1.0);
    for (uint32_t i=0; i<poleCount; ++i)
    {
        zZeroes[i] = zZero;
    }

    Complex numerator[2 * MaxOrder + 1];
    Complex denominator[2 * MaxOrder + 1];
    ExpandPolynomial(zZeroes, poleCount, numerator);
    ExpandPolynomial(zPoles,  poleCount, denominator);
    for (uint32_t i=0; i<=poleCount; ++i)
    {
        m_b[i] = +(numerator[i].real() / denominator[poleCount].real());
    }
    for (uint32_t i=0; i<=poleCount; ++i)
    {
        m_a[i] = -(denominator[i].real() / denominator[poleCount].real());
    }

    // Compute gain
    Complex gainLocation((highPass) ? -1.0 : 1.0);
    Complex gain = EvaluatePoly(numerator, order, gainLocation) / EvaluatePoly(denominator, order, gainLocation);
    m_invGain = static_cast<Float>(1.0 / hypot(gain.real(), gain.imag()));

    // Init filter memory
    for (uint32_t i=0; i<=order; ++i)
    {
        m_x[i] = 0.0;
        m_y[i] = 0.0;
    }

    m_order         = order;
    m_applyFunction = &Butterworth::ApplyGeneric;

#if BUTTERWORTH_SPECIALIZE
    switch (order)
    {
        case 2:
            m_applyFunction = &Butterworth::Apply2;
            break;
        case 4:
            m_applyFunction = &Butterworth::Apply4;
            break;
        case 6:
            m_applyFunction = &Butterworth::Apply6;
            break;
        default:
            break;
    }
#endif

    return nn::ResultSuccess();
}


template<typename T>
static inline int16_t Clamp16(T val) NN_NOEXCEPT
{
    if (val > T(32767))
    {
        return 32767;
    }
    if (val < T(-32768))
    {
        return -32768;
    }
    return static_cast<int16_t>(val);
}


/**
 * @brief Generic filter routine
 */
void Butterworth::ApplyGeneric(const int16_t* srce, int16_t* dest, size_t length, size_t stride) NN_NOEXCEPT
{
    const uint32_t order   = m_order;
    const Float    invGain = m_invGain;
    const Float* a = m_a;
    const Float* b = m_b;
    Float* x = m_x;
    Float* y = m_y;

    for (size_t i=0; i<length; ++i)
    {
        Float s = *srce;
        srce += stride;
        s *= invGain;

        Float xd = s;
        Float yd = 0;

        #define CASE(j)                                   \
            case j:                                       \
                x[order - j] = x[order - j + 1];          \
                y[order - j] = y[order - j + 1];          \
                xd = fma(b[order - j], x[order - j], xd); \
                yd = fma(a[order - j], y[order - j], yd);

        static_assert(MaxOrder <= 10, "");
        switch (order)
        {
            CASE(10)
            CASE(9)
            CASE(8)
            CASE(7)
            CASE(6)
            CASE(5)
            CASE(4)
            CASE(3)
            CASE(2)
            CASE(1)
                break;
            default:
                NN_UNEXPECTED_DEFAULT;
        }

        #undef CASE

        Float d = xd + yd;
        x[order] = s;
        y[order] = d;

        *dest++ = Clamp16(d);
    }
}


//=================================================================================================


#if BUTTERWORTH_SPECIALIZE

/**
 * @brief Generic filter routine
 */
void Butterworth::Apply2(const int16_t* srce, int16_t* dest, size_t length, size_t stride) NN_NOEXCEPT
{
    NN_ASSERT(m_order == 2u);

    const Float a0 = m_a[0];
    const Float a1 = m_a[1];
    const Float b1 = m_b[1];

    Float x0 = m_x[0];
    Float x1 = m_x[1];
    Float x2 = m_x[2];

    Float y0 = m_y[0];
    Float y1 = m_y[1];
    Float y2 = m_y[2];

    const Float invGain = m_invGain;

#if BUTTERWORTH_UNROLL
    while (length >= 3u)
    {
        Float s0 = *srce; srce += stride;
        Float s1 = *srce; srce += stride;
        Float s2 = *srce; srce += stride;

        s0 *= invGain;
        s1 *= invGain;
        s2 *= invGain;

        y0 = a0 * y1;           x0 = x1 + s0;
        y1 = a0 * y2;           x1 = x2 + s1;
        y0 = fma(a1, y2, y0);   x0 = fma(b1, x2, x0);
        y0 = x0 + y0;           x0 = s0;
        y1 = fma(a1, y0, y1);   x1 = fma(b1, x0, x1);
        y2 = a0 * y0;           x2 = x0 + s2;
        y1 = x1 + y1;           x1 = s1;
        y2 = fma(a1, y1, y2);   x2 = fma(b1, x1, x2);
        y2 = x2 + y2;           x2 = s2;

        *dest++ = Clamp16(y0);
        *dest++ = Clamp16(y1);
        *dest++ = Clamp16(y2);

        length -= 3;
    }
#endif

    for (size_t i=0; i<length; ++i)
    {
        Float s = *srce;
        srce += stride;

        s *= invGain;

        x0 = x1;    y0 = y1;
        x1 = x2;    y1 = y2;
        x2 = s;

        Float rx, ry;
        rx = x0 + x2;
        ry = a0 * y0;
        rx = fma(b1, x1, rx);
        ry = fma(a1, y1, ry);

        y2      = rx + ry;
        *dest++ = Clamp16(y2);
    }

    m_x[0] = x0;    m_y[0] = y0;
    m_x[1] = x1;    m_y[1] = y1;
    m_x[2] = x2;    m_y[2] = y2;
}


/**
 * @brief Filter function for 4th order filter
 */
void Butterworth::Apply4(const int16_t* srce, int16_t* dest, size_t length, size_t stride) NN_NOEXCEPT
{
    NN_ASSERT(m_order == 4u);

    const Float a0 = m_a[0];
    const Float a1 = m_a[1];
    const Float a2 = m_a[2];
    const Float a3 = m_a[3];

    const Float b1 = m_b[1];
    const Float b2 = m_b[2];

    Float x0 = m_x[0];
    Float x1 = m_x[1];
    Float x2 = m_x[2];
    Float x3 = m_x[3];
    Float x4 = m_x[4];

    Float y0 = m_y[0];
    Float y1 = m_y[1];
    Float y2 = m_y[2];
    Float y3 = m_y[3];
    Float y4 = m_y[4];

    const Float invGain = m_invGain;

#if BUTTERWORTH_UNROLL
    while (length >= 5u)
    {
        Float s0 = *srce; srce += stride;
        Float s1 = *srce; srce += stride;
        Float s2 = *srce; srce += stride;
        Float s3 = *srce; srce += stride;
        Float s4 = *srce; srce += stride;

        s0 *= invGain;
        s1 *= invGain;
        s2 *= invGain;
        s3 *= invGain;
        s4 *= invGain;

        y0 = a0 * y1;            x0 = x1 + s0;
        y1 = a0 * y2;            x1 = x2 + s1;
        y0 = fma(a1, y2, y0);    x0 = fma(b1, x2, x0);
        y2 = a0 * y3;            x2 = x3 + s2;
        y1 = fma(a1, y3, y1);    x1 = fma(b1, x3, x1);
        y2 = fma(a1, y4, y2);    x2 = fma(b1, x4, x2);
        y0 = fma(a2, y3, y0);    x0 = fma(b2, x3, x0);
        y3 = a0 * y4;            x3 = x4 + s3;
        y1 = fma(a2, y4, y1);    x1 = fma(b2, x4, x1);
        y0 = fma(a3, y4, y0);    x0 = fma(b1, x4, x0);
        y0 = x0 + y0;            x0 = s0;
        y1 = fma(a3, y0, y1);    x1 = fma(b1, s0, x1);
        y2 = fma(a2, y0, y2);    x2 = fma(b2, s0, x2);
        y3 = fma(a1, y0, y3);    x3 = fma(b1, s0, x3);
        *dest++ = Clamp16(y0);
        y4 = a0 * y0;            x4 = x0 + s4;
        y1 = x1 + y1;            x1 = s1;
        y2 = fma(a3, y1, y2);    x2 = fma(b1, x1, x2);
        y3 = fma(a2, y1, y3);    x3 = fma(b2, x1, x3);
        y4 = fma(a1, y1, y4);    x4 = fma(b1, x1, x4);
        *dest++ = Clamp16(y1);
        y2 = x2 + y2;            x2 = s2;
        y3 = fma(a3, y2, y3);    x3 = fma(b1, x2, x3);
        y4 = fma(a2, y2, y4);    x4 = fma(b2, x2, x4);
        y3 = x3 + y3;            x3 = s3;
        *dest++ = Clamp16(y2);
        y4 = fma(a3, y3, y4);    x4 = fma(b1, x3, x4);
        y4 = x4 + y4;            x4 = s4;
        *dest++ = Clamp16(y3);
        *dest++ = Clamp16(y4);

        length -= 5;
    }
#endif

    for (size_t i=0; i<length; ++i)
    {
        Float s = *srce;
        srce += stride;

        s *= invGain;

        x0 = x1;    y0 = y1;
        x1 = x2;    y1 = y2;
        x2 = x3;    y2 = y3;
        x3 = x4;    y3 = y4;
        x4 = s;

        Float x13 = x1 + x3;

        Float r0, r1, r2;
        r0 = fma(a0, y0, x0);
        r1 = fma(a1, y1, x4);
        r2 = a2 * y2;
        r0 = fma(a3, y3,  r0);
        r1 = fma(b1, x13, r1);
        r2 = fma(b2, x2,  r2);

        y4      = r0 + r1 + r2;
        *dest++ = Clamp16(y4);
    }

    m_x[0] = x0;    m_y[0] = y0;
    m_x[1] = x1;    m_y[1] = y1;
    m_x[2] = x2;    m_y[2] = y2;
    m_x[3] = x3;    m_y[3] = y3;
    m_x[4] = x4;    m_y[4] = y4;
}


/**
 * @brief Filter function for 6th order filter
 */
void Butterworth::Apply6(const int16_t* srce, int16_t* dest, size_t length, size_t stride) NN_NOEXCEPT
{
    NN_ASSERT(m_order == 6u);

    const Float a0 = m_a[0];
    const Float a1 = m_a[1];
    const Float a2 = m_a[2];
    const Float a3 = m_a[3];
    const Float a4 = m_a[4];
    const Float a5 = m_a[5];

    const Float b1 = m_b[1];
    const Float b2 = m_b[2];
    const Float b3 = m_b[3];

    Float x0 = m_x[0];
    Float x1 = m_x[1];
    Float x2 = m_x[2];
    Float x3 = m_x[3];
    Float x4 = m_x[4];
    Float x5 = m_x[5];
    Float x6 = m_x[6];

    Float y0 = m_y[0];
    Float y1 = m_y[1];
    Float y2 = m_y[2];
    Float y3 = m_y[3];
    Float y4 = m_y[4];
    Float y5 = m_y[5];
    Float y6 = m_y[6];

    const Float invGain = m_invGain;

#if BUTTERWORTH_UNROLL
    while (length >= 7u)
    {
        Float s0 = *srce; srce += stride;
        Float s1 = *srce; srce += stride;
        Float s2 = *srce; srce += stride;
        Float s3 = *srce; srce += stride;
        Float s4 = *srce; srce += stride;
        Float s5 = *srce; srce += stride;
        Float s6 = *srce; srce += stride;

        s0 *= invGain;
        s1 *= invGain;
        s2 *= invGain;
        s3 *= invGain;
        s4 *= invGain;
        s5 *= invGain;
        s6 *= invGain;

        y0 = a0 * y1;            x0 = x1 + s0;
        y1 = a0 * y2;            x1 = x2 + s1;
        y0 = fma(a1, y2, y0);    x0 = fma(b1, x2, x0);
        y1 = fma(a1, y3, y1);    x1 = fma(b1, x3, x1);
        y2 = a0 * y3;            x2 = x3 + s2;
        y0 = fma(a2, y3, y0);    x0 = fma(b2, x3, x0);
        y1 = fma(a2, y4, y1);    x1 = fma(b2, x4, x1);
        y2 = fma(a1, y4, y2);    x2 = fma(b1, x4, x2);
        y3 = a0 * y4;            x3 = x4 + s3;
        y0 = fma(a3, y4, y0);    x0 = fma(b3, x4, x0);
        y1 = fma(a3, y5, y1);    x1 = fma(b3, x5, x1);
        y2 = fma(a2, y5, y2);    x2 = fma(b2, x5, x2);
        y3 = fma(a1, y5, y3);    x3 = fma(b1, x5, x3);
        y0 = fma(a4, y5, y0);    x0 = fma(b2, x5, x0);
        y1 = fma(a4, y6, y1);    x1 = fma(b2, x6, x1);
        y4 = a0 * y5;            x4 = x5 + s4;
        y2 = fma(a3, y6, y2);    x2 = fma(b3, x6, x2);
        y0 = fma(a5, y6, y0);    x0 = fma(b1, x6, x0);
        y3 = fma(a2, y6, y3);    x3 = fma(b2, x6, x3);
        y4 = fma(a1, y6, y4);    x4 = fma(b1, x6, x4);
        y5 = a0 * y6;            x5 = x6 + s5;
        y0 = x0 + y0;            x0 = s0;
        y1 = fma(a5, y0, y1);    x1 = fma(b1, x0, x1);
        y2 = fma(a4, y0, y2);    x2 = fma(b2, x0, x2);
        y3 = fma(a3, y0, y3);    x3 = fma(b3, x0, x3);
        y4 = fma(a2, y0, y4);    x4 = fma(b2, x0, x4);
        y1 = x1 + y1;            x1 = s1;
        *dest++ = Clamp16(y0);
        y5 = fma(a1, y0, y5);    x5 = fma(b1, x0, x5);
        y6 = a0 * y0;            x6 = x0 + s6;
        y2 = fma(a5, y1, y2);    x2 = fma(b1, x1, x2);
        y3 = fma(a4, y1, y3);    x3 = fma(b2, x1, x3);
        y4 = fma(a3, y1, y4);    x4 = fma(b3, x1, x4);
        y5 = fma(a2, y1, y5);    x5 = fma(b2, x1, x5);
        y2 = x2 + y2;            x2 = s2;
        *dest++ = Clamp16(y1);
        y6 = fma(a1, y1, y6);    x6 = fma(b1, x1, x6);
        y3 = fma(a5, y2, y3);    x3 = fma(b1, x2, x3);
        y4 = fma(a4, y2, y4);    x4 = fma(b2, x2, x4);
        y5 = fma(a3, y2, y5);    x5 = fma(b3, x2, x5);
        y6 = fma(a2, y2, y6);    x6 = fma(b2, x2, x6);
        y3 = x3 + y3;            x3 = s3;
        *dest++ = Clamp16(y2);
        y4 = fma(a5, y3, y4);    x4 = fma(b1, x3, x4);
        y5 = fma(a4, y3, y5);    x5 = fma(b2, x3, x5);
        y6 = fma(a3, y3, y6);    x6 = fma(b3, x3, x6);
        y4 = x4 + y4;            x4 = s4;
        y5 = fma(a5, y4, y5);    x5 = fma(b1, x4, x5);
        y6 = fma(a4, y4, y6);    x6 = fma(b2, x4, x6);
        y5 = x5 + y5;            x5 = s5;
        y6 = fma(a5, y5, y6);    x6 = fma(b1, x5, x6);
        y6 = x6 + y6;            x6 = s6;
        *dest++ = Clamp16(y3);
        *dest++ = Clamp16(y4);
        *dest++ = Clamp16(y5);
        *dest++ = Clamp16(y6);

        length -= 7;
    }
#endif

    for (size_t i=0; i<length; ++i)
    {
        Float s = *srce;
        srce += stride;

        s *= invGain;

        x0 = x1;    y0 = y1;
        x1 = x2;    y1 = y2;
        x2 = x3;    y2 = y3;
        x3 = x4;    y3 = y4;
        x4 = x5;    y4 = y5;
        x5 = x6;    y5 = y6;
        x6 = s;

        Float x15 = x1 + x5;
        Float x24 = x2 + x4;

        Float r0, r1, r2, r3;
        r0  = fma(a0, y0, x0);
        r1  = fma(a1, y1, x6);
        r2  = a2 * y2;
        r3  = a3 * y3;
        r0  = fma(a4, y4,  r0);
        r1  = fma(a5, y5,  r1);
        r2  = fma(b1, x15, r2);
        r3  = fma(b2, x24, r3);
        r0  = fma(b3, x3,  r0);

        y6      = (r1 + r2) + (r3 + r0);
        *dest++ = Clamp16(y6);
    }

    m_x[0] = x0;    m_y[0] = y0;
    m_x[1] = x1;    m_y[1] = y1;
    m_x[2] = x2;    m_y[2] = y2;
    m_x[3] = x3;    m_y[3] = y3;
    m_x[4] = x4;    m_y[4] = y4;
    m_x[5] = x5;    m_y[5] = y5;
    m_x[6] = x6;    m_y[6] = y6;
} // NOLINT(impl/function_size)


#endif // BUTTERWORTH_SPECIALIZE


} // namespace nns
