﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include "dhcpc_PrivateIncludes.h"
#include <nn\socket\socket_ApiPrivate.h>

namespace nn { namespace bsdsocket { namespace dhcpc {

const nn::socket::BpfInsn PacketManager::m_BpfFilter[BPF_FILTER_ENTRIES] = {
    // Make sure this is an IP packet...
    nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ld + nn::socket::BpfCode::Bpf_H + nn::socket::BpfCode::Bpf_Abs,
             12),
    nn::socket::Bpf_Jump(nn::socket::BpfCode::Bpf_Jmp + nn::socket::BpfCode::Bpf_Jeq + nn::socket::BpfCode::Bpf_K,
             static_cast<nn::socket::BpfUInt32>(nn::socket::EtherType::EtherType_Ip),
             0,
             8),
    // Make sure it's a UDP packet...
    nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ld + nn::socket::BpfCode::Bpf_B + nn::socket::BpfCode::Bpf_Abs,
             nn::socket::BpfCode::Bpf_EthCook + static_cast<nn::socket::BpfUInt32>(23)),
    nn::socket::Bpf_Jump(nn::socket::BpfCode::Bpf_Jmp + nn::socket::BpfCode::Bpf_Jeq + nn::socket::BpfCode::Bpf_K,
             static_cast<nn::socket::BpfUInt32>(nn::socket::Protocol::IpProto_Udp),
             0,
             6),
    // Make sure this isn't a fragment...
    nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ld + nn::socket::BpfCode::Bpf_H + nn::socket::BpfCode::Bpf_Abs,
             nn::socket::BpfCode::Bpf_EthCook + static_cast<nn::socket::BpfUInt32>(20)),
    nn::socket::Bpf_Jump(nn::socket::BpfCode::Bpf_Jmp + nn::socket::BpfCode::Bpf_Jset + nn::socket::BpfCode::Bpf_K,
             0x1fff,
             4,
             0),
    // Get the IP header length...
    nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ldx + nn::socket::BpfCode::Bpf_B + nn::socket::BpfCode::Bpf_Msh,
             nn::socket::BpfCode::Bpf_EthCook + static_cast<nn::socket::BpfUInt32>(14)),
    // Make sure it's to the right port...
    nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ld + nn::socket::BpfCode::Bpf_H + nn::socket::BpfCode::Bpf_Ind,
             nn::socket::BpfCode::Bpf_EthCook + static_cast<nn::socket::BpfUInt32>(16)),
    nn::socket::Bpf_Jump(nn::socket::BpfCode::Bpf_Jmp + nn::socket::BpfCode::Bpf_Jeq + nn::socket::BpfCode::Bpf_K,
             DhcpProtPorts_Client,
             0,
             1),
    // If we passed all the tests, ask for the whole packet.
    nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ret + nn::socket::BpfCode::Bpf_K, 0xffffffff),
    // Otherwise, drop it.
    nn::socket::Bpf_Stmt(nn::socket::BpfCode::Bpf_Ret + nn::socket::BpfCode::Bpf_K, 0),
};


PacketManager::PacketManager()
    : m_pInterface(nullptr)
    , m_RawFd(-1)
    , m_ArpFd(-1)
    , m_Biocgblen(0)
{
    memset(&m_ReadBuffer, 0, sizeof(m_ReadBuffer));
}

PacketManager::~PacketManager()
{

}

Result PacketManager::Initialize(Interface *pInterface)
{
    Result result = ResultSuccess();
    m_pInterface       = pInterface;
    m_RawFd            = -1;
    m_ArpFd            = -1;
    memset(&m_ReadBuffer, 0, sizeof(m_ReadBuffer));
    do
    {
        char pathBuf[nn::socket::IfNamSiz] = { 0 };
        int flag = 1;
        nn::socket::BpfVersion v = { 0, 0 };
        nn::socket::IfReq ifr;
        memset(&ifr, 0, sizeof(ifr));
        m_pInterface->GetInterfaceName((char *)ifr.ifr_name, sizeof(ifr.ifr_name));
        nn::util::Strlcpy(pathBuf, "/dev/bpf0", sizeof(pathBuf));

        if ((m_RawFd = nn::socket::Open(pathBuf, nn::socket::OpenFlag::O_RdWr)) < 0)
        {
            DHCPC_BREAK_UPON_ERROR(ResultBpfError());
        }

        if (nn::socket::Ioctl(m_RawFd, static_cast<nn::socket::IoctlCommand>(nn::socket::IoctlCommandPrivate::BiocSetIf), &ifr, sizeof(ifr)) < 0)
        {
            nn::socket::Errno lastErrno = nn::socket::GetLastError();
            DHCPC_LOG_ERROR("Ioctl( BiocSetIf ) failed, errno=%d.\n", lastErrno);
            switch (lastErrno)
            {
            case nn::socket::Errno::ENxIo:
                result = ResultIfRemoved();
                break;
            default:
                result = ResultBpfError();
                break;
            }
            break;
        }
        if (nn::socket::Ioctl(m_RawFd, static_cast<nn::socket::IoctlCommand>(nn::socket::IoctlCommandPrivate::BiocVersion), &v, sizeof(v)) >= 0)
        {
            if (v.bv_major != nn::socket::Bpf_Major_Version || v.bv_minor < nn::socket::Bpf_Minor_Version)
            {
                DHCPC_LOG_ERROR("BPF version error\n");
                result = ResultBpfError();
                break;
            }
        }
        else
        {
            nn::socket::Errno lastErrno = nn::socket::GetLastError();
            DHCPC_LOG_ERROR("Ioctl( BiocVersion ) failed, errno=%d.\n", lastErrno);
            result = ResultBpfError();
            break;
        }

        // Set immediate mode so that reads return as soon as a packet
        // comes in, rather than waiting for the input buffer to fill
        // with packets.
        if (nn::socket::Ioctl(m_RawFd, static_cast<nn::socket::IoctlCommand>(nn::socket::IoctlCommandPrivate::BiocImmediate), &flag, sizeof(flag)) < 0)
        {
            nn::socket::Errno lastErrno = nn::socket::GetLastError();
            DHCPC_LOG_ERROR("Ioctl( BiocImmediate ) failed, errno=%d.\n", lastErrno);
            result = ResultBpfError();
            break;
        }

        // Set timeout escape mechanism
#if 0
        nn::socket::TimeVal to;
        to.tv_sec = 10;
        to.tv_usec = 0;
        if (nn::socket::Ioctl(m_RawFd, static_cast<nn::socket::IoctlCommand>(nn::socket::IoctlCommandPrivate::BiocSRTimeout), &to, sizeof(to)) < 0)
        {
            DHCPC_LOG_ERROR("Can't set timeout on bpf device, errno=%d\n",nn::socket::GetLastError());
            result = ResultBpfError();
            break;
        }
#endif

        // Get the required BPF buffer length from the kernel.
        if (nn::socket::Ioctl(m_RawFd, static_cast<nn::socket::IoctlCommand>(nn::socket::IoctlCommandPrivate::BiocGBLen), &m_Biocgblen, sizeof(m_Biocgblen)) >= 0)
        {
            if (m_Biocgblen > sizeof(m_ReadBuffer))
            {
                DHCPC_LOG_ERROR("read buffer size of %d is smaller than BiocGBLen %d\n",
                                sizeof(m_ReadBuffer), m_Biocgblen);
                result = ResultBpfError();
                break;
            }
        }
        else
        {
            nn::socket::Errno lastErrno = nn::socket::GetLastError();
            DHCPC_LOG_ERROR("Ioctl( BIOCGBLEN ) failed, errno=%d.\n", lastErrno);
            result = ResultBpfError();
            break;
        }

        // Set up filter
        nn::socket::BpfProgram bpfProgramTable;
        memset(&bpfProgramTable, 0, sizeof(bpfProgramTable));
        memcpy(bpfProgramTable.bf_insns, m_BpfFilter, sizeof(m_BpfFilter));
        bpfProgramTable.bf_len = BPF_FILTER_ENTRIES;
        if (nn::socket::Ioctl(m_RawFd, static_cast<nn::socket::IoctlCommand>(nn::socket::IoctlCommandPrivate::BiocSetf), &bpfProgramTable, sizeof(bpfProgramTable)) < 0)
        {
            nn::socket::Errno lastErrno = nn::socket::GetLastError();
            DHCPC_LOG_ERROR("Ioctl( BiocSetf ) failed, errno=%d.\n", lastErrno);
            switch (lastErrno)
            {
            case nn::socket::Errno::ENxIo:
                result = ResultIfRemoved();
                break;
            default:
                result = ResultBpfError();
                break;
            }
            break;
        }

        result = ResultSuccess();

    } while (false);

    if ( !result.IsSuccess())
    {
        Finalize();
    }

    return result;
} // NOLINT(impl/function_size)

Result PacketManager::Finalize()
{
    if (m_RawFd >= 0)
    {
        nn::socket::Close(m_RawFd);
        m_RawFd = -1;
    }
    if (m_ArpFd >= 0)
    {
        nn::socket::Close(m_ArpFd);
        m_ArpFd = -1;
    }

    m_pInterface = NULL;

    return ResultSuccess();
}

Result PacketManager::SendRawPacket(nn::socket::EtherType protocol, const void *data, size_t len)
{
    Result result = ResultSuccess();
    int totalSize = len + nn::socket::Ether_Hdr_Len;
    uint8_t buffer[DhcpSizes_MtuMax + nn::socket::Ether_Addr_Len];
    nn::socket::EtherHeader *pEthHdr = (nn::socket::EtherHeader *)buffer;
    int fd = (protocol == nn::socket::EtherType::EtherType_Arp) ? m_ArpFd : m_RawFd;
    memset(pEthHdr, 0, nn::socket::Ether_Hdr_Len);
    memset(pEthHdr->ether_dhost, 0xff, nn::socket::Ether_Addr_Len);
    pEthHdr->ether_type = nn::socket::InetHtons(static_cast<uint16_t>(protocol));
    memcpy(((uint8_t *)pEthHdr) + nn::socket::Ether_Hdr_Len, data, len);
    if (nn::socket::Write(fd, buffer, totalSize) != totalSize)
    {
        nn::socket::Errno lastErrno = nn::socket::GetLastError();
        DHCPC_LOG_ERROR("BPF Write( ) failed, errno=%d.\n", lastErrno);
        switch (lastErrno)
        {
        case nn::socket::Errno::ENxIo:
            result = ResultIfRemoved();
            break;
        default:
            result = ResultIfSendError();
            break;
        }
    }
    return result;
}

ssize_t PacketManager::ReadRawPacket(nn::socket::EtherType protocol, void *data, size_t len, int *flags)
{
    nn::socket::BpfHdr packet;
    ssize_t bytes;
    const unsigned char *payload;
    int fd = (protocol == nn::socket::EtherType::EtherType_Arp) ? m_ArpFd : m_RawFd;
    *flags = 0;
    for (;;)
    {
        if (m_ReadBuffer.length  == 0)
        {
            bytes = nn::socket::Read(fd, m_ReadBuffer.data, sizeof(m_ReadBuffer.data));
            if (bytes == -1 || bytes == 0)
            {
                char ifName[Config_MaxIfNameSize] = { 0 };
                *flags |= RawFlags_EOF;
                m_pInterface->GetInterfaceName(ifName, sizeof(ifName));
                DHCPC_LOG_VERBOSE("Interface-%s: raw read returned size of %d.\n", ifName, bytes);
                return bytes;
            }
            m_ReadBuffer.length = (size_t)bytes;
            m_ReadBuffer.position = 0;
        }
        bytes = -1;
        memcpy(&packet, m_ReadBuffer.data + m_ReadBuffer.position, sizeof(packet));
        if (packet.bh_caplen != packet.bh_datalen)
        {
            goto next; /* Incomplete packet, drop. */
        }
        if ((m_ReadBuffer.position + static_cast<int32_t>(packet.bh_caplen) +
             static_cast<int32_t>(packet.bh_hdrlen)) > m_ReadBuffer.length)
        {
            goto next; /* PacketManager beyond buffer, drop. */
        }
        payload = m_ReadBuffer.data + m_ReadBuffer.position + packet.bh_hdrlen + nn::socket::Ether_Hdr_Len;
        bytes = (ssize_t)packet.bh_caplen - nn::socket::Ether_Hdr_Len;
        if ((size_t)bytes > len)
        {
            bytes = (ssize_t)len;
        }
        memcpy(data, payload, (size_t)bytes);
    next:
        m_ReadBuffer.position += nn::socket::Bpf_WordAlign(packet.bh_hdrlen + packet.bh_caplen);
        if (m_ReadBuffer.position >= m_ReadBuffer.length)
        {
            m_ReadBuffer.length = m_ReadBuffer.position = 0;
            *flags |= RawFlags_EOF;
        }
        if (bytes != -1) return bytes;
    }
}

Result PacketManager::SendUdpPacket(const nn::socket::InAddr *destAddr,const nn::socket::InAddr *sourceAddr, const void *data, size_t len)
{
    int udpSock;
    Result result = ResultSuccess();
    if ((udpSock = nn::socket::SocketExempt(nn::socket::Family::Pf_Inet, nn::socket::Type::Sock_Dgram | nn::socket::TypePrivate::Sock_CloExec, nn::socket::Protocol::IpProto_Udp)) != -1)
    {
        nn::socket::SockAddrIn sbind;
        memset(&sbind, 0, sizeof(sbind));
        sbind.sin_family = nn::socket::Family::Af_Inet;
        sbind.sin_addr = *sourceAddr;
        sbind.sin_port = nn::socket::InetHtons(DhcpProtPorts_Client);
        if(nn::socket::Bind(udpSock,(nn::socket::SockAddr *)&sbind,sizeof(sbind)) < 0)
        {
            nn::socket::Errno lastErrno = nn::socket::GetLastError();
            DHCPC_LOG_ERROR("BPF Bind( ) failed, errno=%d.\n", lastErrno);
            result = ResultIfSendError();
            nn::socket::Close(udpSock);
            return result;
        }

        nn::socket::SockAddrIn sin;
        memset(&sin, 0, sizeof(sin));
        sin.sin_family = nn::socket::Family::Af_Inet;
        sin.sin_addr = *destAddr;
        sin.sin_port = nn::socket::InetHtons(DhcpProtPorts_Server);

        int broadcastEnable = 1;
        nn::socket::SetSockOpt(udpSock, nn::socket::Level::Sol_Socket, nn::socket::Option::So_Broadcast,
                                &broadcastEnable, sizeof(broadcastEnable));
        if (nn::socket::SendTo(udpSock, data, len, nn::socket::MsgFlag::Msg_None, (nn::socket::SockAddr *)&sin, sizeof(sin)) < 0)
        {
            nn::socket::Errno lastErrno = nn::socket::GetLastError();
            DHCPC_LOG_ERROR("BPF SendTo( ) failed, errno=%d.\n", lastErrno);
            switch (lastErrno)
            {
            case nn::socket::Errno::ENxIo:
                result = ResultIfRemoved();
                break;
            default:
                result = ResultIfSendError();
                break;
            }
        }
        nn::socket::Close(udpSock);
    }
    else
    {
        DHCPC_RETURN_UPON_ERROR(ResultSocketOpenError());
    }
    return result;
}

void PacketManager::ClearRawReceiveBuffer(nn::socket::EtherType protocol)
{
    ssize_t size;
    int flags;
    uint8_t buffer[DhcpSizes_MtuMax + nn::socket::Ether_Addr_Len];
    do
    {
        size = ReadRawPacket(protocol, buffer, sizeof(buffer), &flags);
    }while (size > 0);
}

int PacketManager::GetRawFd()
{
    return m_RawFd;
}

int PacketManager::ValidateUdpPacket(const uint8_t *data, size_t data_len, nn::socket::InAddr *from, int noudpcsum)
{
    DhcpPacket p;
    uint16_t bytes, udpsum;
    if (data_len < sizeof(p.ip))
    {
        if (from)
        {
            from->S_addr = nn::socket::InetHtonl(nn::socket::InAddr_Any);
        }
        return -1;
    }
    memcpy(&p, data, DHCPC_MIN(data_len, sizeof(p)));
    if (from)
    {
        *from = p.ip.ip_src;
    }
    if (data_len > sizeof(p))
    {
        return -1;
    }
    if (Util::Checksum((const uint8_t *)&p.ip, sizeof(p.ip)) != 0)
    {
        return -1;
    }

    bytes = nn::socket::InetNtohs(p.ip.ip_len);
    if (data_len < bytes)
    {
        return -1;
    }

    if (noudpcsum == 0)
    {
        udpsum = p.udp.uh_sum;
        p.udp.uh_sum = 0;
        p.ip.ip_hl = 0;
        p.ip.ip_v = 0;
        p.ip.ip_tos = 0;
        p.ip.ip_len = p.udp.uh_ulen;
        p.ip.ip_id = 0;
        p.ip.ip_off = 0;
        p.ip.ip_ttl = 0;
        p.ip.ip_sum = 0;
        if (udpsum && Util::Checksum((const uint8_t *)&p, bytes) != udpsum)
        {
            return -1;
        }
    }

    return 0;
}

size_t PacketManager::MakeUdpPacket(DhcpPacket *pPkt, size_t dhcpMessageSize, const nn::socket::InAddr *source, const nn::socket::InAddr *dest)
{
    nn::socket::Ip *ip = &pPkt->ip;
    nn::socket::UdpHdr *udp = &pPkt->udp;
    size_t totalSize = sizeof(*ip) + sizeof(*udp) + dhcpMessageSize;

    ip->ip_p = static_cast<uint8_t>(nn::socket::Protocol::IpProto_Udp);
    ip->ip_src = *source;
    ip->ip_dst = *dest;

    udp->uh_sport = nn::socket::InetHtons(DhcpProtPorts_Client);
    udp->uh_dport = nn::socket::InetHtons(DhcpProtPorts_Server);
    udp->uh_ulen = nn::socket::InetHtons(sizeof(*udp) + dhcpMessageSize);
    ip->ip_len = udp->uh_ulen;
    udp->uh_sum = Util::Checksum((const uint8_t *)ip, totalSize);

    ip->ip_v = nn::socket::IpVersion;
    ip->ip_hl = sizeof(*ip) >> 2;
    ip->ip_id = 0;
    ip->ip_ttl = nn::socket::IpDefTtl;
    ip->ip_len = nn::socket::InetHtons(sizeof(*ip) + sizeof(*udp) + dhcpMessageSize);
    ip->ip_sum = Util::Checksum((const uint8_t *)ip, sizeof(*ip));

    return totalSize;
}


} // namespace dhcpc
} // namespace bsdsocket
} // namespace nn




