﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <nn/nn_Assert.h>
#include <nn/nn_Result.h>
#include <nn/sf/sf_ExpHeapAllocator.h>
#include <nn/sf/sf_ObjectFactory.h>

#include <nn/ssl/ssl_BuiltInManager.h>
#include <nn/ssl/ssl_SessionCache.h>

#include "server/ssl_NssCommon.h"
#include "server/ssl_NssCore.h"
#include "server/ssl_MemoryManager.h"
#include "server/ssl_ServiceDatabase.h"
#include "server/ssl_SslServiceImpl.h"
#include "server/ssl_SslContextImpl.h"
#include "server/ssl_SslTrustedCertManager.h"
#include "server/ssl_Util.h"
#include "detail/ssl_ISslServiceFactory.h"

using namespace nn::sf;
using namespace nn::ssl::sf;

namespace nn { namespace ssl { namespace detail {


SslServiceImpl::SslServiceImpl() NN_NOEXCEPT
    : m_ClientProcessId(0), m_ISslServiceVersion(ISslServiceFactory::ISslServiceVersionType_Launch)
{
}

SslServiceImpl::~SslServiceImpl() NN_NOEXCEPT
{
    // Flush all session caches tight to this process
    nnsdkNssPortClearProcessSessionCache(m_ClientProcessId, nullptr);
}

nn::Result SslServiceImpl::CreateContext(Out<SharedPointer<ISslContext>> outContext,
                                         SslVersion version,
                                         nn::Bit64 processId) NN_NOEXCEPT
{
    nn::Result                  ret = ResultSuccess();
    typedef ObjectFactory<ExpHeapAllocator::Policy>   Factory;

    do
    {
        NN_DETAIL_SSL_DBG_PRINT("[CreateContext] creating context with version flags 0x%X\n",
                            version.flags);

        EmplacedRef<ISslContext, SslContextImpl> er =
            Factory::CreateSharedEmplaced<ISslContext, SslContextImpl>(SslMemoryManager::GetSfAllocator(),
                                                                       this,
                                                                       version.flags);
        if ((er == nullptr) || (er.Get() == nullptr))
        {
            NN_DETAIL_SSL_DBG_PRINT("[CreateContext] failed to create ctx\n");
            ret = ResultInsufficientMemory();
            break;
        }

        //  Track the new context in the database.  If this add fails it usually
        //  means we have hit a resource limit.  Clear the EmplacedRef to free
        //  the service object.
        SslContextImpl &ctx = er.GetImpl();
        ret = SslServiceDatabase::AddSslContext(&ctx);
        if (ret.IsFailure())
        {
            NN_DETAIL_SSL_DBG_PRINT("[CreateContext] failed to track new ctx: (ret %d:%d)\n",
                                    ret.GetModule(),
                                    ret.GetDescription());
            static_cast<SharedPointer<ISslContext>>(er) = nullptr;
            break;
        }

        //  Return the new service object interface to the caller
        *outContext = er;
        m_ClientProcessId = processId;
    } while (NN_STATIC_CONDITION(false));

    return Util::ConvertResultFromInternalToExternal(ret);
}


nn::Result SslServiceImpl::GetContextCount(Out<uint32_t> outValue) NN_NOEXCEPT
{
    nn::Result                  ret = ResultSuccess();
    uint32_t                    count;

    do
    {
        ret = SslServiceDatabase::GetContextCount(this, &count);
        *outValue = count;
    } while (NN_STATIC_CONDITION(false));

    return Util::ConvertResultFromInternalToExternal(ret);
}


nn::Result SslServiceImpl::GetAllIds(uint32_t **pOutIds, uint32_t *pOutNumIds)
{
    nn::Result                  result = ResultSuccess();
    uint32_t                    *pAllCastIds = nullptr;
    CaCertificateId             *pAllIds = nullptr;
    uint32_t                    idCount = 0;

    do
    {
        result = TrustedCertManager::GetAllIds(&pAllIds, &idCount);
        if (result.IsFailure())
        {
            result = ResultErrorLower();
            break;
        }

        pAllCastIds = new uint32_t[idCount];
        if (pAllCastIds == nullptr)
        {
            NN_DETAIL_SSL_DBG_PRINT("[GetAllIds] unable to get ID scratch buf\n");
            result = ResultErrorLower();
            break;
        }

        for (uint32_t i = 0; i < idCount; i++)
        {
            pAllCastIds[i] = static_cast<uint32_t>(pAllIds[i]);
        }

        *pOutIds = pAllCastIds;
        *pOutNumIds = idCount;
        pAllCastIds = nullptr;
    } while (NN_STATIC_CONDITION(false));

    if (pAllCastIds != nullptr)
    {
        delete[] pAllCastIds;
        pAllCastIds = nullptr;
    }

    if (pAllIds != nullptr)
    {
        delete[] pAllIds;
        pAllIds = nullptr;
    }

    return result;
}


nn::Result SslServiceImpl::GetCertificateBufSize(uint32_t        *pOutSize,
                                                 bool            *pOutAll,
                                                 const uint32_t  *pIds,
                                                 uint32_t        idCount)
{
    nn::Result                  result = ResultSuccess();
    uint32_t                    size;
    bool                        needAll = false;
    uint32_t                    *pAllIds = nullptr;
    uint32_t                    numBci = idCount;

    do
    {
        //  Walk the array of IDs and see if "all" has been requested
        for (uint32_t i = 0; (i < idCount) && !needAll; i++)
        {
            needAll = TrustedCertManager::IsAllId(static_cast<CaCertificateId>(pIds[i]));
        }

        //  If "all" has been requested, replace the passed in ID array
        //  with the list of all IDs.
        if (needAll)
        {
            result = GetAllIds(&pAllIds, &idCount);
            if (result.IsFailure())
            {
                break;
            }

            pIds = pAllIds;

            //  We need 1 extra BuiltInCertificateInfo entry for "all"
            //  so the caller knows where the end is.
            numBci = idCount + 1;
        }

        size = sizeof(BuiltInManager::BuiltInCertificateInfo) * numBci;

        //  Iterate over the IDs, get the size of each cert and add to our total
        for (uint32_t i = 0; i < idCount; i++)
        {
            BuiltinDataInfo     *pBdi;

            result =
                TrustedCertManager::GetTrustedCertInfo(&pBdi,
                                                       static_cast<CaCertificateId>(pIds[i]));
            if (result.IsFailure())
            {
                if (ResultInvalidCertStoreId::Includes(result))
                {
                    //  Skip over invalid IDs, don't count any cert data
                    //  in the size, but we still need the BuiltInCertificateInfo
                    //  entry for it.
                    result = ResultSuccess();
                    continue;
                }

                NN_DETAIL_SSL_DBG_PRINT("[GetCertificateBufSize] failed to get cert size for id %X\n",
                                        pIds[i]);
                break;
            }

            size += NN_DETAIL_SSL_CALC_ALIGN(pBdi->GetDataSize(), sizeof(uint32_t));
        }

        *pOutSize = size;
        *pOutAll = needAll;
    } while(NN_STATIC_CONDITION(false));

    if (pAllIds != nullptr)
    {
        delete[] pAllIds;
        pAllIds = nullptr;
    }

    return result;
}


nn::Result SslServiceImpl::GetCertificates(const OutBuffer  &outCertsData,
                                           Out<uint32_t>    outIdCount,
                                           const InBuffer   &inCertIds)
{
    nn::Result                              result = ResultSuccess();
    uint8_t                                 *pOutBuf = nullptr;
    BuiltInManager::BuiltInCertificateInfo  *pCertInfoArray = nullptr;
    uint32_t                                outBufSize;
    const uint32_t                          *pIds;
    uint32_t                                idCount;
    uint32_t                                numBci;
    uint32_t                                sizeNeeded;
    uint32_t                                offset;
    bool                                    needAll = false;
    uint32_t                                *pAllIds = nullptr;

    do
    {
        //  Wait for NSS core to be ready before we allow trusted cert ops
        NssCore::InitStatus nssStatus =
            NssCore::GetInitStatus(TrustedCertManager::g_TcmDeferredInitId, true);
        if (nssStatus == NssCore::InitStatus_InitFail)
        {
            NN_DETAIL_SSL_DBG_PRINT("[GetCertificates] NSS init failed\n");
            result = ResultErrorLower();
            break;
        }

        //  Determine what size output buffer is needed.  Note that
        //  GetCertificateBufSize will handle the "all" case, but we will also
        //  need to do it later in order to get the real data.
        idCount = static_cast<uint32_t>(inCertIds.GetSize() /
                                        sizeof(uint32_t));
        numBci = idCount;
        pIds = reinterpret_cast<const uint32_t *>(inCertIds.GetPointerUnsafe());
        result = SslServiceImpl::GetCertificateBufSize(&sizeNeeded,
                                                       &needAll,
                                                       pIds,
                                                       idCount);
        if (result.IsFailure())
        {
            NN_DETAIL_SSL_DBG_PRINT("[GetCertificates] failed to get size needed for provided IDs: %d-%d\n",
                                    result.GetModule(),
                                    result.GetDescription());
            break;
        }

        outBufSize = static_cast<uint32_t>(outCertsData.GetSize());
        if (outBufSize < sizeNeeded)
        {
            NN_DETAIL_SSL_DBG_PRINT("[GetCertificates] provided buffer too small (got %u, need %u)\n",
                                    outBufSize,
                                    sizeNeeded);
            result = ResultBufferTooShort();
            break;
        }

        //  If "all" has been requested, replace the passed in ID array
        //  with the list of all IDs.
        if (needAll)
        {
            result = GetAllIds(&pAllIds, &idCount);
            if (result.IsFailure())
            {
                break;
            }

            pIds = pAllIds;

            //  We need 1 extra BuiltInCertificateInfo entry for "all"
            //  so the caller knows where the end is.
            numBci = idCount + 1;
        }

        offset = sizeof(BuiltInManager::BuiltInCertificateInfo) * numBci;
        pOutBuf = reinterpret_cast<uint8_t *>(outCertsData.GetPointerUnsafe());
        pCertInfoArray =
            reinterpret_cast<BuiltInManager::BuiltInCertificateInfo *>(pOutBuf);
        for (uint32_t i = 0; i < idCount; i++)
        {
            BuiltinDataInfo                 *pBdi;
            int                             certDataStatus;
            uint32_t                        tmp;
            nn::ssl::TrustedCertStatus      status = nn::ssl::TrustedCertStatus_Invalid;
            uint32_t                        size = 0;
            uint32_t                        curOffset = 0;
            CaCertificateId                 curId = static_cast<CaCertificateId>(pIds[i]);

            result = TrustedCertManager::GetTrustedCertInfo(&pBdi, curId);
            if (result.IsFailure())
            {
                //  If the certificate ID provided is not valid, just leave
                //  the size and offset alone but allow this continue so an
                //  entry for this cert is created.  The app can see that the
                //  ID is invalid by looking at the cert status.
                if (!ResultInvalidCertStoreId::Includes(result))
                {
                    NN_DETAIL_SSL_DBG_PRINT("[GetCertificates] failed to get info for cert id %X, %d-%d\n",
                                            pIds[i],
                                            result.GetModule(),
                                            result.GetDescription());
                    break;
                }

                result = ResultSuccess();
            }
            else
            {
                status    = static_cast<nn::ssl::TrustedCertStatus>(pBdi->GetStatus());
                size      = pBdi->GetDataSize();
                curOffset = offset;

                certDataStatus = pBdi->GetData(pOutBuf + offset,
                                               outBufSize - offset,
                                               &tmp);
                if (certDataStatus != 0)
                {
                    NN_DETAIL_SSL_DBG_PRINT("[GetCertificates] failed to get cert data for id %u\n", pIds[i]);
                    break;
                }
            }

            pCertInfoArray[i].id                  = curId;
            pCertInfoArray[i].status              = status;
            pCertInfoArray[i].certificateSize     = size;
            pCertInfoArray[i].data.priv.reserved1 = curOffset;

            offset += NN_DETAIL_SSL_CALC_ALIGN(size, sizeof(uint32_t));
        }

        if (needAll)
        {
            //  Mark the dummy entry at the end with the ALL id so the caller
            //  knows it is the end.
            pCertInfoArray[idCount].id                  = nn::ssl::CaCertificateId_All;
            pCertInfoArray[idCount].status              = nn::ssl::TrustedCertStatus_Invalid;
            pCertInfoArray[idCount].certificateSize     = 0;
            pCertInfoArray[idCount].data.priv.reserved1 = 0;
        }

        *outIdCount = idCount;
    } while (NN_STATIC_CONDITION(false));

    if (pAllIds != nullptr)
    {
        delete[] pAllIds;
        pAllIds = nullptr;
    }

    return result;
}    //  NOLINT(impl/function_size)


nn::Result SslServiceImpl::GetCertificateBufSize(Out<uint32_t>   outSize,
                                                 const InBuffer  &inCertIds)
{
    nn::Result                  result = ResultSuccess();
    uint32_t                    size = 0;
    const uint32_t              *pIds = nullptr;
    uint32_t                    idCount;
    bool                        needAll;

    do
    {
        //  Wait for NSS core to be ready before we allow trusted cert ops
        NssCore::InitStatus nssStatus =
            NssCore::GetInitStatus(TrustedCertManager::g_TcmDeferredInitId, true);
        if (nssStatus == NssCore::InitStatus_InitFail)
        {
            NN_DETAIL_SSL_DBG_PRINT("[GetCertificateBufSize] NSS init failed\n");
            result = ResultErrorLower();
            break;
        }

        //  The size will be the size of BuiltInCertificateInfo times
        //  the number of certs plus the size of each cert itself, with
        //  alignment padding taken into account.
        idCount = static_cast<uint32_t>(inCertIds.GetSize() / sizeof(uint32_t));
        pIds = reinterpret_cast<const uint32_t *>(inCertIds.GetPointerUnsafe());

        result = SslServiceImpl::GetCertificateBufSize(&size,
                                                       &needAll,
                                                       pIds,
                                                       idCount);
        if (result.IsSuccess())
        {
            NN_DETAIL_SSL_DBG_PRINT("[GetCertificateBufSize] need %u byte buffer for %u certs\n",
                                    size,
                                    idCount);
            *outSize = size;
        }
    } while (NN_STATIC_CONDITION(false));

    return result;
}

nn::Result SslServiceImpl::FlushSessionCache(Out<uint32_t> outEntriesDeletedCount, InBuffer inHostName, nn::ssl::sf::FlushSessionCacheOptionType sfOption) NN_NOEXCEPT
{
    nn::Result  ret            = ResultSuccess();
    const char* pHostName      = inHostName.GetPointerUnsafe();
    auto        hostNameBufLen = inHostName.GetSize();

    nn::ssl::FlushSessionCacheOptionType option = static_cast<nn::ssl::FlushSessionCacheOptionType>(sfOption.option);

    NN_DETAIL_SSL_DBG_PRINT("[FlushSessionCache] enter\n");

    do
    {
        if(option == nn::ssl::FlushSessionCacheOptionType_AllHosts)
        {
            *outEntriesDeletedCount = nnsdkNssPortClearProcessSessionCache(GetClientProcessId(), nullptr);
        }
        else if(option == nn::ssl::FlushSessionCacheOptionType_SingleHost)
        {
            auto strLen = strnlen(pHostName, hostNameBufLen);
            if(strLen == hostNameBufLen)
            {
                NN_DETAIL_SSL_DBG_PRINT("[FlushSessionCache] Hostname string is not null terminated!\n");
                ret = nn::ssl::ResultInvalidArgument();
                break;
            }
            else if(strLen == 0)
            {
                NN_DETAIL_SSL_DBG_PRINT("[FlushSessionCache] Hostname string is empty!\n");
                ret = nn::ssl::ResultInvalidArgument();
                break;
            }

            *outEntriesDeletedCount = nnsdkNssPortClearProcessSessionCache(GetClientProcessId(), pHostName);
        }
        else
        {
            NN_DETAIL_SSL_DBG_PRINT("[FlushSessionCache] Invalid value for FlushSessionCacheOptionType\n");
            ret = nn::ssl::ResultInvalidArgument();
            break;
        }
    } while(NN_STATIC_CONDITION(false));

    NN_DETAIL_SSL_DBG_PRINT("[FlushSessionCache] exit\n");

    return Util::ConvertResultFromInternalToExternal(ret);
}


nn::Bit64 SslServiceImpl::GetClientProcessId() NN_NOEXCEPT
{
    return m_ClientProcessId;
}

nn::Result SslServiceImpl::SetInterfaceVersion(uint32_t inValue) NN_NOEXCEPT
{
    m_ISslServiceVersion = inValue;

    return nn::ResultSuccess();
}

uint32_t SslServiceImpl::GetInterfaceVersion() NN_NOEXCEPT
{
    return m_ISslServiceVersion;
}

} } }    //  namespace nn::ssl::detail


#ifdef NN_DETAIL_SSL_ENABLE_PROCESS_DEBUG
// ------------------------------------------------------------------------------------------------
// Debug interface
// ------------------------------------------------------------------------------------------------
#include <nn/ssl/ssl_Api.debug.h>
#include "debug/ssl_DebugImpl.h"

namespace nn { namespace ssl { namespace detail {

nn::Result SslServiceImpl::DebugIoctl(const OutBuffer &outData,
                                      const InBuffer  &inData,
                                      uint64_t cmd) NN_NOEXCEPT
{
    nn::Result  result = ResultSuccess();
    do
    {
        Debug::IoctlCommand ioctlCmd;
        Debug::Input        inputInfo;
        Debug::Output       outputInfo;

        ioctlCmd              = static_cast<Debug::IoctlCommand>(cmd);
        inputInfo.pBuffer     = inData.GetPointerUnsafe();
        inputInfo.bufferSize  = inData.GetSize();
        outputInfo.pBuffer    = outData.GetPointerUnsafe();
        outputInfo.bufferSize = outData.GetSize();

        result = DebugImpl::Ioctl(&outputInfo, &inputInfo, ioctlCmd);
    } while (NN_STATIC_CONDITION(false));

    return result;
}


} } }    //  namespace nn::ssl::detail
#else
namespace nn { namespace ssl { namespace detail {
nn::Result SslServiceImpl::DebugIoctl(const OutBuffer &outData,
                                      const InBuffer  &inData,
                                      uint64_t cmd) NN_NOEXCEPT
{
    NN_UNUSED(outData);
    NN_UNUSED(inData);
    NN_UNUSED(cmd);

    return ResultErrorLower();
}
} } }    //  namespace nn::ssl::detail
#endif // NN_DETAIL_SSL_ENABLE_PROCESS_DEBUG
