﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

// 全スロットが確保済の状態からはじめて
// 1 つのスレッドが順次 Release() しながら
// 複数スレッドが同時に AcquireIndex() するテスト

#include <nn/nn_Common.h>
#include <nn/nn_Assert.h>
#include <nn/nn_Log.h>
#include <nn/os.h>
#include <nn/init.h>
#include <nnt.h>
#include <cstdlib>
#include <algorithm>

#include <nn/gfx/util/detail/gfx_IndexRingBuffer.h>
#include "testGfxUtil_WorkerThreads.h"

namespace {
    struct ThreadParameter
    {
        nn::gfx::util::detail::IndexRingBuffer* pRing;
        int acquireCount;
        int* acquiredIndices;
        int releaseCount;
        nn::gfx::util::detail::IndexRange releaseRange;
    };
    void ThreadFunction(int tid, ThreadParameter* pParam)
    {
        const int InvalidIndex = nn::gfx::util::detail::IndexRingBuffer::InvalidIndex;
        if(tid == 0)
        {
            nn::gfx::util::detail::IndexRange range = pParam->releaseRange;
            int acquiredCount = 0;
            while(range.count > 0 || acquiredCount < pParam->acquireCount)
            {
                if(range.count > 0)
                {
                    nn::gfx::util::detail::IndexRange r;
                    r.base  = range.base % pParam->pRing->GetIndexCount();
                    r.count = std::min(pParam->releaseCount, range.count);
                    pParam->pRing->ReleaseIndexRange(&r);
                    range.base  += r.count;
                    range.count -= r.count;
                }

                if(acquiredCount < pParam->acquireCount)
                {
                    int idx = pParam->pRing->AcquireIndex();
                    if(idx != InvalidIndex)
                    {
                        pParam->acquiredIndices[acquiredCount++] = idx;
                    }
                }
            }
        }
        else
        {
            for(int i = 0; i < pParam->acquireCount; i++)
            {
                int idx = InvalidIndex;
                do
                {
                    idx = pParam->pRing->AcquireIndex();
                }
                while(idx == InvalidIndex);
                pParam->acquiredIndices[i] = idx;
            }
        }
    }
}

TEST(IndexRingBuffer, MultiThreadRelease)
{
    using namespace nnt::gfxUtil;
    static const int AcquireCount = 4096;
    static const int StackSize = 1024 * 1024;
    static const int MaxNumberOfThreads = 64;
    NN_STATIC_ASSERT(StackSize % nn::os::ThreadStackAlignment == 0);

    const int InvalidIndex = nn::gfx::util::detail::IndexRingBuffer::InvalidIndex;
    nn::gfx::util::detail::IndexRingBuffer ring;

    // ワーカースレッドを用意
    WorkerThreads<ThreadParameter> wthreads;
    wthreads.Initialize(ThreadFunction, StackSize);
    const int numThreads = wthreads.GetNumberOfThreads();
    // パラメータを設定
    ThreadParameter params[MaxNumberOfThreads];
    int firstIndex = (AcquireCount * numThreads) / 2;
    for(int t = 0; t < numThreads; t++)
    {
        ThreadParameter& param = params[t];
        param.pRing = &ring;
        param.acquireCount = AcquireCount;
        param.acquiredIndices = static_cast<int*>(AlignedAllocate(sizeof(int) * AcquireCount + 1,  NN_ALIGNOF(int)));
        for(int j = 0; j < AcquireCount + 1; j++)
        {
            param.acquiredIndices[j] = InvalidIndex;
        }
        param.releaseCount = 2;
        param.releaseRange.base  = firstIndex + 1;
        param.releaseRange.count = AcquireCount * numThreads;
        wthreads.SetThreadParameter(t, &param);
    }
    wthreads.PrepareWorker();

    int ringSize = AcquireCount * numThreads + 1;
    ring.Initialize(123, ringSize);
    // 強制的に n - 1 個確保した状態にする
    {
        nn::gfx::util::detail::IndexRange range;
        ring.Begin();                                      // head = 0             , tail = ringSize - 1
        ring.AcquireIndexRange(firstIndex + 1);            // head = firstIndex + 1, tail = ringSize - 1
        ring.End(&range);
        ring.ReleaseIndexRange(&range);                    // head = firstIndex + 1, tail = ringSize + firstIndex
        ring.Begin();
        ring.AcquireIndexRange(ringSize - firstIndex - 1); // head = ringSize             , tail = ringSize + firstIndex
        ring.AcquireIndexRange(firstIndex);                // head = ringSize + firstIndex, tail = ringSize + firstIndex
        ring.End(&range);                                  // head = firstIndex           , tail = firstIndex
    }
    ring.Begin();
    {
        wthreads.StartWorker();
        wthreads.WaitWorkerComplete();
    }
    nn::gfx::util::detail::IndexRange range;
    ring.End(&range);

    // 正しく確保されているか確認
    {
        const int totalCount = AcquireCount * numThreads;
        // 確保された数が一致するか確認
        EXPECT_EQ(firstIndex, range.base);
        EXPECT_EQ(totalCount, range.count);
        // 確保された内容が正しいか確認
        {
            int* nextIndices[MaxNumberOfThreads];
            for(int t = 0; t < numThreads; t++)
            {
                nextIndices[t] = params[t].acquiredIndices;
            }
            for(int i = 0; i < totalCount; i++)
            {
                int idx = (range.base + i) % ring.GetIndexCount() + ring.GetBaseIndex();
                // どこかで確保されているはず。
                bool found = false;
                for(int t = 0; t < numThreads; t++)
                {
                    int value = *nextIndices[t];
                    if(value == idx)
                    {
                        found = true;
                        nextIndices[t]++;
                        break;
                    }
                }
                if(!found)
                {
                    // どこでも確保されていなかったらエラー
                    NN_LOG("Index %d (%d-th) is not found\n", idx, i);
                    FAIL();
                }
            }
        }
    }

    // 後始末
    for(int t = 0; t < numThreads; t++)
    {
        AlignedFree(params[t].acquiredIndices);
    }
    wthreads.Finalize();
}
