﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <nn/nn_SdkAssert.h>
#include <nn/nn_SdkLog.h>
#include <nn/ssl/ssl_Types.h>
#include <nn/ssl/ssl_Context.h>

#include "detail/ssl_ServiceSession.h"
#include "server/ssl_MemoryManager.h"
#include "server/ssl_Util.h"


using namespace nn::ssl::detail;

namespace nn { namespace ssl {


Context::Context() NN_NOEXCEPT : m_ContextId(0)
{
}

Context::~Context() NN_NOEXCEPT
{
    NN_SDK_ASSERT(m_ContextId == 0);
}

nn::Result Context::Create(SslVersion version) NN_NOEXCEPT
{
    nn::Result                  ret = nn::ResultSuccess();
    SharedPointer<ISslContext>  *newCtx = nullptr;

    do
    {
        SharedPointer<ISslService>  *controlIf = ServiceSession::GetServiceSession();
        NN_DETAIL_SSL_VALIDATE_SHARED_POINTER(controlIf, ret, ResultLibraryNotInitialized());

        nn::ssl::sf::SslVersion     vers;

        vers.flags = static_cast<int32_t>(version);
        void *ptr = SslMemoryManager::AllocateChunk(sizeof(SharedPointer<ISslContext>), 0);
        if (ptr == nullptr)
        {
            ret = ResultInsufficientMemory();
            break;
        }

        newCtx = new(ptr) SharedPointer<ISslContext>();

        ret = (*controlIf)->CreateContext(newCtx,
                                          vers,
                                          0);
        if (ret.IsFailure())
        {
            break;
        }

        NN_DETAIL_SSL_GET_ID_FROM_PTR(m_ContextId, newCtx);
    } while (NN_STATIC_CONDITION(false));

    if (ret.IsFailure() && (newCtx != nullptr))
    {
        newCtx->~SharedPointer<ISslContext>();
        SslMemoryManager::Free(newCtx, 0);
    }

    return ret;
}


nn::Result Context::Destroy() NN_NOEXCEPT
{
    nn::Result                  ret = nn::ResultSuccess();
    uint32_t                    connCount;

    do
    {
        if (!ServiceSession::IsInitialized())
        {
            ret = ResultLibraryNotInitialized();
            break;
        }

        SharedPointer<ISslContext>  *ctx;
        NN_DETAIL_SSL_GET_PTR_FROM_ID(ctx, m_ContextId, SharedPointer<ISslContext>);
        NN_DETAIL_SSL_VALIDATE_SHARED_POINTER(ctx, ret, ResultInvalidContext());

        ret = (*ctx)->GetConnectionCount(&connCount);
        if (ret.IsFailure())
        {
            break;
        }

        if (connCount > 0)
        {
            ret = ResultConnectionRemaining();
            break;
        }

        m_ContextId = 0;
        *ctx = nullptr;
        ctx->~SharedPointer<ISslContext>();
        SslMemoryManager::Free(ctx, 0);
    } while (NN_STATIC_CONDITION(false));

    return ret;
}


nn::Result Context::SetOption(ContextOption optionName, int optionValue) NN_NOEXCEPT
{
    nn::Result                  ret = nn::ResultSuccess();
    SharedPointer<ISslContext>  *ctx;

    do
    {
        if (!ServiceSession::IsInitialized())
        {
            ret = ResultLibraryNotInitialized();
            break;
        }

        nn::ssl::sf::ContextOption opt;

        NN_DETAIL_SSL_GET_PTR_FROM_ID(ctx, m_ContextId, SharedPointer<ISslContext>);
        NN_DETAIL_SSL_VALIDATE_SHARED_POINTER(ctx, ret, ResultInvalidContext());

        opt.name = static_cast<uint32_t>(optionName);
        ret = (*ctx)->SetOption(opt, optionValue);
        if (ret.IsFailure())
        {
            break;
        }

    } while (NN_STATIC_CONDITION(false));

    return ret;
}


nn::Result Context::GetOption(int *pOutValue, ContextOption optionName) NN_NOEXCEPT
{
    nn::Result                  ret;
    SharedPointer<ISslContext>  *ctx;

    NN_SDK_REQUIRES_NOT_NULL(pOutValue);

    do
    {
        if (!ServiceSession::IsInitialized())
        {
            ret = ResultLibraryNotInitialized();
            break;
        }

        if (pOutValue == nullptr)
        {
            ret = ResultInvalidPointer();
            break;
        }

        nn::ssl::sf::ContextOption opt;

        NN_DETAIL_SSL_GET_PTR_FROM_ID(ctx, m_ContextId, SharedPointer<ISslContext>);
        NN_DETAIL_SSL_VALIDATE_SHARED_POINTER(ctx, ret, ResultInvalidContext());

        opt.name = static_cast<uint32_t>(optionName);
        ret = (*ctx)->GetOption(pOutValue, opt);
        if (ret.IsFailure())
        {
            break;
        }

    } while (NN_STATIC_CONDITION(false));

    return ret;
}


nn::Result Context::GetContextId(SslContextId* pOutValue) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(pOutValue);
    if (pOutValue == nullptr)
    {
        return ResultInvalidPointer();
    }

    *pOutValue = m_ContextId;
    return nn::ResultSuccess();
}


nn::Result Context::ImportServerPki(
        CertStoreId* pOutCertId,
        const char* pInCertData,
        uint32_t certDataSize,
        CertificateFormat certFormat) NN_NOEXCEPT
{
    nn::Result                  ret = ResultSuccess();
    SharedPointer<ISslContext>  *ctx;

    NN_SDK_REQUIRES_NOT_NULL(pOutCertId);

    do
    {
        if (!ServiceSession::IsInitialized())
        {
            ret = ResultLibraryNotInitialized();
            break;
        }

        if (pOutCertId == nullptr)
        {
            ret = ResultInvalidPointer();
            break;
        }

        if (pInCertData == nullptr)
        {
            ret = ResultInvalidPointer();
            break;
        }

        nn::ssl::sf::CertificateFormat fmt;

        NN_DETAIL_SSL_GET_PTR_FROM_ID(ctx, m_ContextId, SharedPointer<ISslContext>);
        NN_DETAIL_SSL_VALIDATE_SHARED_POINTER(ctx, ret, ResultInvalidContext());

        fmt.value = static_cast<uint32_t>(certFormat);
        InBuffer buf(pInCertData, certDataSize);
        ret = (*ctx)->ImportServerPki(pOutCertId, buf, fmt);
    } while (NN_STATIC_CONDITION(false));

    return ret;
}


nn::Result Context::ImportClientPki(
    CertStoreId* pOutCertId,
    const char* pInP12Data,
    const char* pInPwData,
    uint32_t  p12DataSize,
    uint32_t  pwDataSize) NN_NOEXCEPT
{
    nn::Result                  ret = ResultSuccess();
    SharedPointer<ISslContext>  *ctx;

    NN_SDK_REQUIRES_NOT_NULL(pOutCertId);

    do
    {
        if (!ServiceSession::IsInitialized())
        {
            ret = ResultLibraryNotInitialized();
            break;
        }

        if (pOutCertId == nullptr)
        {
            ret = ResultInvalidPointer();
            break;
        }

        if (pInP12Data == nullptr)
        {
            ret = ResultInvalidPointer();
            break;
        }

        if (pInPwData == nullptr && pwDataSize > 0)
        {
            ret = ResultInvalidPointer();
            break;
        }

        if (pInPwData && pwDataSize == 0)
        {
            ret = ResultInvalidPasswordSize();
            break;
        }

        NN_DETAIL_SSL_GET_PTR_FROM_ID(ctx, m_ContextId, SharedPointer<ISslContext>);
        NN_DETAIL_SSL_VALIDATE_SHARED_POINTER(ctx, ret, ResultInvalidContext());

        InBuffer p12Buf(pInP12Data, p12DataSize);
        InBuffer pwBuf(pInPwData, pwDataSize);
        ret = (*ctx)->ImportClientPki(pOutCertId, p12Buf, pwBuf);
    } while (NN_STATIC_CONDITION(false));

    return ret;
}

nn::Result Context::RemovePki(CertStoreId certId) NN_NOEXCEPT
{
    nn::Result                  ret = ResultSuccess();
    SharedPointer<ISslContext>  *ctx;

    do
    {
        if (!ServiceSession::IsInitialized())
        {
            ret = ResultLibraryNotInitialized();
            break;
        }

        NN_DETAIL_SSL_GET_PTR_FROM_ID(ctx, m_ContextId, SharedPointer<ISslContext>);
        NN_DETAIL_SSL_VALIDATE_SHARED_POINTER(ctx, ret, ResultInvalidContext());

        ret = (*ctx)->RemoveServerPki(certId);
        if (ResultInvalidCertStoreId::Includes(ret))
        {
            ret = (*ctx)->RemoveClientPki(certId);
            if (ResultInvalidCertStoreId::Includes(ret))
            {
                ret = (*ctx)->RemoveCrl(certId);
            }
        }
    } while (NN_STATIC_CONDITION(false));

    return ret;
}

nn::Result Context::RegisterInternalPki(CertStoreId* pOutCertId, InternalPki pkiType) NN_NOEXCEPT
{
    nn::Result                  ret = ResultSuccess();
    SharedPointer<ISslContext>  *ctx;
    nn::ssl::sf::InternalPki    sfPkiType;

    do
    {
        if (!ServiceSession::IsInitialized())
        {
            ret = ResultLibraryNotInitialized();
            break;
        }

        NN_DETAIL_SSL_GET_PTR_FROM_ID(ctx, m_ContextId, SharedPointer<ISslContext>);
        NN_DETAIL_SSL_VALIDATE_SHARED_POINTER(ctx, ret, ResultInvalidContext());
        sfPkiType.pki = static_cast<int32_t>(pkiType);

        ret = (*ctx)->RegisterInternalPki(pOutCertId, sfPkiType);
    } while (NN_STATIC_CONDITION(false));

    return ret;
}

nn::Result Context::AddPolicyOid(const char* pInPolicyOIdString, uint32_t stringBufferSize) NN_NOEXCEPT
{
    nn::Result                  ret = ResultSuccess();
    SharedPointer<ISslContext>  *ctx;

    do
    {
        if (!ServiceSession::IsInitialized())
        {
            ret = ResultLibraryNotInitialized();
            break;
        }

        if (pInPolicyOIdString == nullptr)
        {
            ret = ResultInvalidPointer();
            break;
        }

        if (stringBufferSize > MaxPolicyOidStringLength)
        {
            ret = ResultPolicyOidStringTooLong();
            break;
        }

        NN_DETAIL_SSL_GET_PTR_FROM_ID(ctx, m_ContextId, SharedPointer<ISslContext>);
        NN_DETAIL_SSL_VALIDATE_SHARED_POINTER(ctx, ret, ResultInvalidContext());

        InBuffer oidStringBuf(pInPolicyOIdString, stringBufferSize);
        ret = (*ctx)->AddPolicyOid(oidStringBuf);
    } while (NN_STATIC_CONDITION(false));

    return ret;
}

nn::Result Context::ImportCrl(
        CertStoreId* pOutCrlId,
        const char*  pInCrlDerData,
        uint32_t     crlDataSize) NN_NOEXCEPT
{
    nn::Result                  ret = ResultSuccess();
    SharedPointer<ISslContext>  *ctx;

    NN_SDK_REQUIRES_NOT_NULL(pOutCrlId);

    do
    {
        if (!ServiceSession::IsInitialized())
        {
            ret = ResultLibraryNotInitialized();
            break;
        }

        if (pOutCrlId == nullptr)
        {
            ret = ResultInvalidPointer();
            break;
        }

        if (pInCrlDerData == nullptr)
        {
            ret = ResultInvalidPointer();
            break;
        }

        NN_DETAIL_SSL_GET_PTR_FROM_ID(ctx, m_ContextId, SharedPointer<ISslContext>);
        NN_DETAIL_SSL_VALIDATE_SHARED_POINTER(ctx, ret, ResultInvalidContext());

        InBuffer buf(pInCrlDerData, crlDataSize);
        ret = (*ctx)->ImportCrl(pOutCrlId, buf);
    } while (NN_STATIC_CONDITION(false));

    return ret;
}

} }
