﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/nn_Log.h>
#include <nn/bsdsocket/cfg/cfg.h>
#include <nn/socket.h>
#include <nn/socket/socket_ApiPrivate.h>
#include <nn/util/util_StringUtil.h>
#include <nn/wlan/wlan_InfraApi.h>
#include <cstdlib>
#include <array>
#include "Command.h"
#include "Async.h"
#include "InteractiveShell.h"

#ifdef __INTELLISENSE__  // IntelliSense が #include_next を理解しないので仕方なく
#include <sys/poll.h>
#endif


namespace Command
{

namespace
{

struct CommandInfo
{
    const char* name;
    size_t nameLength;
    Result(*handler)(char**, int);
    const char* usage;

    template <size_t T>
    CommandInfo(const char(&name)[T], Result(*handler)(char**, int), const char* usage) :
        name(name), nameLength(T), handler(handler), usage(usage)
    {
    }
};


Result Exit(char** argv, int argc);
Result Help(char** argv, int argc);
Result Join(char** argv, int argc);
Result Disconnect(char** argv, int argc);
Result WakeReason(char** argv, int argc);
Result Socket(char** argv, int argc);
Result Close(char** argv, int argc);
Result Info(char** argv, int argc);
Result SetWakeReason(char** argv, int argc);
Result GetWakeReason(char** argv, int argc);
Result EnableWowlFeatures(char** argv, int argc);

const CommandInfo Commands[] = {  // ソートしないとコマンドが見つからなくなる
    { "close", Close, "close <socket_fd>" },
    { "disconnect", Disconnect, "disconnect" },
    { "exit", Exit, "exit" },
    { "get_wake", GetWakeReason, "get_wake" },  // 無線ドライバーからの起床要因の生値を取得する
    { "help", Help, "help" },
    { "info", Info, "info <socket_fd>" },
    { "join", Join, "join <ssid> <open|wep|wpa|wpa2> [<sec key>]" },
    { "set_wake", SetWakeReason, "set_wake <wake reason hex>" },  // 無線ドライバーに起床要因を直接指定する
    { "socket", Socket, "socket <destination_ip> [<port>]" },
    { "wake_reason", WakeReason, "wake_reason" },
    { "wowl_feature", EnableWowlFeatures, "wowl_feature [<arp|tcpka>] ..." }, // Wowl機能を個別に有効・無効化する
};
const size_t CommandsCount = sizeof(Commands) / sizeof(*Commands);


// nibble を ascii に
inline char NibbleToAscii(uint8_t nibble)
{
    nibble &= 0xf;
    return nibble < 0xa ? nibble + '0' : nibble - 0xa + 'A';
}


// 終了コマンド
Result Exit(char** argv, int argc)
{
    nn::wlan::WlanState state;
    nn::wlan::Infra::GetState(&state);
    if (state == nn::wlan::WlanState_InfraSta)
    {
        nn::bsdsocket::cfg::SetIfDown("wl0");
        nn::wlan::Infra::Disconnect();
    }

    InteractiveShell::Quit();
    return Result_Success;
}


// ヘルプ表示
Result Help(char** argv, int argc)
{
    for (int i = 0; i < sizeof(Commands) / sizeof(*Commands); ++i)
    {
        NN_LOG(">%s\n", Commands[i].usage);
    }
    return Result_Success;
}


// AP 接続
Result Join(char** argv, int argc)
{
    // パラメータチェック
    nn::wlan::SecurityMode mode = {};
    int keylen = 0;
    bool isArgValid = false;
    if (argc == 3 && !nn::util::Strncmp("open", argv[2], 5))
    {
        mode = nn::wlan::SecurityMode_Open;
        isArgValid = true;
    }
    else if (argc == 4)
    {
        keylen = nn::util::Strnlen(argv[3], 65);
        if (!nn::util::Strncmp("wep", argv[2], 4))
        {
            if (keylen == 10)
            {
                mode = nn::wlan::SecurityMode_Wep64Open;
            }
            else if (keylen == 26)
            {
                mode = nn::wlan::SecurityMode_Wep128Open;
            }
            isArgValid = true;
        }
        else if (keylen >= 8 && keylen <= 63)
        {
            if (!nn::util::Strncmp("wpa", argv[2], 4))
            {
                mode = nn::wlan::SecurityMode_WpaAes;
                isArgValid = true;
            }
            else if (!nn::util::Strncmp("wpa2", argv[2], 5))
            {
                mode = nn::wlan::SecurityMode_Wpa2Aes;
                isArgValid = true;
            }
        }
    }
    if (!isArgValid)
    {
        return Result_BadArgument;
    }

    // 接続
    nn::wlan::Ssid ssid;
    ssid.Set(argv[1]);

    nn::wlan::Security security = {};
    security.privacyMode = mode;
    security.groupPrivacyMode = mode;
    memcpy(security.key, argv[3], keylen);

    auto result = nn::wlan::Infra::Connect(ssid, nn::wlan::MacAddress::CreateBroadcastMacAddress(), -1, security, false);
    if (result.IsFailure())
    {
        return Result_Unknown;
    }
    nn::wlan::ConnectionStatus connectionStatus;
    nn::wlan::Infra::GetConnectionStatus(&connectionStatus);
    if (connectionStatus.state != nn::wlan::ConnectionState_Connected)
    {
        NN_LOG("Failed to connect.\n");
        return Result_Unknown;
    }

    // インターフェース Up
    nn::bsdsocket::cfg::IfSettings settings = {};
    settings.mode = nn::bsdsocket::cfg::IfIpAddrMode_Dhcp;
    settings.mtu = 1500;

    result = nn::bsdsocket::cfg::SetIfUp("wl0", &settings);
    if (result.IsFailure())
    {
        NN_LOG(">result_value:%08X\n", result.GetInnerValueForDebug());
        nn::wlan::Infra::Disconnect();
        return Result_Unknown;
    }

    // IP とか出力
    nn::bsdsocket::cfg::IfState state;
    nn::bsdsocket::cfg::GetIfState("wl0", &state);

    NN_LOG(">ip:%d.%d.%d.%d\n",
        state.addr.S_addr & 0xff,
        (state.addr.S_addr >> 8) & 0xff,
        (state.addr.S_addr >> 16) & 0xff,
        (state.addr.S_addr >> 24) & 0xff
    );

    NN_LOG(">netmask:%d.%d.%d.%d\n",
        state.subnetMask.S_addr & 0xff,
        (state.subnetMask.S_addr >> 8) & 0xff,
        (state.subnetMask.S_addr >> 16) & 0xff,
        (state.subnetMask.S_addr >> 24) & 0xff
    );

    NN_LOG(">def_gw:%d.%d.%d.%d\n",
        state.gatewayAddr.S_addr & 0xff,
        (state.gatewayAddr.S_addr >> 8) & 0xff,
        (state.gatewayAddr.S_addr >> 16) & 0xff,
        (state.gatewayAddr.S_addr >> 24) & 0xff
    );

    NN_LOG(">dns:%d.%d.%d.%d\n",
        state.dnsAddrs[0].S_addr & 0xff,
        (state.dnsAddrs[0].S_addr >> 8) & 0xff,
        (state.dnsAddrs[0].S_addr >> 16) & 0xff,
        (state.dnsAddrs[0].S_addr >> 24) & 0xff
    );

    NN_LOG(">alt_dns:%d.%d.%d.%d\n",
        state.dnsAddrs[1].S_addr & 0xff,
        (state.dnsAddrs[1].S_addr >> 8) & 0xff,
        (state.dnsAddrs[1].S_addr >> 16) & 0xff,
        (state.dnsAddrs[1].S_addr >> 24) & 0xff
    );

    return Result_Success;
} //NOLINT(impl/function_size)


// AP 切断
Result Disconnect(char** argv, int argc)
{
    // インターフェース Down
    nn::bsdsocket::cfg::SetIfDown("wl0");

    // 切断
    auto result = nn::wlan::Infra::Disconnect();

    if (result.IsSuccess())
    {
        return Result_Success;
    }

    NN_LOG(">result_value:%08X\n", result.GetInnerValueForDebug());
    return Result_Unknown;
}


// 起床要因を出力
Result WakeReason(char** argv, int argc)
{
    nn::wlan::WowlWakeReason reason;
    nn::wlan::Infra::GetWakeupReason(&reason);

    char reasonStr[15];
    switch (reason)
    {
    case nn::wlan::WowlWakeReason_Nothing:
        strcpy(reasonStr, "Nothing");
        break;
    case nn::wlan::WowlWakeReason_Magicpacket:
        strcpy(reasonStr, "Magic packet");
        break;
    case nn::wlan::WowlWakeReason_PatternData:
        strcpy(reasonStr, "Receive pattern data");
        break;
    case nn::wlan::WowlWakeReason_Linkdown:
        strcpy(reasonStr, "Link down");
        break;
    case nn::wlan::WowlWakeReason_TcpSessionData:
        strcpy(reasonStr, "Receive TCP session data");
        break;
    default:
        strcpy(reasonStr, "Unknown");
        break;
    }
    NN_LOG(">reason:%s\n", reasonStr);

    return Result_Success;
}


// ソケット接続
Result Socket(char** argv, int argc)
{
    // パラメータチェック
    bool isArgValid = false;
    int port;
    if (argc == 2)
    {
        port = 45678;
        isArgValid = true;
    }
    else if (argc == 3)
    {
        port = std::atoi(argv[2]);
        if (port > 0 || port <= 65535)
        {
            isArgValid = true;
        }
    }

    if (!isArgValid)
    {
        return Result_BadArgument;
    }

    auto socket = nn::socket::Socket(nn::socket::Family::Af_Inet, nn::socket::Type::Sock_Stream, nn::socket::Protocol::IpProto_Tcp);

    {
        int opt = 1;
        auto ret = nn::socket::SetSockOpt(socket, nn::socket::Level::Sol_Tcp, nn::socket::Option::Tcp_NoDelay, &opt, sizeof(opt));
        if (ret < 0)
        {
            NN_LOG(">warn:failed to setting nn::socket::Option::Tcp_NoDelay flag\n");
        }

        // disable SACK/WindowScaling/TimeStamp
        opt = 1;
        ret = nn::socket::SetSockOpt(socket, nn::socket::Level::Sol_Tcp, nn::socket::Option::Tcp_NoOpt, &opt, sizeof(opt));
        if (ret < 0)
        {
            NN_LOG(">warn:failed to setting nn::socket::Option::Tcp_NoDelay flag\n");
        }

        opt = 4096;
        ret = nn::socket::SetSockOpt(socket, nn::socket::Level::Sol_Socket, nn::socket::Option::So_RcvBuf, &opt, sizeof(opt));
        if (ret  < 0)
        {
            NN_LOG(">warn:failed to setting SO_RCVBUF value\n");
        }
    }

    nn::socket::InAddr addr;
    nn::socket::InetAton(argv[1], &addr);

    nn::socket::SockAddrIn destAddr = {};
    destAddr.sin_addr = addr;
    destAddr.sin_port = nn::socket::InetHtons(static_cast<uint16_t>(port));
    destAddr.sin_family = nn::socket::Family::Af_Inet;

    int result = nn::socket::Connect(socket, reinterpret_cast<nn::socket::SockAddr*>(&destAddr), sizeof(destAddr));
    if (result < 0)
    {
        return Result_Unknown;
    }

    // データ受信コールバック
    static const auto ReceiveCallback = [](int socket, nn::socket::PollEvent pollResult, void* userptr)
    {
        const size_t AddressByteWidth = 2;
        const size_t BytesInLine = 16;
        const size_t LineWidth = AddressByteWidth * 2 + 1 + 3 * BytesInLine + 1 + BytesInLine + 1;

        if ((pollResult & nn::socket::PollEvent::PollHup) != nn::socket::PollEvent::PollNone)
        {
            // 切断された
            NN_LOG(">socket_disconnected\n");
            nn::socket::Close(socket);
            return;
        }
        else if ((pollResult & nn::socket::PollEvent::PollNVal) != nn::socket::PollEvent::PollNone)
        {
            // ソケット閉じた
            NN_LOG(">socket_closed\n");
            nn::socket::Close(socket);
            return;
        }

        NN_LOG(">recv\n");

        {
            char format[1 + 4 + 1 + BytesInLine * 3 + 1 + BytesInLine + 1 + 1];
            nn::util::SNPrintf(format, sizeof(format), ">%%-%ds +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +A +B +C +D +E +F 0123456789ABCDEF\n", AddressByteWidth * 2);

            NN_LOG(format, "addr");
        }

        std::array<char, LineWidth> lineBuffer;
        lineBuffer.back() = '\0';

        int totalReceivedCount = 0;
        char* lineHexWritePtr;
        char* lineAsciiWritePtr;

        uint8_t recvBuf[128];
        for (;;)
        {
            auto readPtr = recvBuf;
            auto readEnd = readPtr + nn::socket::Recv(socket, recvBuf, sizeof(recvBuf), nn::socket::MsgFlag::Msg_None);

            for (; readPtr < readEnd; ++readPtr)
            {
                if (!(totalReceivedCount & 0xf))
                {
                    auto lineWritePtr = lineBuffer.begin();
                    lineHexWritePtr = lineWritePtr + AddressByteWidth * 2 + 1;
                    lineAsciiWritePtr = lineHexWritePtr + 3 * BytesInLine;

                    for (int i = AddressByteWidth * 2; i--;)
                    {
                        *lineWritePtr++ = NibbleToAscii(totalReceivedCount >> (i * 4));
                    }

                    while (lineWritePtr < lineBuffer.end())
                    {
                        *lineWritePtr++ = ' ';
                    }
                }

                *(lineHexWritePtr) = NibbleToAscii(*readPtr >> 4);
                *(lineHexWritePtr + 1) = NibbleToAscii(*readPtr);
                lineHexWritePtr += 3;

                *lineAsciiWritePtr++ = *readPtr >= 0x20 && *readPtr < 0x80 ? *readPtr : '.';

                if (!(~totalReceivedCount & 0xf))
                {
                    NN_LOG(">%s\n", lineBuffer.data());
                }

                ++totalReceivedCount;
            }

            int bufferedDataLength;
            nn::socket::Ioctl(socket, nn::socket::IoctlCommand::FionRead, &bufferedDataLength, sizeof(bufferedDataLength));
            if (bufferedDataLength == 0)
            {
                break;
            }
        }

        if (totalReceivedCount & 0xf)
        {
            NN_LOG(">%s\n", lineBuffer.data());
        }

        if (totalReceivedCount > 0)
        {
            Async::WaitRecv(socket, reinterpret_cast<void(*)(int, nn::socket::PollEvent, void*)>(userptr), userptr);
        }
        else
        {
            // RST 投げられたらここに来るようなので、たぶん RST
            NN_LOG(">socket_rst\n");
            nn::socket::Close(socket);
        }
    };
    Async::WaitRecv(socket, ReceiveCallback, reinterpret_cast<void*>(static_cast<void(*)(int, nn::socket::PollEvent, void*)>(ReceiveCallback)));

    NN_LOG(">socket_fd:%d\n", socket);
    return Result_Success;
} //NOLINT(impl/function_size)


// ソケットクローズ
Result Close(char** argv, int argc)
{
    // パラメータチェック
    // TODO: まじめにチェックする
    if (argc != 2)
    {
        return Result_BadArgument;
    }

    auto result = nn::socket::Close(std::atoi(argv[1]));
    if (result < 0)
    {
        return Result_Unknown;
    }

    return Result_Success;
}

Result Info(char** argv, int argc)
{
    // パラメータチェック
    // TODO: まじめにチェックする
    if (argc != 2)
    {
        return Result_BadArgument;
    }

    nn::bsdsocket::cfg::IfState state;
    nn::bsdsocket::cfg::GetIfState("wl0", &state);

    int s = std::atoi(argv[1]);
    nn::socket::TcpInfo ti;
    nn::socket::SockLenT len = sizeof(ti);

    if (nn::socket::GetSockOpt(s, nn::socket::Level::Sol_Tcp, nn::socket::Option::Tcp_Info, reinterpret_cast<void*>(&ti), &len) < 0)
    {
        NN_LOG(">GetSockOpt failed (error %d)\n", nn::socket::GetLastError());
        return Result_Unknown;
    }

    nn::socket::SockAddrIn saiClient;
    len = sizeof(saiClient);
    if (nn::socket::GetSockName(s, reinterpret_cast<nn::socket::SockAddr*>(&saiClient), &len) < 0)
    {
        NN_LOG(">GetSockName failed (error %d)\n", nn::socket::GetLastError());
        return Result_Unknown;
    }
    if (saiClient.sin_addr.S_addr == nn::socket::InAddr_Any)
    {
        saiClient.sin_addr.S_addr = state.addr.S_addr;
    }

    nn::socket::SockAddrIn saiServer;
    len = sizeof(saiServer);
    if (nn::socket::GetPeerName(s, reinterpret_cast<nn::socket::SockAddr*>(&saiServer), &len) < 0)
    {
        NN_LOG(">GetPeerName failed (error %d)\n", nn::socket::GetLastError());
        return Result_Unknown;
    }

    uint8_t remoteHardwareAddress[nn::socket::Ether_Addr_Len];

    // FIXME: static routes
    nn::socket::SockAddrIn saiNext;
    if ((saiServer.sin_addr.S_addr & state.subnetMask.S_addr) == (state.gatewayAddr.S_addr & state.subnetMask.S_addr))
    {
        saiNext = saiServer;
    }
    else
    {
        saiNext.sin_addr = state.gatewayAddr;
    }

    nn::Result result = nn::bsdsocket::cfg::LookupArpEntry(remoteHardwareAddress, nn::socket::Ether_Addr_Len, saiNext.sin_addr);
    if (result.IsFailure())
    {
        NN_LOG(">LookupArpEntry failed (result 0x%08x)\n", result.GetInnerValueForDebug());
        return Result_Unknown;
    }

    NN_LOG(">       seq: % 16u[0x%08x]\n", ti.tcpi_snd_nxt, ti.tcpi_snd_nxt);
    NN_LOG(">       ack: % 16u[0x%08x]\n", ti.tcpi_rcv_nxt, ti.tcpi_rcv_nxt);
    NN_LOG(">      rwin: % 16u\n", ti.tcpi_rcv_space);
    NN_LOG(">     local: % 16s:%u\n", nn::socket::InetNtoa(saiClient.sin_addr), nn::socket::InetNtohs(saiClient.sin_port));
    NN_LOG(">    remote: % 16s:%u\n", nn::socket::InetNtoa(saiServer.sin_addr), nn::socket::InetNtohs(saiServer.sin_port));
    NN_LOG("> next addr: % 16s[%02x:%02x:%02x:%02x:%02x:%02x]\n",
        nn::socket::InetNtoa(saiNext.sin_addr),
        remoteHardwareAddress[0], remoteHardwareAddress[1], remoteHardwareAddress[2], remoteHardwareAddress[3], remoteHardwareAddress[4], remoteHardwareAddress[5]);

    nn::wlan::WlanIpv4Address src;
    nn::wlan::WlanIpv4Address dst;
    char srcIpStr[20];
    char dstIpStr[20];
    char* tok;

    std::strcpy(srcIpStr, nn::socket::InetNtoa(saiClient.sin_addr));
    std::strcpy(dstIpStr, nn::socket::InetNtoa(saiServer.sin_addr));
    tok = std::strtok(srcIpStr, ".");
    src.addr[0] = std::atoi(tok);
    for( int i = 1; i < 4; i++ )
    {
        tok = std::strtok(NULL, ".");
        src.addr[i] = std::atoi(tok);
    }

    tok = std::strtok(dstIpStr, ".");
    dst.addr[0] = std::atoi(tok);
    for( int i = 1; i < 4; i++ )
    {
        tok = std::strtok(NULL, ".");
        dst.addr[i] = std::atoi(tok);
    }

    nn::wlan::MacAddress dstMac(remoteHardwareAddress);
    result = nn::wlan::Infra::SetTcpSessionInformation(dstMac,
            src, dst, nn::socket::InetNtohs(saiClient.sin_port), nn::socket::InetNtohs(saiServer.sin_port),
            ti.tcpi_rcv_nxt, static_cast<uint16_t>(ti.tcpi_rcv_space));
    if( result.IsFailure() )
    {
        NN_LOG("SetTcpSessionInformation failed.\n");
    }

    return Result_Success;
}

Result SetWakeReason(char** argv, int argc)
{
    // パラメータチェック
    if (argc != 2)
    {
        return Result_BadArgument;
    }

    // 入力文字列をHEXと見なす
    char* endptr = NULL;
    int64_t hex = std::strtol(argv[1], &endptr, 16);
    if( *endptr != '\0' )
    {
        NN_LOG("Not hex.\n");
        return Result_BadArgument;
    }

    uint32_t reason = static_cast<uint32_t>(hex);
    NN_LOG("> You set wake reason: 0x%08X\n", reason);
    auto result = nn::wlan::Infra::SetWakeupReasonRaw(reason);
    if( result.IsSuccess() )
    {
        return Result_Success;
    }
    else
    {
        return Result_Unknown;
    }
}

Result GetWakeReason(char** argv, int argc)
{
    uint32_t reason = 0;
    auto result = nn::wlan::Infra::GetWakeupReasonRaw(&reason);
    if( result.IsSuccess() )
    {
        NN_LOG("> wake reason: 0x%08X\n", reason);
        return Result_Success;
    }
    else
    {
        return Result_Unknown;
    }
}

Result EnableWowlFeatures(char** argv, int argc)
{
    uint32_t feature;

    if (argc == 1)
    {
        feature = 0;
    }

    for (int i = 1; i < argc; i++)
    {
        if (!nn::util::Strncmp("arp", argv[i], 4))
        {
            feature |= nn::wlan::WowlFeatures_ArpOffload;
        }
        else if (!nn::util::Strncmp("tcpka", argv[i], 6))
        {
            feature |= nn::wlan::WowlFeatures_TcpKeepAlive;
        }
    }
    NN_LOG("> You set features:0x%08X\n", feature);

    auto result = nn::wlan::Infra::EnableWowlFeatures(feature);
    if( result.IsSuccess() )
    {
        return Result_Success;
    }
    else
    {
        return Result_Unknown;
    }
}

}


// コマンドを実行
Result Invoke(char** argV, int argC)
{
    Result result = Result_CommandNotFound;

    int commandIndex = -1;
    // TODO: 二分探索すると速くなる (いまのところ、それが必要になるほどコマンドがない)
    for (commandIndex = CommandsCount; commandIndex--;)
    {
        auto cmp = nn::util::Strnicmp(argV[0], Commands[commandIndex].name, Commands[commandIndex].nameLength);
        if (!cmp)
        {
            result = Commands[commandIndex].handler(argV, argC);
            break;
        }
        else if (cmp > 0)
        {
            commandIndex = -1;
            break;
        }
    }

    if (result == Result_BadArgument)
    {
        NN_LOG(">usage: %s\n", Commands[commandIndex].usage);
    }

    NN_LOG(">result:%s\n", result == Command::Result_Success ? "success" : "failure");

    return result;
}

}
