﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/socket/socket_Api.h>
#include <nn/dns/parser.h>

//#define NN_DNSPARSER_LOG_LEVEL NN_DNSPARSER_LOG_LEVEL_HEX
#define NN_DNSPARSER_LOG_MODULE_NAME "LabelValidate" // NOLINT(preprocessor/const)
#include "dns_ParserLog.h"
#include "dns_ParserMacros.h"

namespace nn { namespace dns { namespace parser {

bool IsLabelTypePointer(const uint8_t value) NN_NOEXCEPT;
uint16_t LabelGetOffsetUnsafe(const uint8_t* pCursor) NN_NOEXCEPT;

struct rangelist
{
    const uint8_t* pBegin;
    const uint8_t* pEnd;
    struct rangelist* pNext;
};

/**
 * @brief Checks whether or not pointer ranges are overlapping in the
 * following manner where there exists a "haystack" range C-F the
 * following ranges are valid:
 *
 *       *-----o
 *   A B C D E F G H
 *   *-o |     |
 *     *-o     |
 *   *---o     |
 *             o-->
 *
 * by implication C-D, D-E, E-F, C-F, A-D, A-E, A-F, etc. are invalid.
 * @return True if any ranges overlap.
 **/
static
bool PointerRangesOverlap(struct rangelist* pHead) NN_NOEXCEPT
{
    NN_DNSPARSER_LOG_DEBUG("pHead: %p\n", pHead);

    bool rc = true;

    for (struct rangelist* pHaystack = pHead; pHaystack != nullptr; pHaystack = pHaystack->pNext)
    {
        if (pHaystack->pBegin == nullptr)
        {
            NN_DNSPARSER_LOG_DEBUG("\n");
            goto bail;
        }
        else if (pHaystack->pEnd == nullptr)
        {
            NN_DNSPARSER_LOG_DEBUG("\n");
            goto bail;
        }
        else if (pHaystack->pEnd <= pHaystack->pBegin)
        {
            NN_DNSPARSER_LOG_DEBUG("\n");
            goto bail;
        };

        for (struct rangelist* pNeedle = pHaystack->pNext; pNeedle != nullptr; pNeedle = pNeedle->pNext)
        {
            NN_DNSPARSER_LOG_DEBUG("h.pBegin: %p h.pEnd: %p, n.pBegin: %p, n.pEnd: %p\n",
                                   pHaystack->pBegin, pHaystack->pEnd, pNeedle->pBegin, pNeedle->pEnd);
            if (pNeedle->pBegin == nullptr)
            {
                NN_DNSPARSER_LOG_DEBUG("\n");
                goto bail;
            }
            else if (pNeedle->pEnd == nullptr)
            {
                NN_DNSPARSER_LOG_DEBUG("\n");
                goto bail;
            }
            else if (pNeedle->pEnd <= pNeedle->pBegin)
            {
                NN_DNSPARSER_LOG_DEBUG("\n");
                goto bail;
            }
            else if (!(pNeedle->pEnd <= pHaystack->pBegin || pNeedle->pBegin >= pHaystack->pEnd))
            {
                NN_DNSPARSER_LOG_DEBUG("pHaystack: {begin: %p, end:%p contains pNeedle: {begin: %p, end: %p}\n. ",
                                       pHaystack->pBegin, pHaystack->pEnd, pNeedle->pBegin, pNeedle->pEnd);
                goto bail;
            };
        };
    };
    rc = false;

bail:
    return rc;
};

static
bool IsLabelValidInternal(const uint8_t* pMessageData,
                          const uint8_t* pEndOfMessage,
                          const uint8_t* pCursor,
                          struct rangelist* pHeadIn,
                          struct rangelist* pPrevious) NN_NOEXCEPT
{
    NN_DNSPARSER_LOG_DEBUG("pMessageData: %p, pEndOfMessage: %p, pCursor: %p, pHeadIn: %p, pPrevious: %p\n",
                           pMessageData, pEndOfMessage, pCursor, pHeadIn, pPrevious);
    bool rc = false;
    uint8_t typeLength = 0;
    size_t rangeSize = 0;
    uint16_t offset = 0;
    struct rangelist range = { 0 };
    struct rangelist* pHead = (pHeadIn == nullptr ? &range : pHeadIn);

    typeLength = *pCursor;
    range.pBegin = pCursor;

    if (nullptr != pPrevious)
    {
        pPrevious->pNext = &range;
    };

    if (IsLabelTypePointer(typeLength))
    {
        rangeSize = sizeof(uint16_t);
        range.pEnd = pCursor + rangeSize;
        NN_DNSPARSER_LOG_HEX(range.pBegin, rangeSize,
                             "begin: %p, end: %p, label pointer hex: ",
                             range.pBegin, range.pEnd);

        // check to see that there is at least another byte
        if (pEndOfMessage - pCursor < sizeof(uint16_t))
        {
            NN_DNSPARSER_LOG_DEBUG("\n");
            goto bail;
        };

        // self referential pointer is not allowed
        if (pCursor == pMessageData + LabelGetOffsetUnsafe(pCursor))
        {
            NN_DNSPARSER_LOG_DEBUG("\n");
            goto bail;
        };

        if (PointerRangesOverlap(pHead))
        {
            NN_DNSPARSER_LOG_DEBUG("\n");
            goto bail;
        };

        offset = LabelGetOffsetUnsafe(pCursor);
        pCursor = pMessageData + offset;
        rc = IsLabelValidInternal(pMessageData, pEndOfMessage, pCursor, pHead, &range);
        goto bail;
    }
    // typelength of zero indicates the root
    else if (0 == typeLength)
    {
        rangeSize = sizeof(uint8_t);
        range.pEnd = pCursor + rangeSize;
        NN_DNSPARSER_LOG_HEX(range.pBegin, rangeSize,
                             "begin: %p, end: %p, label end hex: ",
                             range.pBegin, range.pEnd);

        rc = !PointerRangesOverlap(pHead);
    }
    else
    {
        // rangesize is computed as the sizeof the length parameter plus
        // the number of characters
        rangeSize = sizeof(uint8_t) + sizeof(uint8_t) * typeLength;
        range.pEnd = pCursor + rangeSize;
        NN_DNSPARSER_LOG_HEX(range.pBegin, rangeSize,
                             "begin: %p, end: %p, label data hex: ",
                             range.pBegin, range.pEnd);

        rc = IsLabelValidInternal(pMessageData, pEndOfMessage, range.pEnd, pHead, &range);
    };

bail:
    return rc;
};

bool IsLabelValid(const uint8_t* pMessageData,
                  const uint8_t* pEndOfMessage,
                  const uint8_t* pCursor) NN_NOEXCEPT
{
    return IsLabelValidInternal(pMessageData,
                                pEndOfMessage,
                                pCursor,
                                nullptr,
                                nullptr);
};

}}}; //nn::dnsparser
