﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include "resolver_OptionsHelper.h"
#include <cstring>
#include <nn/socket/socket_ResolverOptionsPrivate.h>
#include <resolver_ClientServerShared.h>

//#define LOG_LEVEL LOG_LEVEL_MAX
#define LOG_MODULE_NAME "opt helper" // NOLINT(preprocessor/const)
#include <nn/socket/resolver/private/resolver_DebugLog.h>

namespace nn { namespace socket {

/**
 * @brief
 * This parameter is the current (and therefore highest)
 * version of resolver options feature supported.
 *
 * @details
 * This constant is also used to size the table of default request
 * options. Various helper functions use it to check against the
 * client version provided; which is currently for equivalence but
 * this might prove too strict in the future.
 */
const unsigned g_ResolverOptionVersion = 1;

/**
 * @brief
 * This structure is a simple pairing of a version number
 * along with an array of ResolverOptions that are used to fill out
 * the array of @ref defaults below.
 *
 * @details
 * The options parameter is declared as a sized array first
 * because it is convenient but also because the default table must
 * have an associated default value for the version in question.
 */
struct DefaultRequestOptions
{
    uint32_t version;
    nn::socket::ResolverOption options[static_cast<uint32_t>(ResolverOptionKey::RequestMaximumValue) -
                                       static_cast<uint32_t>(ResolverOptionKey::RequestMinimumValue)];
};

/**
 * @brief
 * This variable is the table of default @ref ResolverOption
 * values. It fills out the array
 *
 * @details
 * Whether or not new functionality should be enabled or disabled
 * for previous versions is on a case-by-case basis. For example:
 * the cache feature is enabled by default even though prior to
 * version 4.x no cache existed.
 */
DefaultRequestOptions DefaultRequestOptionTable[g_ResolverOptionVersion] = {
    {
        .version = 1,
        .options = {
            {
                .key = ResolverOptionKey::RequestEnableServiceDiscoveryBoolean,
                .type = ResolverOptionType::Boolean,
                .size = sizeof(bool),
                .data.booleanValue = true
            },
            {
                .key = ResolverOptionKey::RequestCancelHandleInteger,
                .type = ResolverOptionType::Integer,
                .size = sizeof(int),
                .data.unsigned32Value = 0
            },
            {
                .key = ResolverOptionKey::RequestEnableDnsCacheBoolean,
                .type = ResolverOptionType::Boolean,
                .size = sizeof(bool),
                .data.booleanValue = true
            }
        }
    }
};

/**
 * @brief This value is the size of the default request option count
 */
const
size_t DefaultRequestOptionCount =
    sizeof(DefaultRequestOptionTable) / sizeof(struct DefaultRequestOptions);

/**
 * @brief This api is called only in non-release builds for testing purposes only
 */
#if ! defined(NN_SDK_BUILD_RELEASE)
void DisableCacheByDefault()
{
    static int s_LogFlag = 0;
    if (++s_LogFlag == 1)
    {
        LogMajor("Globally disabling the resolver cache by default.\n");
    };
    DefaultRequestOptionTable[0].options[2].data.booleanValue = false;
};
#endif

/**
 * @brief Validate that they user-provided key is valid for this version.
 *
 * @param[in] key The @ref ResolverOptionKey to validate
 *
 * @param[in] version the version to validate against
 */
bool ValidateKeyVersion(ResolverOptionKey key, uint32_t version)
{
    bool rc = false;
    if (version != g_ResolverOptionVersion)
    {
        goto bail;
    }

    switch (static_cast<uint32_t>(key))
    {
    // get
    case static_cast<uint32_t>(ResolverOptionKey::GetCancelHandleInteger):
    // set
    case static_cast<uint32_t>(ResolverOptionKey::SetCancelHandleInteger):
    case static_cast<uint32_t>(ResolverOptionKey::SetRemoveDomainnameFromCachePointer):
    case static_cast<uint32_t>(ResolverOptionKey::SetRemoveIpAddressFromCacheUnsigned32):
    // request
    case static_cast<uint32_t>(ResolverOptionKey::RequestEnableServiceDiscoveryBoolean):
    case static_cast<uint32_t>(ResolverOptionKey::RequestCancelHandleInteger):
    case static_cast<uint32_t>(ResolverOptionKey::RequestEnableDnsCacheBoolean):
    // private
#if ! defined(NN_SDK_BUILD_RELEASE)
    case static_cast<uint32_t>(ResolverOptionPrivateKey::GetDnsServerAddressesPointer):
    case static_cast<uint32_t>(ResolverOptionPrivateKey::GetCacheEntryCountForDomainnamePointer):
    case static_cast<uint32_t>(ResolverOptionPrivateKey::GetCacheEntryCountForIpUnsigned32):
    case static_cast<uint32_t>(ResolverOptionPrivateKey::SetFlushCacheBoolean):
    case static_cast<uint32_t>(ResolverOptionPrivateKey::SetTimeToLiveForDomainnamePointer):
    case static_cast<uint32_t>(ResolverOptionPrivateKey::SetDnsServerAddressesPointer):
#endif
    // local are handled in the client shim
        rc = true;
        break;
    default:
        ;
    };

bail:
    LogDebug("key: %s (%u), version: %d, returning %s\n",
             resolver::GetResolverOptionKeyString(key),
             key,
             version,
             true == rc ? "true" : "false");
    return rc;
}

/**
 * @brief This function validates the size and type field for a
 * @ref ResolverOption.
 *
 * param[in] option This parameter is the option to validate.
 *
 * @return true This function returns true if the type and size
 * field correspond for the
 */
static
bool ValidateOptionContents(const ResolverOption& option)
{
    LogHex(&option.data,
           option.size,
           "option: %p, key: %s (%u), type: %s (%u), size: %zu, data: ",
           &option,
           resolver::GetResolverOptionKeyString(option.key),
           option.key,
           resolver::GetResolverOptionTypeString(option.type),
           option.type,
           option.size);

    bool rc = false;

    /**
     * next validate that the size matches the type
     */
    switch (option.type)
    {
    case ResolverOptionType::Boolean:
        if (option.size != sizeof(bool))
        {
            LogDebug("\n");
            goto bail;
        };
        break;
    case ResolverOptionType::Integer:
        if (option.size != sizeof(int))
        {
            LogDebug("\n");
                goto bail;
        };
        break;
    case ResolverOptionType::Unsigned32:
        if (option.size != sizeof(uint32_t))
        {
            LogDebug("\n");
            goto bail;
        };
        break;
    case ResolverOptionType::Unsigned64:
        if (option.size != sizeof(uint64_t))
        {
            LogDebug("\n");
            goto bail;
        };
        break;
    case ResolverOptionType::Double:
        if (option.size != sizeof(double))
        {
            LogDebug("\n");
            goto bail;
        };
        break;
    case ResolverOptionType::Pointer:
        LogDebug("\n");
        break;
    default:
        LogDebug("\n");
        goto bail;
    };

    rc = true;

bail:
    LogDebug("returning: %s\n", rc == true ? "true" : "false");
    return rc;
};

/**
 * @brief This function validates that a @ref ResolverOption key is
 * within the set of @ref OptionContext values represented by the mask.
 *
 * @details
 * This function takes two parameters. The first parameter is a mask
 * that represents a set of set the second parameter is a @ref
 * ResolverOptionKey that might be a member set membership is
 * determined by whether or not the key value lies on the number line
 * between any of the set minimum or maximum values. If the key exists
 * in any of the sets then the function returns true otherwise it
 * returns false.
 *
 * @param[in] set A bit set of @ref OptionContext context values
 * that describe the context in which an @ref ResolverOption is being
 * used.
 *
 * @param[in] key The @ref ResolverOptionKey key value checked against
 * ranges of every given set.
 */
static
bool SetContainsKey(OptionContext set, ResolverOptionKey key)
{
    LogDebug("set: %x: key: %s (%u)\n", set,
             resolver::GetResolverOptionKeyString(key), key);
    uint32_t setU32 = static_cast<uint32_t>(set);
    uint32_t keyU32 = static_cast<uint32_t>(key);

#define CONTAINS(check, min, max)                               \
    do                                                          \
    {                                                           \
        uint32_t checkU32 = static_cast<uint32_t>(check);       \
        uint32_t minU32 = static_cast<uint32_t>(min);           \
        uint32_t maxU32 = static_cast<uint32_t>(max);           \
        if (0 != (checkU32 & setU32) &&                         \
            keyU32 > minU32 &&                                  \
            keyU32 < maxU32)                                    \
        {                                                       \
            valid++;                                            \
        }                                                       \
    } while(NN_STATIC_CONDITION(false))

    unsigned valid = 0;

    CONTAINS(OptionContext_Get,
             ResolverOptionKey::GetMinimumValue,
             ResolverOptionKey::GetMaximumValue);

    CONTAINS(OptionContext_Set,
             ResolverOptionKey::SetMinimumValue,
             ResolverOptionKey::SetMaximumValue);

    CONTAINS(OptionContext_Request,
             ResolverOptionKey::RequestMinimumValue,
             ResolverOptionKey::RequestMaximumValue);

    CONTAINS(OptionContext_PrivateGet,
             ResolverOptionPrivateKey::GetMinimumValue,
             ResolverOptionPrivateKey::GetMaximumValue);

    CONTAINS(OptionContext_PrivateSet,
             ResolverOptionPrivateKey::SetMinimumValue,
             ResolverOptionPrivateKey::SetMaximumValue);

    CONTAINS(OptionContext_PrivateRequest,
             ResolverOptionPrivateKey::RequestMinimumValue,
             ResolverOptionPrivateKey::RequestMaximumValue);

    // local keys are never handled by the server; those keys always
    // return false

    if (0 == valid)
    {
        LogMajor("ResolverOptionKey not in OptionContext set; "
                 "key: %s (%u), set: d%u\n",
                 resolver::GetResolverOptionKeyString(key), key,
                 set);
    };

    return valid != 0;
#undef CONTAINS
};

bool ValidateOption(uint32_t version,
                    const OptionContext mask,
                    const ResolverOption& option,
                    bool validateKeyVersion,
                    bool validateOptionContents)
{
    LogDebug("version: %d, "
             "mask: %x: "
             "option: %p, "
             "validateKeyVersion: %d, "
             "validateOptionContents: %d\n",
             version,
             mask,
             &option,
             validateKeyVersion,
             validateOptionContents);

    bool rc = false;

    // function
    if (!SetContainsKey(mask, option.key))
    {
        goto bail;
    }
    // form
    else if (true == validateKeyVersion &&
             !ValidateKeyVersion(option.key, version))
    {
        goto bail;
    }
    else if (true == validateOptionContents
             && !ValidateOptionContents(option))
    {
        goto bail;
    };

    rc = true;

bail:
    if (false == rc)
    {
        LogMajor("Failed to validate Resolver option: %p, version: %u "
                 "(expected: %u) key: %s (%u), type: %s (%u), "
                 "mask: %x\n",
                 &option, version, g_ResolverOptionVersion,
                 resolver::GetResolverOptionKeyString(option.key),
                 option.key,
                 resolver::GetResolverOptionTypeString(option.type),
                 option.type,
                 mask);
    };
    return rc;
};

bool ValidateOptionForVersionAndMask(const ResolverOption& option,
                                     uint32_t version,
                                     const OptionContext mask)
{
    LogDebug("option: %p, option.key: %s, version: %d, mask: %x\n",
             &option,
             resolver::GetResolverOptionKeyString(option.key),
             version, mask);
    return ValidateOption(version, mask, option, true, true);
};

bool ValidateKeyForVersionAndMask(const ResolverOptionKey key,
                                  uint32_t version,
                                  const OptionContext mask)
{
    LogDebug("key: %s (%u), version: %d, mask: %x\n",
             resolver::GetResolverOptionKeyString(key),
             version, mask);

    struct ResolverOption option;
    option.key = key;
    return ValidateOption(version, mask, option, true, false);
};

bool ValidateOptionsArrayForVersionAndMask(const ResolverOption* pOptions,
                                           size_t count,
                                           const uint32_t version,
                                           OptionContext mask)
{
    LogDebug("pOptions: %p, count: %zu, version: %d, mask: %x\n",
             pOptions, count, version, mask);

    bool rc = false;
    errno = EUCLEAN;

    for (unsigned idx=0; idx<count; ++idx)
    {
        LogDebug("Validating option #%u\n", idx + 1);
        const ResolverOption& option = pOptions[idx];
        if (!ValidateOption(version, mask, option, true, true))
        {
            goto bail;
        };
    };

    rc = true;
    errno = 0;

bail:
    return rc;
};

// make request options
template <typename T>
static
int GetRequestOption(T& out,
                     const size_t outSize,
                     const unsigned version,
                     const ResolverOptionKey key,
                     const ResolverOptionType type,
                     const ResolverOption* pUserOptions,
                     const size_t userOptionsCount)
{
#define CopyValue(pOut, native, pOption, key, type_t, value)            \
    do                                                                  \
    {                                                                   \
        if (nullptr != pOption && type_t == pOption->type)              \
        {                                                               \
            LogDebug("Copying key: %s (%u), "                           \
                     "native: %s, type: %s, option: %p, "               \
                     "option { key: %u, type: %s, size: %d, "           \
                     "data.%s: %x. }\n",                                \
                     resolver::GetResolverOptionKeyString(key),         \
                     key,                                               \
                     #native, #type_t, pOption,                         \
                     pOption->key,                                      \
                     resolver::GetResolverOptionTypeString(pOption->type), \
                     pOption->size, #value, pOption->data.value);       \
            *reinterpret_cast<native>(&pOut) = pOption->data.value;     \
            rc = 0;                                                     \
        };                                                              \
    } while(NN_STATIC_CONDITION(false))

    struct OptionsAndCount
    {
        const struct ResolverOption* options;
        size_t count;
    };

    int rc = -1;
    const unsigned SearchMax = 2;
    OptionsAndCount searches[SearchMax] = { {0}, {0} };
    unsigned max = 0;
    const ResolverOption* pOption = nullptr;

    // always check the user array first
    if (nullptr != pUserOptions)
    {
        searches[max].options = pUserOptions;
        searches[max].count = userOptionsCount;
        ++max;
    };

    // and if it isn't provided then get it from defaults
    for (unsigned idx = 0; idx<DefaultRequestOptionCount; ++idx)
    {
        if (version == DefaultRequestOptionTable[idx].version)
        {
            searches[max].options = DefaultRequestOptionTable[idx].options;
            searches[max].count =
                static_cast<uint32_t>(ResolverOptionKey::RequestMaximumValue) -
                static_cast<uint32_t>(ResolverOptionKey::RequestMinimumValue);
            ++max;
            break;
        };
    };

    for (unsigned idx=0; idx<max; ++idx)
    {
        for (unsigned jdx=0; jdx<searches[idx].count; ++jdx)
        {
            pOption = &searches[idx].options[jdx];

            if (key == pOption->key)
            {
                if (type != pOption->type)
                {
                    LogMinor("Type check failed; "
                             "pOption: %p, "
                             "provided %s (%u), "
                             "expected: %s (%u)\n",
                             pOption,
                             resolver::GetResolverOptionTypeString(type), type,
                             resolver::GetResolverOptionTypeString(pOption->type), pOption->type);
                    goto bail;
                }
                else if (outSize < pOption->size)
                {
                    LogMinor("Size check failed: "
                             "pOption: %p, "
                             "provided %zu, "
                             "expected: %zu\n",
                             pOption,
                             outSize,
                             pOption->size);

                    goto bail;
                };
                goto copy;
            };
        };
    };

    // if the search for loop completes then bypass copying
    goto bail;

copy:
    CopyValue(out, bool*,     pOption, key, ResolverOptionType::Boolean,    booleanValue);
    CopyValue(out, int*,      pOption, key, ResolverOptionType::Integer,    integerValue);
    CopyValue(out, uint32_t*, pOption, key, ResolverOptionType::Unsigned32, unsigned32Value);
    CopyValue(out, uint64_t*, pOption, key, ResolverOptionType::Unsigned64, unsigned64Value);
    CopyValue(out, double*,   pOption, key, ResolverOptionType::Double,     doubleValue);

bail:
    return rc;
#undef CopyValue
};

int GetRequestOptionValue(bool& out,
                          const uint32_t version,
                          const ResolverOptionKey key,
                          const ResolverOption* pOptions,
                          size_t count)
{
    return GetRequestOption<bool>(out,
                                  sizeof(out),
                                  version,
                                  key,
                                  ResolverOptionType::Boolean,
                                  pOptions,
                                  count);
};

int GetRequestOptionValue(int& out,
                          const uint32_t version,
                          const ResolverOptionKey key,
                          const ResolverOption* pOptions,
                          size_t count)
{
    return GetRequestOption<int>(out,
                                 sizeof(out),
                                 version,
                                 key,
                                 ResolverOptionType::Integer,
                                 pOptions,
                                 count);
};

int GetRequestOptionValue(uint32_t& out,
                          const uint32_t version,
                          const ResolverOptionKey key,
                          const ResolverOption* pOptions,
                          size_t count)
{
    return GetRequestOption<uint32_t>(out,
                                      sizeof(out),
                                      version,
                                      key,
                                      ResolverOptionType::Unsigned32,
                                      pOptions,
                                      count);
};

int GetRequestOptionValue(uint64_t& out,
                          const uint32_t version,
                          const ResolverOptionKey key,
                          const ResolverOption* pOptions,
                          size_t count)
{
    return GetRequestOption<uint64_t>(out,
                                      sizeof(out),
                                      version,
                                      key,
                                      ResolverOptionType::Unsigned64,
                                      pOptions,
                                      count);
};

int GetRequestOptionValue(double& out,
                          const uint32_t version,
                          const ResolverOptionKey key,
                          const ResolverOption* pOptions,
                          size_t count)
{
    return GetRequestOption<double>(out,
                                    sizeof(out),
                                    version,
                                    key,
                                    ResolverOptionType::Double,
                                    pOptions,
                                    count);
};

int GetRequestOptionValue(const char *& out,
                          const uint32_t version,
                          const ResolverOptionKey key,
                          const ResolverOption* pOptions,
                          size_t count)
{
    // validate that it's a string ?
    char* pointerValue = NULL;
    int rc = GetRequestOption<char*>(pointerValue,
                                     sizeof(pointerValue),
                                     version,
                                     key,
                                     ResolverOptionType::Pointer,
                                     pOptions,
                                     count);
    out = pointerValue;
    return rc;
};

}}; // nn::socket
