﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include "dhcps_BpfManager.h"
#include "dhcps_Coordinator.h"

//#define NN_DETAIL_DHCPS_LOG_LEVEL NN_DETAIL_DHCPS_LOG_LEVEL_DEBUG
#define NN_DETAIL_DHCPS_LOG_MODULE_NAME "BpfManager"
#include "dhcps_Log.h"

namespace nn { namespace dhcps { namespace detail {

extern Config g_Config;

namespace {

/**
 * @brief This structure exists for the 32-bit use case. It forces
 * the timeval seconds and useconds to be 64-bit values and the rest
 * of the structure mirrors the BpfHeader structure.
 */
struct BpfHeaderTechnique
{
    struct TimeVal64
    {
        uint64_t seconds;
        uint64_t useconds;
    };

    TimeVal64 bh_tstamp;
    nn::socket::BpfUInt32 bh_caplen;
    nn::socket::BpfUInt32 bh_datalen;
    uint16_t bh_hdrlen;
};

int BpfOpen() NN_NOEXCEPT
{

    char device[] = "/dev/bpf";
    char deviceBuffer[static_cast<size_t>(LibraryConstants::MaximumDeviceStringBufferSize)];
    int idx = 0;
    int fd = -1;

    do
    {
        nn::util::SNPrintf(reinterpret_cast<char*>(deviceBuffer),
                           sizeof(deviceBuffer),
                           "%s%u", device, idx);

        NN_DETAIL_DHCPS_LOG_DEBUG("Attempting to open: %s\n", deviceBuffer);

        fd = nn::socket::Open((char *) &deviceBuffer, nn::socket::OpenFlag::O_RdWr);
        ++idx;
    }
    while (fd < 0 && idx < static_cast<size_t>(LibraryConstants::MaximumDevices));

    return fd;
};

int BpfSetOptions(int fd, const char* interface)  NN_NOEXCEPT
{
    int rc = -1;
    int yes = 1;
    nn::socket::IfReq ifr;
    nn::socket::Errno errorNumber;

    strncpy(reinterpret_cast<char*>(ifr.ifr_name), interface, sizeof(ifr.ifr_name) - 1);

    if (-1 == (rc = nn::socket::Ioctl(fd,
                                      static_cast<nn::socket::IoctlCommand>(nn::socket::IoctlCommandPrivate::BiocSetIf),
                                      &ifr,
                                      sizeof(ifr))))
    {
        errorNumber = nn::socket::GetLastError();
        NN_DETAIL_DHCPS_LOG_MAJOR("error ioctl BiocSetIf: %s (%d)\n",
                                       strerror(static_cast<int>(errorNumber)),
                                       errorNumber);
        goto bail;
    }
    else if (-1 == (rc = nn::socket::Ioctl(fd,
                                           static_cast<nn::socket::IoctlCommand>(nn::socket::IoctlCommandPrivate::BiocImmediate),
                                           &yes,
                                           sizeof(yes))))
    {
        errorNumber = nn::socket::GetLastError();
        NN_DETAIL_DHCPS_LOG_MAJOR("error ioctl BiocImmediate: %s (%d).\n",
                                       strerror(static_cast<int>(errorNumber)),
                                       errorNumber);
        goto bail;
    };

bail:
    return rc;
};

int BpfCheckDatalink(int fd)  NN_NOEXCEPT
{

    int rc = -1;
    nn::socket::DataLinkType dlt = nn::socket::DataLinkType::Dlt_Null;

    if(-1 == (rc = nn::socket::Ioctl(fd,
                                     static_cast<nn::socket::IoctlCommand>(nn::socket::IoctlCommandPrivate::BiocGDlt),
                                     &dlt,
                                     sizeof(dlt))))
    {
        goto bail;
    };

    switch (dlt)
    {
    case nn::socket::DataLinkType::Dlt_En10Mb:
        rc = 0;
        break;
    default:
        NN_DETAIL_DHCPS_LOG_MAJOR("error ioctl BiocGDlt: unsupported type:%u\n", dlt);
    };

bail:
    return rc;
};

int BpfSetFilter(int fd)  NN_NOEXCEPT
{
    nn::socket::BpfInsn DhcpReadInstructions[] = {
        nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ld  + nn::socket::BpfCode::Bpf_B   + nn::socket::BpfCode::Bpf_Ind, 14),
        nn::socket::Bpf_Jump(nn::socket::BpfCode::Bpf_Jmp + nn::socket::BpfCode::Bpf_Jeq + nn::socket::BpfCode::Bpf_K,   (nn::socket::IpVersion << 4) + 5, 0, 12),

        // Make sure this is an IP packet...
        nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ld  + nn::socket::BpfCode::Bpf_H   + nn::socket::BpfCode::Bpf_Abs, 12),
        nn::socket::Bpf_Jump(nn::socket::BpfCode::Bpf_Jmp + nn::socket::BpfCode::Bpf_Jeq + nn::socket::BpfCode::Bpf_K,   static_cast<nn::socket::BpfUInt32>(nn::socket::EtherType::EtherType_Ip), 0, 10),

        // Make sure it's a UDP packet...
        nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ld  + nn::socket::BpfCode::Bpf_B   + nn::socket::BpfCode::Bpf_Abs, 23),
        nn::socket::Bpf_Jump(nn::socket::BpfCode::Bpf_Jmp + nn::socket::BpfCode::Bpf_Jeq +  nn::socket::BpfCode::Bpf_K,  static_cast<nn::socket::BpfUInt32>(nn::socket::Protocol::IpProto_Udp), 0, 8),

        // Make sure this isn't a fragment...
        nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ld  + nn::socket::BpfCode::Bpf_H    + nn::socket::BpfCode::Bpf_Abs, 20),
        nn::socket::Bpf_Jump(nn::socket::BpfCode::Bpf_Jmp + nn::socket::BpfCode::Bpf_Jset + nn::socket::BpfCode::Bpf_K,   0x1fff, 6, 0),

        // Get the IP header length...
        nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ldx + nn::socket::BpfCode::Bpf_B    + nn::socket::BpfCode::Bpf_Msh, 14),

        // Make sure it's from the right port..
        nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ld  + nn::socket::BpfCode::Bpf_H    + nn::socket::BpfCode::Bpf_Ind, 14),
        nn::socket::Bpf_Jump(nn::socket::BpfCode::Bpf_Jmp + nn::socket::BpfCode::Bpf_Jeq  + nn::socket::BpfCode::Bpf_K,   g_Config.GetDhcpClientUdpPort(), 0, 3),

        // Make sure it is to the right ports ...
        nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ld  + nn::socket::BpfCode::Bpf_H    + nn::socket::BpfCode::Bpf_Ind, 16),
        nn::socket::Bpf_Jump(nn::socket::BpfCode::Bpf_Jmp + nn::socket::BpfCode::Bpf_Jeq  + nn::socket::BpfCode::Bpf_K,   g_Config.GetDhcpServerUdpPort(), 0, 1),

        // If we passed all the tests, ask for the whole packet.
        nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ret + nn::socket::BpfCode::Bpf_K, static_cast<u_int>(-1)),

        // Otherwise, drop it.
        nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ret + nn::socket::BpfCode::Bpf_K, 0),
    };

    nn::socket::BpfProgram dhcpReadBpfProgram;
    memset(&dhcpReadBpfProgram, 0, sizeof(dhcpReadBpfProgram));
    dhcpReadBpfProgram.bf_len = sizeof(DhcpReadInstructions) / sizeof(nn::socket::BpfInsn);
    memcpy(dhcpReadBpfProgram.bf_insns, DhcpReadInstructions, sizeof(DhcpReadInstructions));

    int rc = -1;

    if (-1 == (rc = nn::socket::Ioctl(fd,
                                      static_cast<nn::socket::IoctlCommand>(nn::socket::IoctlCommandPrivate::BiocSetf),
                                      reinterpret_cast<struct nn::socket::BpfProgram*>(&dhcpReadBpfProgram),
                                      sizeof(dhcpReadBpfProgram))))
    {
        nn::socket::Errno errorNumber = nn::socket::GetLastError();
        NN_DETAIL_DHCPS_LOG_MAJOR("error ioctl BiocSetf: %s (%d)\n",
                                       strerror(static_cast<int>(errorNumber)),
                                       errorNumber);
    };

    return rc;
};

}; // end anonymous namespace


const char* BpfManagerStateToString(BpfManager::State in)  NN_NOEXCEPT
{
    switch (in)
    {
        NN_DETAIL_DHCPS_STRINGIFY_CASE(BpfManager::State::Uninitialized);
        NN_DETAIL_DHCPS_STRINGIFY_CASE(BpfManager::State::Initialized);
        NN_DETAIL_DHCPS_STRINGIFY_CASE(BpfManager::State::FileOpen);
        NN_DETAIL_DHCPS_STRINGIFY_CASE(BpfManager::State::SetOptions);
        NN_DETAIL_DHCPS_STRINGIFY_CASE(BpfManager::State::CheckDatalink);
        NN_DETAIL_DHCPS_STRINGIFY_CASE(BpfManager::State::SetFilter);
        NN_DETAIL_DHCPS_STRINGIFY_CASE(BpfManager::State::Ready);
        NN_DETAIL_DHCPS_STRINGIFY_CASE(BpfManager::State::FileRead);
        NN_DETAIL_DHCPS_STRINGIFY_CASE(BpfManager::State::FileWrite);
        NN_DETAIL_DHCPS_STRINGIFY_CASE(BpfManager::State::FileClosed);
        NN_DETAIL_DHCPS_STRINGIFY_CASE(BpfManager::State::FileError);
    default:
        ;
    }
    return "Unknown BpfManager::State";
};

BpfManager::BpfManager() NN_NOEXCEPT :
    m_State(State::Uninitialized),
    m_FileDescriptor(-1),
    m_pCoordinator(nullptr)
{

};

BpfManager::~BpfManager()  NN_NOEXCEPT
{
};

int BpfManager::GetFileDescriptor()  NN_NOEXCEPT
{
    return m_FileDescriptor;
};

int BpfManager::Initialize(Coordinator* pCoordinator)  NN_NOEXCEPT
{
    int rc = -1;

    if (m_pCoordinator != nullptr)
    {
        NN_SDK_ASSERT(false);
        goto bail;
    };

    m_pCoordinator = pCoordinator;

    m_PacketParser.Initialize(m_pCoordinator);

    ChangeState(State::Initialized);
    rc = 0;

bail:
    return rc;
};

int BpfManager::Finalize() NN_NOEXCEPT
{
    if (m_FileDescriptor != -1)
    {
        OnFileClose();
    };

    m_PacketParser.Finalize();

    m_pCoordinator = nullptr;

    ChangeState(State::Uninitialized);

    return 0;
};

void BpfManager::OnEvent(const InternalEvent& e) NN_NOEXCEPT
{
    switch (e.GetType())
    {
    case EventType::OnFileOpen:
        OnFileOpen();
        if ( State::FileOpen == m_State)
        {
            ChangeState(State::Ready);
        };
        break;
    case EventType::OnFileRead:
        OnFileRead();
        break;
    case EventType::OnPacketWrite:
        OnPacketWrite(reinterpret_cast<const NetworkLayersContainer*>(e.GetValue()),
                      e.GetSize());
        break;
    case EventType::OnFileClose:
        OnFileClose();
        break;
    case EventType::OnFileError:
        OnFileError();
        break;
    case EventType::OnTimerExpired:
        OnTimerExpired();
        break;
    default:
        NN_DETAIL_DHCPS_LOG_MINOR("Unhandled event: %s (%u)\n",
                                       EventTypeToString(e.GetType()),
                                       e.GetType());
        break;
    };

    return;
};

void BpfManager::ChangeState(State in) NN_NOEXCEPT
{
    NN_DETAIL_DHCPS_LOG_DEBUG("Changing state from: %s (%d) to %s (%d)\n",
                                   BpfManagerStateToString(m_State), m_State,
                                   BpfManagerStateToString(in), in);

    SetGlobalState(GlobalState::BpfManager, static_cast<uint8_t>(in));

    m_State = in;
};

void BpfManager::OnFileOpen() NN_NOEXCEPT
{
    if (-1 != m_FileDescriptor)
    {
        NN_DETAIL_DHCPS_LOG_MAJOR("Open event when Bpf fd is not closed (fd=%d)\n",
                                       m_FileDescriptor);
        goto bail;
    };

    if ( -1 == (m_FileDescriptor = BpfOpen()))
    {
        InternalEvent e(EventType::OnFileError, nullptr, 0);
        m_pCoordinator->OnEvent(e);
        goto bail;
    };

    ChangeState(State::SetOptions);
    if (-1 == (BpfSetOptions(m_FileDescriptor, g_Config.GetInterfaceName())))
    {
        InternalEvent e(EventType::OnFileError, nullptr, 0);
        m_pCoordinator->OnEvent(e);
        goto bail;
    };

    ChangeState(State::CheckDatalink);
    if ( -1 == BpfCheckDatalink(m_FileDescriptor))
    {
        InternalEvent e(EventType::OnFileError, nullptr, 0);
        m_pCoordinator->OnEvent(e);
        goto bail;
    };

    ChangeState(State::SetFilter);
    if ( -1 == BpfSetFilter(m_FileDescriptor))
    {
        InternalEvent e(EventType::OnFileError, nullptr, 0);
        m_pCoordinator->OnEvent(e);
        goto bail;
    };

    ChangeState(State::FileOpen);

bail:
    return;
};

void BpfManager::OnFileRead() NN_NOEXCEPT
{
    int rc = -1;
    int buflen = -1;
    nn::socket::Errno errorNumber;
    const size_t StaticMemorySize = 4096;
    static uint8_t pInBuffer[StaticMemorySize];

    ChangeState(State::FileRead);

    if ( -1 == (rc = nn::socket::Ioctl(m_FileDescriptor,
                                       static_cast<nn::socket::IoctlCommand>(nn::socket::IoctlCommandPrivate::BiocGBLen),
                                       &buflen,
                                       sizeof(buflen))))
    {
        errorNumber = nn::socket::GetLastError();
        NN_DETAIL_DHCPS_LOG_MAJOR("ioctl BIOCGBLEN: %s (%d)\n",
                                       strerror(static_cast<int>(errorNumber)),
                                       errorNumber);
        goto bail;
    }
    else
    {
        NetworkLayersContainer packet[static_cast<size_t>(NetworkLayer::Max)];
        memset(packet, 0, sizeof(packet));
        memset(pInBuffer, 0, sizeof(pInBuffer));
        BpfHeaderTechnique* header = reinterpret_cast<BpfHeaderTechnique*>(pInBuffer);
        unsigned int layerCount = 0;

        if (buflen > StaticMemorySize)
        {
            NN_DETAIL_DHCPS_LOG_MAJOR("dynamic buflen (%d) > static buffer size (%d)\n",
                                           buflen, StaticMemorySize);
            goto bail;
        }
        else if (-1 == (rc = nn::socket::Read(m_FileDescriptor, pInBuffer, buflen)))
        {
            errorNumber = nn::socket::GetLastError();
            NN_DETAIL_DHCPS_LOG_MAJOR("read error: %s (%d)\n",
                                           strerror(static_cast<int>(errorNumber)),
                                           errorNumber);
            goto bail;
        }
        else if (header->bh_caplen != header->bh_datalen)
        {
            NN_DETAIL_DHCPS_LOG_MAJOR("header->bh_caplen (%d) != header->bh_datalen (%d); discarding\n",
                                           header->bh_caplen, header->bh_datalen);
            rc = 0;
            goto bail;
        };

        packet[layerCount].type = NetworkLayer::BpfHeader;
        packet[layerCount].pBuffer = pInBuffer;
        packet[layerCount].size = header->bh_hdrlen;

        rc = m_PacketParser.OnEthernetPacket(packet,
                                             ++layerCount,
                                             pInBuffer + header->bh_hdrlen,
                                             rc);
    };

bail:
    if ( -1 ==  rc)
    {
        InternalEvent e(EventType::OnFileError, nullptr, 0);
        m_pCoordinator->OnEvent(e);
    };

    ChangeState(State::Ready);
};

void BpfManager::OnPacketWrite(const NetworkLayersContainer pPacket[], size_t count) NN_NOEXCEPT
{
    ssize_t rc = 0;
    size_t size = 0;
    unsigned int layerIndex = 0;

    ChangeState(State::FileWrite);

    for(layerIndex = 0; layerIndex < count; ++layerIndex)
    {
        if (pPacket[layerIndex].type == NetworkLayer::Ethernet)
        {
            break;
        };
    };

    for(unsigned int idx = layerIndex; idx < count; ++idx)
    {
        size += pPacket[idx].size;
    };

    rc = nn::socket::Write(m_FileDescriptor, pPacket[layerIndex].pBuffer, size);
    NN_DETAIL_DHCPS_LOG_DEBUG("wrote: %zd bytes\n", rc);
    ChangeState(State::Ready);
};

void BpfManager::OnFileClose() NN_NOEXCEPT
{
    int rc = -1;
    if (-1 == m_FileDescriptor)
    {
        NN_DETAIL_DHCPS_LOG_DEBUG("file descriptor is -1\n");
        goto bail;
    }

    do
    {
        rc = nn::socket::Close(m_FileDescriptor);
    }
    while ( rc == -1 && nn::socket::GetLastError() == nn::socket::Errno::EIntr);
    m_FileDescriptor = -1;

    ChangeState(State::FileClosed);

bail:
    return;
};

void BpfManager::OnFileError() NN_NOEXCEPT
{
    OnFileClose();
    ChangeState(State::FileError);
};

void BpfManager::OnTimerExpired() NN_NOEXCEPT
{
    if (true == DetectNetworkChange())
    {
        InternalEvent e(static_cast<EventType>(Event::OnDetectedNetworkChange), nullptr, 0);
        m_pCoordinator->OnEvent(e);
        OnFileClose();
    };
};

void BpfManager::GetTimeout(nn::socket::TimeVal* pTv) NN_NOEXCEPT
{
    if (pTv == nullptr)
    {
        NN_SDK_ASSERT(false);
        return;
    };

    const unsigned int Timeout = 1;
    pTv->tv_sec = Timeout;
    pTv->tv_usec = 0;
};

}}}; // nn::dhcps::detail
