﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/socket/socket_Api.h>
#include <nn/dns/parser.h>

//#define NN_DNSPARSER_LOG_LEVEL NN_DNSPARSER_LOG_LEVEL_HEX
#define NN_DNSPARSER_LOG_MODULE_NAME "Message" // NOLINT(preprocessor/const)
#include "dns_ParserLog.h"
#include "dns_ParserMacros.h"

extern "C"
{
#include <nnc/dns/parser.h>
};

NN_DNSPARSER_STATIC_ASSERT(sizeof(struct nndnsparserMessage) == sizeof(nn::dns::parser::Message));
NN_DNSPARSER_STATIC_ASSERT(alignof(struct nndnsparserMessage) == alignof(nn::dns::parser::Message));

namespace nn { namespace dns { namespace parser {

Message::Message() :
    m_pBuffer(nullptr),
    m_Size(0),
    m_DirtySize(0),
    m_Header(Header())
{
    m_HeaderSection.pStart = m_HeaderSection.pEnd = nullptr;
    m_QuestionSection.pStart = m_QuestionSection.pEnd = nullptr;
    m_AnswerSection.pStart = m_AnswerSection.pEnd = nullptr;
    m_AuthoritySection.pStart = m_AuthoritySection.pEnd = nullptr;
    m_AdditionalSection.pStart = m_AdditionalSection.pEnd = nullptr;
}

Message::Message(const Message& rhs) :
    m_pBuffer(rhs.m_pBuffer),
    m_Size(rhs.m_Size),
    m_DirtySize(rhs.m_DirtySize),
    m_HeaderSection(rhs.m_HeaderSection),
    m_QuestionSection(rhs.m_QuestionSection),
    m_AnswerSection(rhs.m_AnswerSection),
    m_AuthoritySection(rhs.m_AuthoritySection),
    m_AdditionalSection(rhs.m_AdditionalSection),
    m_Header(rhs.m_Header)
{
};

Message::~Message()
{
};

const uint8_t* Message::GetBuffer() const NN_NOEXCEPT
{
    return m_pBuffer;
};

Message& Message::operator=(const Message& rhs)
{
    if (this == &rhs)
    {
        goto bail;
    };

    m_pBuffer = rhs.m_pBuffer;
    m_Size = rhs.m_Size;
    m_DirtySize = rhs.m_DirtySize;
    m_HeaderSection = rhs.m_HeaderSection;
    m_QuestionSection = rhs.m_QuestionSection;
    m_AnswerSection = rhs.m_AnswerSection;
    m_AuthoritySection = rhs.m_AuthoritySection;
    m_AdditionalSection = rhs.m_AdditionalSection;
    m_Header = rhs.m_Header;

bail:
    return *this;
};

size_t Message::GetBufferSize() const NN_NOEXCEPT
{
    return m_Size;
};

MemoryBlock& Message::GetHeaderSection() NN_NOEXCEPT
{
    return m_HeaderSection;
};

MemoryBlock& Message::GetQuestionSection() NN_NOEXCEPT
{
    return m_QuestionSection;
};

MemoryBlock& Message::GetAnswerSection() NN_NOEXCEPT
{
    return m_AnswerSection;
};

MemoryBlock& Message::GetAuthoritySection() NN_NOEXCEPT
{
    return m_AuthoritySection;
};

MemoryBlock& Message::GetAdditionalSection() NN_NOEXCEPT
{
    return m_AdditionalSection;
};

Header& Message::GetHeader() NN_NOEXCEPT
{
    return m_Header;
};

void Message::Initialize()
{
    NN_DNSPARSER_STATIC_ASSERT(offsetof(struct nndnsparserMessage, pBuffer) ==
                               offsetof(nn::dns::parser::Message, m_pBuffer));

    NN_DNSPARSER_STATIC_ASSERT(offsetof(struct nndnsparserMessage, size) ==
                               offsetof(nn::dns::parser::Message, m_Size));

    NN_DNSPARSER_STATIC_ASSERT(offsetof(struct nndnsparserMessage, dirtySize) ==
                               offsetof(nn::dns::parser::Message, m_DirtySize));

    NN_DNSPARSER_STATIC_ASSERT(offsetof(struct nndnsparserMessage, headerSection) ==
                               offsetof(nn::dns::parser::Message, m_HeaderSection));

    NN_DNSPARSER_STATIC_ASSERT(offsetof(struct nndnsparserMessage, questionSection) ==
                               offsetof(nn::dns::parser::Message, m_QuestionSection));

    NN_DNSPARSER_STATIC_ASSERT(offsetof(struct nndnsparserMessage, answerSection) ==
                               offsetof(nn::dns::parser::Message, m_AnswerSection));

    NN_DNSPARSER_STATIC_ASSERT(offsetof(struct nndnsparserMessage, authoritySection) ==
                               offsetof(nn::dns::parser::Message, m_AuthoritySection));

    NN_DNSPARSER_STATIC_ASSERT(offsetof(struct nndnsparserMessage, additionalSection) ==
                               offsetof(nn::dns::parser::Message, m_AdditionalSection));

    NN_DNSPARSER_STATIC_ASSERT(offsetof(struct nndnsparserMessage, header) ==
                               offsetof(nn::dns::parser::Message, m_Header));

    memset(this, 0, sizeof(*this));
    m_Header.Initialize(this);

    return;
};


namespace {

/**
 * @brief Calculates the size of the provided @ref Section containing
 * either @ref Question or @ref Record objects.
 *
 * @param[in] message The message containing the section.
 *
 * @param[in] section The provided @ref Section.
 */
template <typename TIterator>
ssize_t SizeOfSection(const Message& message, MessageSectionConstant tag)
{
    NN_DNSPARSER_LOG_DEBUG("Message: %p, tag: %d\n", &message, tag);

    ssize_t rc = -1;
    ssize_t sz = 0;
    int count;
    MemoryBlock section;
    TIterator iter;

    switch (tag)
    {
    case MessageSectionConstant::Question:
        count = const_cast<Message&>(message).GetHeader().GetQuestionCount();
        section = const_cast<Message&>(message).GetQuestionSection();
        break;
    case MessageSectionConstant::Answer:
        count = const_cast<Message&>(message).GetHeader().GetAnswerCount();
        section = const_cast<Message&>(message).GetAnswerSection();
        break;
    case MessageSectionConstant::Authority:
        count = const_cast<Message&>(message).GetHeader().GetAuthorityCount();
        section = const_cast<Message&>(message).GetAuthoritySection();
        break;
    case MessageSectionConstant::Additional:
        count = const_cast<Message&>(message).GetHeader().GetAdditionalCount();
        section = const_cast<Message&>(message).GetAdditionalSection();
        break;
    default:
        goto bail;
    };

    if (count == 0)
    {
        rc = 0;
        goto bail;
    };

    rc = iter.Initialize(message, tag);
    if (-1 == rc)
    {
        goto bail;
    };

    for (;
         1 == rc;
         rc = iter.GetNext())
    {
        ssize_t val = iter.GetCurrent().SizeOf();
        if (-1 == val)
        {
            goto bail;
        };
        sz += val;
    };

    rc = sz;

bail:
    NN_DNSPARSER_LOG_DEBUG("returning: %zd\n", rc);
    return rc;

};
}; // end anonymous namespace

ssize_t Message::SizeOf() const NN_NOEXCEPT
{
    NN_DNSPARSER_LOG_DEBUG("\n");
    ssize_t rc = -1;
    ssize_t sz = 0;

    if (-1 == (rc = m_Header.SizeOf()))
    {
        goto bail;
    };
    sz += rc;

    if (-1 == (rc = SizeOfSection<QuestionIterator>(*this, MessageSectionConstant::Question)))
    {
        NN_DNSPARSER_LOG_DEBUG("\n");
        goto bail;
    };
    sz += rc;

    if (-1 == (rc = SizeOfSection<RecordIterator>(*this, MessageSectionConstant::Answer)))
    {
        NN_DNSPARSER_LOG_DEBUG("\n");
        goto bail;
    };
    sz += rc;

    if (-1 == (rc = SizeOfSection<RecordIterator>(*this, MessageSectionConstant::Authority)))
    {
        NN_DNSPARSER_LOG_DEBUG("\n");
        goto bail;
    };
    sz += rc;

    if (-1 == (rc = SizeOfSection<RecordIterator>(*this, MessageSectionConstant::Additional)))
    {
        NN_DNSPARSER_LOG_DEBUG("\n");
        goto bail;
    };
    sz += rc;

    rc = sz;

bail:
    NN_DNSPARSER_LOG_DEBUG("returning: %zd\n", rc);
    return rc;
};

namespace {
/**
 * @brief Compares the provided section containing @ref Question or
 * @ref Record objects for equality.
 *
 * @param[in] msg1 The first message to compare.
 *
 * @param[in] msg2 The second message to compare
 *
 * @param[in] tag The provided @ref Section.
 */
template <typename TIterator>
bool SectionIsEqual(const Message& msg1, const Message& msg2, MessageSectionConstant tag)
{
    NN_DNSPARSER_LOG_DEBUG("Message1: %p, Message2: %p, tag: %d\n", &msg1, &msg2, tag);

    bool rc = false;
    int count1, count2;
    TIterator iter1, iter2;
    int rc1, rc2;

    switch (tag)
    {
    case MessageSectionConstant::Question:
        count1 = const_cast<Message&>(msg1).GetHeader().GetQuestionCount();
        count2 = const_cast<Message&>(msg2).GetHeader().GetQuestionCount();
        break;
    case MessageSectionConstant::Answer:
        count1 = const_cast<Message&>(msg1).GetHeader().GetAnswerCount();
        count2 = const_cast<Message&>(msg2).GetHeader().GetAnswerCount();
        break;
    case MessageSectionConstant::Authority:
        count1 = const_cast<Message&>(msg1).GetHeader().GetAuthorityCount();
        count2 = const_cast<Message&>(msg2).GetHeader().GetAuthorityCount();
        break;
    case MessageSectionConstant::Additional:
        count1 = const_cast<Message&>(msg1).GetHeader().GetAdditionalCount();
        count2 = const_cast<Message&>(msg2).GetHeader().GetAdditionalCount();
        break;
    default:
        goto bail;
    };

    if (count1 != count2)
    {
        goto bail;
    };

    if (count1 == 0)
    {
        rc = true;
        goto bail;
    };

    rc1 = iter1.Initialize(msg1, tag);
    rc2 = iter2.Initialize(msg2, tag);

    if (-1 == rc1 || -1 == rc2)
    {
        goto bail;
    };

    while (1 == rc1 && 1 == rc2)
    {
        if (!(iter1.GetCurrent() == iter2.GetCurrent()))
        {
            goto bail;
        };

        rc1 = iter1.GetNext();
        rc2 = iter2.GetNext();
        if (-1 == rc1 || -1 == rc2)
        {
            goto bail;
        };
    };

    rc = true;

bail:
    return rc;
};
}; // end anonymous namespace

bool Message::operator==(const Message& that) const NN_NOEXCEPT
{
    bool rc = false;

    if (this == &that)
    {
        rc = true;
        goto bail;
    };

    if (const_cast<Header&>(m_Header).GetQuestionCount() != const_cast<Header&>(that.m_Header).GetQuestionCount())
    {
        goto bail;
    };

    if (!(m_Header == that.m_Header))
    {
        goto bail;
    };

    if (false == SectionIsEqual<QuestionIterator>(*this, that, MessageSectionConstant::Question))
    {
        NN_DNSPARSER_LOG_DEBUG("\n");
        goto bail;
    };

    if (false == SectionIsEqual<RecordIterator>(*this, that, MessageSectionConstant::Answer))
    {
        NN_DNSPARSER_LOG_DEBUG("\n");
        goto bail;
    };

    if (false == SectionIsEqual<RecordIterator>(*this, that, MessageSectionConstant::Authority))
    {
        NN_DNSPARSER_LOG_DEBUG("\n");
        goto bail;
    };

    if (false == SectionIsEqual<RecordIterator>(*this, that, MessageSectionConstant::Additional))
    {
        NN_DNSPARSER_LOG_DEBUG("\n");
        goto bail;
    };

    rc = true;

bail:
    return rc;
}

namespace
{
/**
 * @brief Traverse the section buffer.
 */
template <typename TIterator>
ssize_t SectionFromBuffer(Message& message, const uint8_t* pCursor, size_t size, MessageSectionConstant tag)
{
    NN_DNSPARSER_LOG_DEBUG("message: %p, pCursor: %p, size: %zu, tag: %d\n", &message, pCursor, size, tag);

    int rc = -1;
    uint16_t* pCount;
    MemoryBlock* pSection;
    TIterator iter;
    auto* pEom = pCursor + size;

    switch (tag)
    {
    case MessageSectionConstant::Question:
        pCount = &message.GetHeader().GetQuestionCount();
        pSection = &message.GetQuestionSection();
        break;
    case MessageSectionConstant::Answer:
        pCount = &message.GetHeader().GetAnswerCount();
        pSection = &message.GetAnswerSection();
        break;
    case MessageSectionConstant::Authority:
        pCount = &message.GetHeader().GetAuthorityCount();
        pSection = &message.GetAuthoritySection();
        break;
    case MessageSectionConstant::Additional:
        pCount = &message.GetHeader().GetAdditionalCount();
        pSection = &message.GetAdditionalSection();
        break;
    default:
        goto bail;
    };

    if (0 == *pCount)
    {
        rc = 0;
        goto bail;
    }

    pSection->pStart = pCursor;
    pSection->pEnd = pEom;

    rc = iter.Initialize(message, tag);
    if (-1 == rc)
    {
        goto bail;
    };

    for (;
         1 == rc;
         rc = iter.GetNext())
    {
        rc = iter.GetCurrent().SizeOf();
        if (-1 == rc)
        {
            goto bail;
        };

        pCursor += rc;
    };

    pSection->pEnd = pCursor;
    rc = pSection->pEnd - pSection->pStart;

bail:
    NN_DNSPARSER_LOG_DEBUG("returning: %zd\n", rc);
    return rc;
};
}; // end anonymous namespace


ssize_t Message::FromBuffer(const uint8_t* pBuffer, size_t size) NN_NOEXCEPT
{
    NN_DNSPARSER_LOG_DEBUG("pBuffer: %p, size: %zu\n", pBuffer, size);

    ssize_t rc = -1;
    size_t left = size;
    const uint8_t* pCursor = pBuffer;
    Initialize();
    m_pBuffer = pBuffer;
    m_Size = size;
    NN_DNSPARSER_ERROR_IF(nullptr == pCursor, bail);

    if (-1 == (rc = m_Header.FromBuffer(pCursor, left)))
    {
        goto bail;
    };

    m_HeaderSection.pStart = pCursor;
    m_HeaderSection.pEnd = pCursor + rc;
    left -= rc;
    pCursor += rc;

    if (0 != m_Header.GetQuestionCount())
    {
        if ( -1 == (rc = SectionFromBuffer<QuestionIterator>(*this,
                                                             pCursor,
                                                             left,
                                                             MessageSectionConstant::Question)))
        {
            NN_DNSPARSER_LOG_DEBUG("\n");
            goto bail;
        };
        left -= rc;
        pCursor += rc;
    };

    if (0 != m_Header.GetAnswerCount())
    {
        if ( -1 == (rc = SectionFromBuffer<RecordIterator>(*this,
                                                           pCursor,
                                                           left,
                                                           MessageSectionConstant::Answer)))
        {
            NN_DNSPARSER_LOG_DEBUG("\n");
            goto bail;
        };
        left -= rc;
        pCursor += rc;
    };

    if (0 != m_Header.GetAuthorityCount())
    {
        m_AuthoritySection.pStart = pCursor;
        if ( -1 == (rc = SectionFromBuffer<RecordIterator>(*this,
                                                           pCursor,
                                                           left,
                                                           MessageSectionConstant::Authority)))
        {
            NN_DNSPARSER_LOG_DEBUG("\n");
            goto bail;
        };
        left -= rc;
        pCursor += rc;
        m_AuthoritySection.pEnd = pCursor;
    };

    if (0 != m_Header.GetAdditionalCount())
    {
        if ( -1 == (rc = SectionFromBuffer<RecordIterator>(*this,
                                                           pCursor,
                                                           left,
                                                           MessageSectionConstant::Additional)))
        {
            NN_DNSPARSER_LOG_DEBUG("\n");
            goto bail;
        };
        left -= rc;
        pCursor += rc;
    };

    rc = m_DirtySize = pCursor - pBuffer;

bail:
    NN_DNSPARSER_LOG_DEBUG("returning: %zd\n", rc);
    return rc;
};

namespace
{
template <typename TIterator>
ssize_t SectionToBuffer(const Message& message, uint8_t * const pBuffer, size_t left, MessageSectionConstant tag)
{
    NN_DNSPARSER_LOG_DEBUG("Message: %p, pBuffer: %p, size: %zu, tag: %d\n", &message, pBuffer, left, tag);

    ssize_t rc = -1;
    size_t total = 0;
    uint16_t* pCount;
    MemoryBlock* pSection;
    TIterator iter;
    uint8_t* pCursor = pBuffer;

    switch (tag)
    {
    case MessageSectionConstant::Question:
        pCount = &const_cast<Message&>(message).GetHeader().GetQuestionCount();
        pSection = &const_cast<Message&>(message).GetQuestionSection();
        break;
    case MessageSectionConstant::Answer:
        pCount = &const_cast<Message&>(message).GetHeader().GetAnswerCount();
        pSection = &const_cast<Message&>(message).GetAnswerSection();
        break;
    case MessageSectionConstant::Authority:
        pCount = &const_cast<Message&>(message).GetHeader().GetAuthorityCount();
        pSection = &const_cast<Message&>(message).GetAuthoritySection();
        break;
    case MessageSectionConstant::Additional:
        pCount = &const_cast<Message&>(message).GetHeader().GetAdditionalCount();
        pSection = &const_cast<Message&>(message).GetAdditionalSection();
        break;
    default:
        goto bail;
    };

    if (*pCount == 0)
    {
        rc = 0;
        goto bail;
    };

    if (-1 == (rc = iter.Initialize(message, tag)))
    {
        goto bail;
    };

    for (;
         1 == rc;
         rc = iter.GetNext())
    {
        rc = iter.GetCurrent().ToBuffer(pCursor, left);
        if (-1 == rc )
        {
            NN_DNSPARSER_LOG_DEBUG("\n");
            goto bail;
        };

        total += rc;
        pCursor += rc;
        left -= rc;
    };

    rc = total;

bail:
    NN_DNSPARSER_LOG_DEBUG("returning: %zd\n", rc);
    return rc;
};
}; // end anonymous namespace

ssize_t Message::ToBuffer(uint8_t* const pBuffer, size_t size) const NN_NOEXCEPT
{
    NN_DNSPARSER_LOG_DEBUG("pBuffer: %p, size: %zu\n", pBuffer, size);

    ssize_t rc = -1;
    uint8_t* pCursor = pBuffer;
    RecordIterator iter;

    NN_DNSPARSER_ERROR_IF(nullptr == pBuffer, bail);

    if (-1 == (rc = m_Header.ToBuffer(pCursor, size)))
    {
        NN_DNSPARSER_LOG_DEBUG("\n");
        goto bail;
    };
    size -= rc;
    pCursor += rc;

    if ( -1 == (rc = SectionToBuffer<QuestionIterator>(*this, pCursor, size, MessageSectionConstant::Question)))
    {
        NN_DNSPARSER_LOG_DEBUG("\n");
        goto bail;
    };
    size -= rc;
    pCursor += rc;

    if ( -1 == (rc = SectionToBuffer<RecordIterator>(*this, pCursor, size, MessageSectionConstant::Answer)))
    {
        NN_DNSPARSER_LOG_DEBUG("\n");
        goto bail;
    };
    size -= rc;
    pCursor += rc;

    if ( -1 == (rc = SectionToBuffer<RecordIterator>(*this, pCursor, size, MessageSectionConstant::Authority)))
    {
        NN_DNSPARSER_LOG_DEBUG("\n");
        goto bail;
    };
    size -= rc;
    pCursor += rc;

    if ( -1 == (rc = SectionToBuffer<RecordIterator>(*this, pCursor, size, MessageSectionConstant::Additional)))
    {
        NN_DNSPARSER_LOG_DEBUG("\n");
        goto bail;
    };
    size -= rc;
    pCursor += rc;

    rc = pCursor - pBuffer;

bail:
    NN_DNSPARSER_LOG_DEBUG("returning: %zd\n", rc);
    return rc;
};

}}}; //nn::dnsparser
