﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include "./ngc_SmartBitmap.h"

namespace nn { namespace ngc { namespace detail {

// 立っているビットの数を返す
inline int PopCnt(uint32_t x) NN_NOEXCEPT
{
    return PopCnt32(x);
}
inline int PopCnt(uint64_t x) NN_NOEXCEPT
{
    return PopCnt64(x);
}

/**
 * @brief   pBv から Rank 用辞書を作成します。
 * @param[out]  pOutLvA     作成先辞書、レベル A(256ビット区切り)
 *                          256 ビットごとに、その位置での Rank1 の結果を保存
 * @param[out]  pOutLvB     作成先辞書、レベル B(32 ビット区切り)
 *                          32 ビットごとに、直前の lv_a からの Rank1 の差分を保存
 * @param[in]   wordCount   ブロック数(1 ブロック = 32 ビット)
 * @param[in]   pBv         作成元ビットベクトル
 */
template<class BIT>
inline void BuildRankDictionaryTempl(uint32_t* pOutLvA,
                                     uint8_t* pOutLvB,
                                     unsigned int wordCount,
                                     const BIT* pBv) NN_NOEXCEPT
{
    static const int kBitCount = sizeof(BIT) * 8;  // NOLINT
    static const int tmpLvA = 256 / kBitCount;  // 8(32bit), 4(64bit)

    unsigned int r = 0;     // 今までのビットが経っているものの累積
    for (size_t i = 0; i < wordCount; ++i)
    {
        if (i % tmpLvA == 0) pOutLvA[i / tmpLvA] = r;
        pOutLvB[i] = static_cast<unsigned char>(r - pOutLvA[i / tmpLvA]);
        r += PopCnt(pBv[i]);
    }
}

/**
 * @brief   Rank を求めます
 *          あらかじめ用意した Rank 辞書を利用して O(1) で求められます
 */
template<class BIT>
inline unsigned int CalcRank1Templ(unsigned int pos,
                                   const BIT* pBv,
                                   const uint32_t* pLvA,
                                   const uint8_t* pLvB) NN_NOEXCEPT
{
    // bv[0 ... pos]にある1ビットの数を返す。
    // posはbvの範囲内にあるものとする。
    static const int kBitCount = sizeof(BIT) * 8;  // NOLINT

    size_t remain = pos % kBitCount;
    size_t idxBmp = pos / kBitCount;
    size_t idxA = pos / 256;
    size_t idxB = pos / kBitCount;
    unsigned int r = pLvA[idxA];
    r += pLvB[idxB];
    r += PopCnt(pBv[idxBmp] & (static_cast<BIT>(-1) >> (kBitCount - 1 - remain)));
    return r;
}

/**
 * @brief       pBv から Rank 用辞書を作成します。
 * @param[out]  pOutLvA     作成先辞書、レベル A(256ビット区切り)
 * @param[out]  pOutLvB     作成先辞書、レベル B(32 ビット区切り)
 * @param[in]   wordCount   ブロック数(1 ブロック = 32 ビット)
 * @param[in]   pBv         作成元ビットベクトル
 */
void BuildRankDictionary(uint32_t* pOutLvA, uint8_t* pOutLvB,
                         unsigned int wordCount, const uint32_t* pBv) NN_NOEXCEPT
{
    return BuildRankDictionaryTempl(pOutLvA, pOutLvB, wordCount, pBv);
}

void BuildRankDictionary(uint32_t* pOutLvA, uint8_t* pOutLvB,
                         unsigned int wordCount, const uint64_t* pBv) NN_NOEXCEPT
{
    return BuildRankDictionaryTempl(pOutLvA, pOutLvB, wordCount, pBv);
}

/**
 * @brief       Rank 操作を行います。[0..pos] 内に含まれる 1 の数を返します。
 * @param[in]   pos     調べる範囲
 * @param[in]   pBv     検索元ビットベクトル
 * @param[in]   pLvA    Rank 辞書A
 * @param[in]   pLvB    Rank 辞書B
 */
unsigned int CalcRank1(unsigned int pos, const uint32_t* pBv, const uint32_t* pLvA,
                       const uint8_t* pLvB) NN_NOEXCEPT
{
    return CalcRank1Templ(pos, pBv, pLvA, pLvB);
}

unsigned int CalcRank1(unsigned int pos, const uint64_t* pBv, const uint32_t* pLvA,
                       const uint8_t* pLvB) NN_NOEXCEPT
{
    return CalcRank1Templ(pos, pBv, pLvA, pLvB);
}

// 4 ビットのビット列 c のうち、 idx 番目の 1 はどの位置にあるかを一発で返す配列
// 0<= idx <= 3
// idx << 4 + c をした位置に目的の数が入っている
static const unsigned char selectPosArray[] = {
    // idx: 2bit(count(0-3)), 4bit(c)
    // value: pos(0-3), error=4
    // 下の桁から数える
    4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,  // (idx, c) = (0, c)
    4, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,  // (idx, c) = (1, c)
    4, 4, 4, 4, 4, 4, 4, 2, 4, 4, 4, 3, 4, 3, 3, 2,  // (idx, c) = (2, c)
    4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3   // (idx, c) = (3, c)
};

/**
 * @brief       32bit のビット列のうち前(0ビット目)から p 番目の 1 はどこにあるかを返す
 * @param[in]   r   ビット列
 * @param[in]   q   何番目の 1 を探すか（0 はじまり）
 * @details     16 ビットごとに区切って popcnt16 を使ってカウントして位置を求めている
 */
int SelectPos(uint32_t r, size_t p) NN_NOEXCEPT
{
    int baseNum;
    uint32_t lb;
    size_t t;
    // 下の桁から数えていく
    // pがrの範囲外の場合は動作は不定
    t = PopCntArray[lb = r & 0xFF];
    if (p >= t)
    {
        p -= t;
        t = PopCntArray[lb = (r >> 8) & 0xFF];
        if (p >= t)
        {
            p -= t;
            t = PopCntArray[lb = (r >> 16) & 0xFF];
            if (p >= t)
            {
                p -= t;
                lb = (r >> 24);
                // lb = (r >> 24) & 0xFF;
                baseNum = 24;
            }
            else
            {
                baseNum = 16;
            }
        }
        else
        {
            baseNum = 8;
        }
    }
    else
    {
        baseNum = 0;
    }
    // lbに最後のバイトが入っている
    t = PopCntArray[lb & 0xF];
    if (p >= t)
    {
        p -= t;
        baseNum += 4;
        lb = (lb >> 4) & 0xF;
    }
    else
    {
        lb = lb & 0xF;
    }
    return baseNum + selectPosArray[(p << 4) + lb];
}

int SelectPos(uint64_t r, size_t p) NN_NOEXCEPT
{
    uint32_t lo = static_cast<uint32_t>(r & 0xFFFFFFFFU);
    int popLo;
    popLo = PopCnt32(lo);
    if (popLo > static_cast<int>(p))
    {
        return SelectPos(lo, p);
    }
    uint32_t hi = static_cast<uint32_t>(r >> 32);
    return 32 + SelectPos(hi, p - popLo);
}

/**
 * @brief       nth 番目の 1 ビットの場所を返します。 nth は 0 から開始します。
 * @param[in]   nth     nth 何番目の 1 ビットの場所が知りたいか
 * @param[in]   size    ビットベクトルの長さ（ビット数）
 * @param[in]   pBv     検索元ビットベクトル
 * @param[in]   pLvA    Rank 補助データA
 * @param[in]   pLvB    Rank 補助データB
 */
template<class BIT>
inline int CalcSelect1Templ(unsigned int nth, unsigned int size,
                            const BIT* pBv, const uint32_t* pLvA,
                            const uint8_t* pLvB) NN_NOEXCEPT
{
    static const int kBitCount = sizeof(BIT) * 8;  // NOLINT

    // nthに対応するビットは存在する必要がある
    // sizeはビットマップのサイズ
    // NOTE: off by 1地獄になっている。
    int low = 0;
    // pLvA[high - 1]が最後の要素
    int high = (size + 255) / 256;
    while (low + 1 < high)
    {
        unsigned int mid = (high + low) / 2;  // オーバーフローしない
        unsigned int rank = pLvA[mid];
        if (nth < rank)
        {
            high = mid;
        }
        else
        {
            low = mid;
        }
    }

    int lvA = low;                      // nth 直前の pLvA の場所
    int lvB = lvA * (256 / kBitCount);  // pLvB の配列の時の位置
    int p = nth - pLvA[lvA];            // tnh 直前の pLvA と目的の数の差分
    int i;
    int imax;
    if (lvA == static_cast<int>(size / 256))
    {
        // 一番後ろのブロック
        imax = ((size + kBitCount - 1) / kBitCount) % (256 / kBitCount);
        if (imax == 0) imax = (256 / kBitCount);
    }
    else
    {
        // 1ブロックにいる pLvB の数 = imax
        imax = (256 / kBitCount);
    }
    for (i = 0; i < imax; ++i)
    {
        if (p - pLvB[lvB + i] < 0)
        {
            p -= pLvB[lvB + i - 1];     // ここで p は nth 直前の lv_b の位置からの差分となる
            int idx = i - 1 + lvB;      // nth 直前の lv_a, lv_b で到達できる場所
            // この時点であとは kBitCount ビットのうちどの点にいるかを算出するだけ
            return idx * kBitCount + SelectPos(pBv[idx], p);
        }
    }
    // i = max のときここに来る？
    if (i == 0)
    {
        i = 1;
    }
    p -= pLvB[lvB + i - 1];
    int idx = i - 1 + lvB;
    int ans = idx * kBitCount + SelectPos(pBv[idx], p);
    return ans;
}

template<class BIT>
inline int CalcSelect0Templ(unsigned int nth, unsigned int size,
                            const BIT* pBv, const uint32_t* pLvA,
                            const uint8_t* pLvB) NN_NOEXCEPT
{
    static const int kBitCount = sizeof(BIT) * 8;  // NOLINT

    // nthに対応するビットは存在する必要がある
    // sizeはビットマップのサイズ
    // NOTE: off by 1地獄になっている。
    int low = 0;
    // pLvA[high - 1]が最後の要素
    int high = (size + 255) / 256;
    while (low + 1 < high)
    {
        unsigned int mid = (high + low) / 2;  // オーバーフローしない
        unsigned int rank = mid * 256 - pLvA[mid];
        if (nth < rank)
        {
            high = mid;
        }
        else
        {
            low = mid;
        }
    }

    int lvA = low;
    int lvB = lvA * (256 / kBitCount);
    int p = nth - (lvA * 256 - pLvA[lvA]);
    int i;
    int imax;
    if (lvA == static_cast<int>(size / 256))
    {
        imax = ((size + kBitCount - 1) / kBitCount) % (256 / kBitCount);
        if (imax == 0)
        {
            imax = (256 / kBitCount);
        }
    }
    else
    {
        imax = (256 / kBitCount);
    }
    for (i = 0; i < imax; ++i)
    {
        if (p - (i * kBitCount - pLvB[lvB + i]) < 0)
        {
            p -= (i - 1) * kBitCount - pLvB[lvB + i - 1];
            int idx = i - 1 + lvB;
            return idx * kBitCount + SelectPos(~pBv[idx], p);
        }
    }
    if (i == 0)
    {
        i = 1;
    }
    p -= (i - 1) * kBitCount - pLvB[lvB + i - 1];
    int idx = i - 1 + lvB;
    int ans = idx * kBitCount + SelectPos(~pBv[idx], p);
    return ans;
}

/**
 * @brief   Select を求めます
 *          nth 番目の 1 ビットの場所を返します。 nth は 0 から開始します。
 */
int CalcSelect1(unsigned int nth, unsigned int size, const uint32_t* pBv, const uint32_t* pLvA,
                const uint8_t* pLvB) NN_NOEXCEPT
{
    return CalcSelect1Templ(nth, size, pBv, pLvA, pLvB);
}

int CalcSelect1(unsigned int nth, unsigned int size, const uint64_t* pBv, const uint32_t* pLvA,
                const uint8_t* pLvB) NN_NOEXCEPT
{
    return CalcSelect1Templ(nth, size, pBv, pLvA, pLvB);
}

int CalcSelect0(unsigned int nth, unsigned int size, const uint32_t* pBv, const uint32_t* pLvA,
                const uint8_t* pLvB) NN_NOEXCEPT
{
    return CalcSelect0Templ(nth, size, pBv, pLvA, pLvB);
}

int CalcSelect0(unsigned int nth, unsigned int size, const uint64_t* pBv, const uint32_t* pLvA,
                const uint8_t* pLvB) NN_NOEXCEPT
{
    return CalcSelect0Templ(nth, size, pBv, pLvA, pLvB);
}

}}} // nn::ngc::detail
