﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/socket/resolver/sfdl/resolver.sfdl.h>
#include <nn/os/os_Thread.h>
#include <nn/os/os_SdkThreadLocalStorage.h>
#include <nn/os/os_MutexApi.h>
#include <nn/os.h>
#include <nn/socket/resolver/private/resolver_PrivateApi.h>
#include <nn/os/os_Result.public.h>
#include "resolver_ThreadLocalStorage.h"
#include <nn/nn_SdkLog.h>
#include "../../detail/socket_Allocator.h"
#include <vector>
#include <mutex>

namespace nn { namespace socket { namespace resolver { namespace tls {

//#define SOCKET_TRACE_ALLOCATIONS

#ifdef SOCKET_TRACE_ALLOCATIONS
#define NN_SOCKET_ALLOCATION_LOG NN_SDK_LOG
#else
#define NN_SOCKET_ALLOCATION_LOG(...)
#endif

nn::os::TlsSlot g_TlsSlot;
bool g_DidAlreadyAllocateTlsSlot = false;

/**
 * @brief this structure maintains the client and server thread local structure
 */
class ResolverThreadLocalStoragePair
{
private:
    /**
     * @brief a pointer to the client tls
     */
    void* m_ClientTLS;

    /**
     * @brief the client tls destructor
     * @note this is provided as a function pointer so that the
     * client does not depend on server / bionic linkage.
     */
    void (*m_ClientTLSDestructor)(void* clientTlsToDestroy);

    /**
     * @brief a pointer to the server tls
     */
    void* m_ServerTLS;

    /**
     * @brief the server tls destructor
     * @note this is provided as a function pointer so that the
     * client does not depend on server / bionic linkage.
     */
    void (*m_ServerTLSDestructor)(void* serverTlsToDestroy);

public:
    void* operator new(size_t size) NN_NOEXCEPT
    {
        return nn::socket::detail::AllocAligned(size, nn::socket::detail::MinimumHeapAlignment);
    };

    void operator delete(void* pointer, size_t size) NN_NOEXCEPT
    {
        nn::socket::detail::Free(pointer);
    };

    /** @brief constructor */
    ResolverThreadLocalStoragePair() NN_NOEXCEPT :
        m_ClientTLS( NULL ),
        m_ClientTLSDestructor( NULL ),
        m_ServerTLS( NULL ),
        m_ServerTLSDestructor( NULL )
    {
    };

    /** @brief destructor */
    ~ResolverThreadLocalStoragePair() NN_NOEXCEPT
    {
        NN_SOCKET_ALLOCATION_LOG("~ResolverThreadLocalStoragePair() \n");
        NN_SOCKET_ALLOCATION_LOG("%p %p %p %p \n",m_ClientTLS, m_ClientTLSDestructor, m_ServerTLS, m_ServerTLSDestructor);

        SetClient(NULL, NULL);
        SetServer(NULL, NULL);
    };

    void* GetClient()
    {
        return m_ClientTLS;
    };

    void SetClient(void * tls, void (*destructor)(void* destructor) )
    {
        NN_SOCKET_ALLOCATION_LOG("SetClient %p %p \n", tls, destructor);

        if (m_ClientTLS == tls)
        {
            NN_SOCKET_ALLOCATION_LOG("SetClient 1 \n");
            m_ClientTLSDestructor = destructor;
            goto bail;
        }
        else if (NULL != m_ClientTLS && NULL != m_ClientTLSDestructor)
        {
            NN_SOCKET_ALLOCATION_LOG("SetClient 2 %p(%p)\n", m_ClientTLSDestructor, m_ClientTLS);
            m_ClientTLSDestructor(m_ClientTLS);
            m_ClientTLS = NULL;
        }

        NN_SOCKET_ALLOCATION_LOG("SetClient End \n");
        m_ClientTLS = tls;
        m_ClientTLSDestructor = destructor;

    bail:
        return;
    }

    void* GetServer()
    {
        return m_ServerTLS;
    };

    void SetServer(void * tls, void (*destructor)(void* tls) )
    {
        if (m_ServerTLS == tls)
        {
            m_ServerTLSDestructor = destructor;
            goto bail;
        }
        else if (NULL != m_ServerTLS && NULL != m_ServerTLSDestructor)
        {
            m_ServerTLSDestructor(m_ServerTLS);
            m_ServerTLS = NULL;
        };

        m_ServerTLS = tls;
        m_ServerTLSDestructor = destructor;

    bail:
        return;
    }
};

int GetHeapGeneration(ResolverThreadLocalStoragePair *value)
{
    // Retreive the lower bits (generation).
    return ((int64_t)value) & (nn::socket::detail::MinimumHeapAlignment - 1);
}

void GetTlsPairPointer(ResolverThreadLocalStoragePair *& value)
{
    // Remove the generation bits.
    value = reinterpret_cast<ResolverThreadLocalStoragePair*>(
                            ((int64_t)value & ~(nn::socket::detail::MinimumHeapAlignment - 1)));
}

void AddGenerationToTlsValue(ResolverThreadLocalStoragePair *& value)
{
    // Use the lower bits to store heap generation
    value  = reinterpret_cast<ResolverThreadLocalStoragePair*>(
                ((int64_t)value | socket::detail::GetHeapGeneration()));
}

Result GetCreateTLSPair(ResolverThreadLocalStoragePair *& pTlsPairOut)
{
    Result result = ResultInternalError();
    int generation = 0;

    if (true == g_DidAlreadyAllocateTlsSlot)
    {
        // Read the pointer from TLS
        pTlsPairOut = reinterpret_cast<ResolverThreadLocalStoragePair*>(nn::os::GetTlsValue(g_TlsSlot));

        generation = GetHeapGeneration(pTlsPairOut);
        if (!socket::detail::HeapIsAvailable(generation))
        {
            NN_SOCKET_ALLOCATION_LOG("**TLS %d abandoned %p\n", g_TlsSlot, pTlsPairOut);
            nn::os::SetTlsValue(g_TlsSlot, NULL);
            pTlsPairOut = NULL;
        };

        if (NULL == pTlsPairOut)
        {
            // If not there yet, create it
            if (NULL == (pTlsPairOut = new ResolverThreadLocalStoragePair()))
            {
                NN_SDK_LOG("Unable to allocate ResolverThreadLocalStoragePair\n");
                NN_SDK_ASSERT(false);
                goto bail;
            };

            // Add in the heap generation
            AddGenerationToTlsValue(pTlsPairOut);

            NN_SOCKET_ALLOCATION_LOG("**g_TlsSlot %d receiging a new pair %p\n", g_TlsSlot, pTlsPairOut);
            nn::os::SetTlsValue(g_TlsSlot, reinterpret_cast<uintptr_t>(pTlsPairOut));
        };

        // Retreive the pure pointer back.
        GetTlsPairPointer(pTlsPairOut);

        result = ResultSuccess();
    };

bail:
    return result;
};

/** @brief This function is called when the thread exits */
void TlsDestructFunction(uintptr_t value) NN_NOEXCEPT
{
    NN_SOCKET_ALLOCATION_LOG("TlsDestructFunction (%p) \n", value);

    ResolverThreadLocalStoragePair* pTlsPair =
        reinterpret_cast<ResolverThreadLocalStoragePair*>(value);

    // Extract generation from pointer.
    int generation = GetHeapGeneration(pTlsPair);

    // Now remove the generation info so we can use the pointer.
    GetTlsPairPointer(pTlsPair);

    NN_SOCKET_ALLOCATION_LOG("TlsDestructFunction (%p) \n", pTlsPair);

    if ( pTlsPair != NULL )
    {
        if(!socket::detail::HeapIsAvailable(generation))
        {
            NN_SOCKET_ALLOCATION_LOG("Nothing to do\n");
            return;
        }

        NN_SOCKET_ALLOCATION_LOG("deleting pTlsPair %p\n", pTlsPair);
        delete pTlsPair; // deletes both client and server TLS
    }
}

/**
 * @brief initialize the Resolver thread local storage implementation
 * @returns result code indicating success or failure
 */
Result Initialize() NN_NOEXCEPT
{
    Result result = ResultInternalError();

    if (true == g_DidAlreadyAllocateTlsSlot)
    {
        result = ResultSuccess();
        goto bail;
    }
    else if (false == g_DidAlreadyAllocateTlsSlot)
    {
        if ((result = nn::os::SdkAllocateTlsSlot(&g_TlsSlot, TlsDestructFunction)).IsFailure())
        {
            NN_SDK_LOG("Unable to allocate resolver thread-local storage slot\n");
            goto bail;
        };

        // set the current value to 0 / NULL
        nn::os::SetTlsValue(g_TlsSlot, 0);
        g_DidAlreadyAllocateTlsSlot = true;
    };

bail:
    return result;
};

/**
 * @brief finalize the Resolver thread local storage implementation
 * @returns result code indicating success or failure
 */

Result Finalize() NN_NOEXCEPT
{
    NN_SOCKET_ALLOCATION_LOG("resolver TLS Finalize\n");
    Result result = ResultSuccess();

    return result;
};


/**
 * @brief get server tls
 * @param pServerTLSOut the TLS structure
 * @returns result code indicating success or failure
 */
Result GetServerTLS(void *& pServerTLSOut) NN_NOEXCEPT
{
    Result result = ResultInternalError();
    ResolverThreadLocalStoragePair* pTlsPair = NULL;
    pServerTLSOut = NULL;

    if ((result = GetCreateTLSPair(pTlsPair)).IsSuccess() || NULL == pTlsPair )
    {
        pServerTLSOut = pTlsPair->GetServer();
    };

    return result;
};

/**
 * @brief set server tls
 * @param destructor a function pointer called to destroy the thread local storage
 * @param pServerTLSIn the TLS structure
 * @returns result code indicating success or failure
 */
Result SetServerTLS(TLSDestructorFunction destructor, void * pServerTLSIn) NN_NOEXCEPT
{
    Result result = ResultInternalError();
    ResolverThreadLocalStoragePair* pTlsPair = NULL;

    if ((result = GetCreateTLSPair(pTlsPair)).IsFailure() || pTlsPair == NULL)
    {
        goto bail;
    };

    pTlsPair->SetServer(pServerTLSIn, destructor);
    result = ResultSuccess();

bail:
    return result;
};

/**
 * @brief get client tls
 * @param pClientTLSOut the TLS structure
 * @returns result code indicating success or failure
 */
Result GetClientTLS(void *& pClientTlsOut) NN_NOEXCEPT
{
    Result result = ResultInternalError();
    ResolverThreadLocalStoragePair* pTlsPair = NULL;
    pClientTlsOut = NULL;

    if ((result = GetCreateTLSPair(pTlsPair)).IsSuccess() || NULL == pTlsPair)
    {
        pClientTlsOut = pTlsPair->GetClient();
    };

    return result;
};


/**
 * @brief set client tls
 * @param destructor a function pointer called to destroy the thread local storage
 * @param pClientTLSIn the TLS structure
 * @returns result code indicating success or failure
 */
Result SetClientTLS(TLSDestructorFunction destructor, void * pClientTLSIn) NN_NOEXCEPT
{
    Result result = ResultInternalError();
    ResolverThreadLocalStoragePair* pTlsPair = NULL;

    if ((result = GetCreateTLSPair(pTlsPair)).IsFailure() )
    {
        goto bail;
    };

    pTlsPair->SetClient(pClientTLSIn, destructor);
    result = ResultSuccess();

bail:
    return result;
};

}}}} // nn::socket::resolver::tls


/**
 * called by bionic
 */
extern "C"
{
/**
 * @brief wrapper for pthread_getspecific
 */
void *nnResolverGetspecificPrivate()
{
    void* pValueOut = NULL;
    nn::socket::resolver::tls::GetServerTLS(pValueOut);
    return pValueOut;
}

/**
 * @brief wrapper for pthread_setspecific
 */
int nnResolverSetspecificPrivate(const void *pValueIn)
{
    int rc = -1;

    if ( nn::socket::resolver::tls::SetServerTLS(NULL, const_cast<void*>(pValueIn)).IsSuccess() )
    {
        rc = 0;
    };
    return rc;
}

}
