﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <nn/nn_Assert.h>
#include <nn/nn_Result.h>
#include <nn/sf/sf_ExpHeapAllocator.h>
#include <nn/sf/sf_ObjectFactory.h>

#include <nn/ssl/ssl_Types.h>

#include "server/ssl_MemoryManager.h"
#include "server/ssl_NssCore.h"
#include "server/ssl_ServiceDatabase.h"
#include "server/ssl_SslContextImpl.h"
#include "server/ssl_InternalPki.h"
#include "server/ssl_SslTrustedCertManager.h"
#include "server/ssl_Util.h"


using namespace nn::sf;
using namespace nn::ssl::sf;

namespace nn { namespace ssl { namespace detail {

// ------------------------------------------------------------------------------------------------
// Static member variables
// ------------------------------------------------------------------------------------------------

// ------------------------------------------------------------------------------------------------
// Member functions
// ------------------------------------------------------------------------------------------------
SslContextImpl::SslContextImpl(SslServiceImpl *parent, int32_t sslVersion) NN_NOEXCEPT
    : m_Parent(parent, true),
      m_SslVersion(sslVersion),
      m_CrlDateCheckEnabled(true)
{
    uint64_t                    id;

    NN_DETAIL_SSL_GET_ID_FROM_PTR(id, this);
    m_CertStore.SetOwnerId(id);

    NN_DETAIL_SSL_DBG_PRINT("[SslContextImpl] init %p, id %llu\n", this, id);
}


SslContextImpl::~SslContextImpl() NN_NOEXCEPT
{
}


SslServiceImpl *SslContextImpl::GetParentSslService()
{
    return m_Parent.Get();
}


nn::ssl::Context::SslVersion SslContextImpl::GetSslVersion()
{
    return static_cast<nn::ssl::Context::SslVersion>(m_SslVersion);
}


CertStore* SslContextImpl::GetCertStore()
{
    return &m_CertStore;
}

EvCertUtil::PolicyOidInfo* SslContextImpl::GetPolicyOidInfo()
{
    return &m_PolicyOidInfo;
}


nn::Result SslContextImpl::SetOption(ContextOption option, int32_t value) NN_NOEXCEPT
{
    nn::Result                          result = ResultSuccess();
    nn::ssl::Context::ContextOption     opt;

    opt = static_cast<nn::ssl::Context::ContextOption>(option.name);

    do
    {
        NN_DETAIL_SSL_DBG_PRINT("[SetOption] option %u, value %d\n",
                                option.name,
                                value);

        switch (opt)
        {
            case Context::ContextOption_CrlImportDateCheckEnable:
                if (value == 1)
                {
                    m_CrlDateCheckEnabled = true;
                }
                else if (value == 0)
                {
                    m_CrlDateCheckEnabled = false;
                }
                else
                {
                    NN_DETAIL_SSL_DBG_PRINT("[SetOption] invalid CRL date check enable value: %d\n", value);
                    result = ResultInvalidArgument();
                    break;
                }

                NN_DETAIL_SSL_DBG_PRINT("[SetOption] CRL date check now %s\n",
                                        m_CrlDateCheckEnabled ? "TRUE" : "FALSE");
                break;

            default:
                NN_DETAIL_SSL_DBG_PRINT("[SetOption] invalid option: %u\n",
                                        option.name);
                result = ResultInvalidOptionType();
                break;
        }
    } while (NN_STATIC_CONDITION(false));

    return result;
}


nn::Result SslContextImpl::GetOption(Out<int32_t> outValue, ContextOption option) NN_NOEXCEPT
{
    nn::Result                          result = ResultSuccess();
    nn::ssl::Context::ContextOption     opt;

    opt = static_cast<nn::ssl::Context::ContextOption>(option.name);

    do
    {
        NN_DETAIL_SSL_DBG_PRINT("[GetOption] option %u\n", option.name);

        switch (opt)
        {
            case Context::ContextOption_CrlImportDateCheckEnable:
                *outValue = m_CrlDateCheckEnabled ? 1 : 0;
                break;

            default:
                NN_DETAIL_SSL_DBG_PRINT("[GetOption] invalid option: %u\n",
                                        option.name);
                result = ResultInvalidOptionType();
                break;
        }
    } while (NN_STATIC_CONDITION(false));

    return result;
}


nn::Result SslContextImpl::ImportServerPki(Out<uint64_t>                   outCertId,
                                           InBuffer                        inCertData,
                                           nn::ssl::sf::CertificateFormat  fmt)
{
    nn::Result                  ret = ResultSuccess();

    NN_DETAIL_SSL_DBG_PRINT("[ImportServerPki] enter\n");

    do
    {
        if (outCertId.GetPointer() == nullptr)
        {
            NN_DETAIL_SSL_DBG_PRINT("[ImportServerPki] no certId pointer\n");
            ret = ResultInvalidPointer();
            break;
        }

        //  Check the number of ServerCertEntry already imported
        //  TODO: When we support importing multiple server certificates in one ImportServerPki
        //        call, we need to consider revising this.
        if (m_CertStore.GetServerCertEntryCount() > MaxServerPkiImportCount)
        {
            NN_DETAIL_SSL_DBG_PRINT("[ImportServerPki] The maximum number of server PKI is already registered\n");
            ret = ResultMaxServerPkiRegistered();
            break;
        }

        const char                  *pCertData = inCertData.GetPointerUnsafe();
        auto                        certDataSize = static_cast<uint32_t>(inCertData.GetSize());
        uint64_t                    certId;
        nn::ssl::CertificateFormat  certFmt =
            static_cast<nn::ssl::CertificateFormat>(fmt.value);

        ret = m_CertStore.ImportServerPki(&certId,
                                          pCertData,
                                          certDataSize,
                                          certFmt);
        if (ret.IsFailure())
        {
            NN_DETAIL_SSL_DBG_PRINT("[ImportServerPki] faild to import to the certstore: %d-%d\n",
                                    ret.GetModule(),
                                    ret.GetDescription());
            break;
        }

        *outCertId = certId;
    } while (NN_STATIC_CONDITION(false));

    NN_DETAIL_SSL_DBG_PRINT("[ImportServerPki] exit\n");
    return Util::ConvertResultFromInternalToExternal(ret);
}


nn::Result SslContextImpl::ImportClientPki(Out<uint64_t> outCertId,
                                           InBuffer      inP12Data,
                                           InBuffer      inPwData)
{
    nn::Result                  ret = ResultSuccess();

    NN_DETAIL_SSL_DBG_PRINT("[ImportClientPki] enter\n");

    do
    {
        if (outCertId.GetPointer() == nullptr)
        {
            NN_DETAIL_SSL_DBG_PRINT("[ImportClientPki] no certId pointer\n");
            ret = ResultInvalidPointer();
            break;
        }

        //  Before client PKI can proceed, the TrustedCertManager must be done
        //  initializing, otherwise we will not have trusted CAs.
        NssCore::InitStatus nssStatus =
            NssCore::GetInitStatus(TrustedCertManager::g_TcmDeferredInitId, true);
        if (nssStatus == NssCore::InitStatus_InitFail)
        {
            NN_DETAIL_SSL_DBG_PRINT("[ImportClientPki] TCM failed to init\n");
            ret = ResultErrorLower();
            break;
        }

        const char *pP12Data = inP12Data.GetPointerUnsafe();
        auto       p12DataSize = static_cast<uint32_t>(inP12Data.GetSize());
        const char *pPwData = inPwData.GetPointerUnsafe();
        auto       pwDataSize = static_cast<uint32_t>(inPwData.GetSize());
        uint64_t   certId;

        ret = m_CertStore.ImportClientPki(&certId,
                                          pP12Data,
                                          pPwData,
                                          p12DataSize,
                                          pwDataSize);
        if (ret.IsFailure())
        {
            NN_DETAIL_SSL_DBG_PRINT("[ImportClientPki] faild to import to the certstore: %d-%d\n",
                                    ret.GetModule(),
                                    ret.GetDescription());
            break;
        }

        *outCertId = certId;
    } while (NN_STATIC_CONDITION(false));

    NN_DETAIL_SSL_DBG_PRINT("[ImportClientPki] exit\n");

    return Util::ConvertResultFromInternalToExternal(ret);
}


nn::Result SslContextImpl::RemoveServerPki(uint64_t certId)
{
    nn::Result                  ret;

    NN_DETAIL_SSL_DBG_PRINT("[RemoveServerPki] enter\n");

    ret = m_CertStore.RemoveServerPki(certId);

    NN_DETAIL_SSL_DBG_PRINT("[RemoveServerPki] exit\n");

    return Util::ConvertResultFromInternalToExternal(ret);
}


nn::Result SslContextImpl::RemoveClientPki(uint64_t certId)
{
    nn::Result                  ret;

    NN_DETAIL_SSL_DBG_PRINT("[RemoveClientPki] enter\n");

    do
    {
        //  Before client PKI can proceed, the TrustedCertManager must be done
        //  initializing, otherwise we will not have trusted CAs.
        NssCore::InitStatus nssStatus =
            NssCore::GetInitStatus(TrustedCertManager::g_TcmDeferredInitId, true);
        if (nssStatus == NssCore::InitStatus_InitFail)
        {
            NN_DETAIL_SSL_DBG_PRINT("[RemoveClientPki] TCM failed to init\n");
            ret = ResultErrorLower();
            break;
        }

        ret = m_CertStore.RemoveClientPki(certId);
    } while(NN_STATIC_CONDITION(false));

    NN_DETAIL_SSL_DBG_PRINT("[RemoveClientPki] exit\n");

    return Util::ConvertResultFromInternalToExternal(ret);
}


nn::Result SslContextImpl::RegisterInternalPki(Out<uint64_t> outCertId, InternalPki sfPki)
{
    nn::Result           ret;
    uint64_t             certId;
    const char           *nickname = nullptr;

    NN_DETAIL_SSL_DBG_PRINT("[RegisterInternalPki] enter\n");

    do
    {
        if (outCertId.GetPointer() == nullptr)
        {
            NN_DETAIL_SSL_DBG_PRINT("[RegisterInternalPki] no certId pointer\n");
            ret = ResultInvalidPointer();
            break;
        }

        nn::ssl::Context::InternalPki pki =
            static_cast<nn::ssl::Context::InternalPki>(sfPki.pki);

        //  Get the nickname of the internal PKI based on the type.  If
        //  we get back nullptr, then this is not a valid type.
        nickname = InternalPkiManager::GetNickname(pki);
        if (nickname == nullptr)
        {
            NN_DETAIL_SSL_DBG_PRINT("[RegisterInternalPki] invalid pki type (%X)\n", sfPki.pki);
            ret = ResultInvalidInternalPkiType();
            break;
        }

        ret = m_CertStore.ImportDeviceUniqueClientPki(&certId, nickname);
        if (ret.IsFailure())
        {
            NN_DETAIL_SSL_DBG_PRINT("[RegisterInternalPki] failed to import device PKI\n");
            break;
        }

        *outCertId = certId;
    } while (NN_STATIC_CONDITION(false));

    NN_DETAIL_SSL_DBG_PRINT("[RegisterInternalPki] exit\n");

    return Util::ConvertResultFromInternalToExternal(ret);
}


nn::Result SslContextImpl::AddPolicyOid(const InBuffer& stringBuffer) NN_NOEXCEPT
{
    nn::Result                  ret = ResultSuccess();

    NN_DETAIL_SSL_DBG_PRINT("[AddPolicyOId] enter\n");

    do
    {
        //  Bounds check the incoming buffer
        size_t      bufferSize   = stringBuffer.GetSize();
        const char* oidString    = stringBuffer.GetPointerUnsafe();
        size_t      stringLength = strnlen(oidString, bufferSize);

        if (stringLength == bufferSize)
        {
            NN_DETAIL_SSL_DBG_PRINT("[AddPolicyOId] Provided string buffer doesn't include null termination\n");
            ret = ResultInvalidPolicyOidStringBufferLength();
            break;
        }

        if (stringLength >= MaxPolicyOidStringLength)
        {
            NN_DETAIL_SSL_DBG_PRINT("[AddPolicyOId] Provided string is too long.\n");
            ret = ResultPolicyOidStringTooLong();
            break;
        }

        if (m_PolicyOidInfo.GetCount() >= MaxPolicyOidStringCount)
        {
            NN_DETAIL_SSL_DBG_PRINT("[AddPolicyOId] Max OID strings are already registered.\n");
            ret = ResultMaxPolicyOidRegistered();
            break;
        }

        SECOidTag tmpOidTag = SEC_OID_UNKNOWN;
        ret = EvCertUtil::EncodePolicyOidString(&tmpOidTag, oidString);
        if (ret.IsFailure())
        {
            NN_DETAIL_SSL_DBG_PRINT("[AddPolicyOId] EncodePolicyOidString failed.\n");
            break;
        }

        if (m_PolicyOidInfo.GetHead() == nullptr)
        {
            ret = m_PolicyOidInfo.Setup(MaxPolicyOidStringCount);
            if (ret.IsFailure())
            {
                break;
            }
        }

        if (m_PolicyOidInfo.IsOidSet(tmpOidTag) == false)
        {
            ret = m_PolicyOidInfo.AddOids(&tmpOidTag, 1);
        }

        NN_DETAIL_SSL_DBG_PRINT("[AddPolicyOId] Added OID: %d (%s)\n", tmpOidTag, oidString);
    } while (NN_STATIC_CONDITION(false));

    NN_DETAIL_SSL_DBG_PRINT("[AddPolicyOId] exit\n");

    return Util::ConvertResultFromInternalToExternal(ret);
}


nn::Result SslContextImpl::ImportCrl(Out<uint64_t> outCrlId, InBuffer inCrlData) NN_NOEXCEPT
{
    nn::Result                  ret = ResultSuccess();

    NN_DETAIL_SSL_DBG_PRINT("[ImportCrl] enter\n");

    do
    {
        if (outCrlId.GetPointer() == nullptr)
        {
            NN_DETAIL_SSL_DBG_PRINT("[ImportCrl] no crlId pointer\n");
            ret = ResultInvalidPointer();
            break;
        }

        //  Before allowing any type of import, the TrustedCertManager must be
        //  done initializing, otherwise we will not have trusted CAs and other
        //  internal descrepancies may arise.
        NssCore::InitStatus nssStatus =
            NssCore::GetInitStatus(TrustedCertManager::g_TcmDeferredInitId, true);
        if (nssStatus == NssCore::InitStatus_InitFail)
        {
            NN_DETAIL_SSL_DBG_PRINT("[ImportCrl] TCM failed to init\n");
            ret = ResultErrorLower();
            break;
        }

        const char  *pCrlDerData = inCrlData.GetPointerUnsafe();
        auto        crlDerDataSize = static_cast<uint32_t>(inCrlData.GetSize());
        uint64_t    crlId;

        ret = m_CertStore.ImportCrl(&crlId,
                                    reinterpret_cast<const uint8_t *>(pCrlDerData),
                                    crlDerDataSize);
        if (ret.IsFailure())
        {
            NN_DETAIL_SSL_DBG_PRINT("[ImportCrl] failed to import to the certstore: %d-%d\n",
                                    ret.GetModule(),
                                    ret.GetDescription());
            break;
        }

        *outCrlId = crlId;
    } while (NN_STATIC_CONDITION(false));

    NN_DETAIL_SSL_DBG_PRINT("[ImportCrl] exit\n");

    return Util::ConvertResultFromInternalToExternal(ret);
}


nn::Result SslContextImpl::RemoveCrl(uint64_t crlId) NN_NOEXCEPT
{
    nn::Result                  ret;

    NN_DETAIL_SSL_DBG_PRINT("[RemoveCrl] enter\n");

    ret = m_CertStore.RemoveCrl(crlId);

    NN_DETAIL_SSL_DBG_PRINT("[RemoveCrl] exit\n");

    return Util::ConvertResultFromInternalToExternal(ret);
}


nn::Result SslContextImpl::CreateConnection(Out<SharedPointer<ISslConnection>> outValue) NN_NOEXCEPT
{
    nn::Result                  ret = ResultSuccess();
    typedef ObjectFactory<ExpHeapAllocator::Policy>   Factory;

    NN_DETAIL_SSL_DBG_PRINT("[CreateConnection] enter\n");

    do
    {
        EmplacedRef<ISslConnection, SslConnectionImpl> er =
            Factory::CreateSharedEmplaced<ISslConnection, SslConnectionImpl>(SslMemoryManager::GetSfAllocator(),
                                                                             this);
        if ((er == nullptr) || (er.Get() == nullptr))
        {
            NN_DETAIL_SSL_DBG_PRINT("[CreateConnection] failed to create new connection\n");
            ret = ResultInsufficientMemory();
            break;
        }

        //  Track the new connection in the database.  If this add fails it is
        //  because we have hit a resource limit.  Clear the EmplacedRef to
        //  free the service object.
        SslConnectionImpl &conn = er.GetImpl();
        ret = SslServiceDatabase::AddSslConnection(&conn);
        if (ret.IsFailure())
        {
            NN_DETAIL_SSL_DBG_PRINT("[CreateConnection] failed to track new connection: (ret %d:%d)\n",
                                    ret.GetModule(),
                                    ret.GetDescription());
            static_cast<SharedPointer<ISslConnection>>(er) = nullptr;
            break;
        }

        //  Return the new service object to the caller, we're good to go.
        *outValue = er;
    } while (NN_STATIC_CONDITION(false));

    NN_DETAIL_SSL_DBG_PRINT("[CreateConnection] exit\n");

    return Util::ConvertResultFromInternalToExternal(ret);
}


nn::Result SslContextImpl::GetConnectionCount(Out<uint32_t> outValue) NN_NOEXCEPT
{
    nn::Result                  ret = ResultSuccess();
    uint32_t                    count;

    NN_DETAIL_SSL_DBG_PRINT("[GetConnectionCount] enter\n");

    do
    {
        ret = SslServiceDatabase::GetConnectionCount(this, &count);
        if (ret.IsFailure())
        {
            NN_DETAIL_SSL_DBG_PRINT("[GetConnectionCount] err: %d:%d\n",
                                    ret.GetModule(),
                                    ret.GetDescription());
            break;
        }

        *outValue = count;
    } while (NN_STATIC_CONDITION(false));

    NN_DETAIL_SSL_DBG_PRINT("[GetConnectionCount] exit\n");

    return Util::ConvertResultFromInternalToExternal(ret);
}


bool SslContextImpl::IsCrlDateCheckEnabled()
{
    return m_CrlDateCheckEnabled;
}

} } }
