﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <nn/nn_Log.h>
#include <nn/util/util_FormatString.h>
#include <ShellServer.h>

namespace nnt
{
    namespace abuse
    {
        static RegisteredShellCommand sValidCommands[]  = {RegisteredShellCommand(ShellCommandType::SHELL_SCRIPT_FILE, "run_script"),
                                                RegisteredShellCommand(ShellCommandType::SHELL_RUN_SCRIPT_COMMAND, "do_command"),
                                                RegisteredShellCommand(ShellCommandType::SHELL_STOP_SCRIPT, "stop_script"),
                                                RegisteredShellCommand(ShellCommandType::SHELL_EXIT_ABUSE, "exit_abuse"),
                                                RegisteredShellCommand(ShellCommandType::SHELL_INIT_MONITOR, "init_monitor"),
                                                RegisteredShellCommand(ShellCommandType::SHELL_GET_AVAILABLE_TASKS, "list_available_tasks"),};

        ShellServer::ShellServer() : m_nextId(0), m_serverSocket(0), m_initialized(false), m_running(false)
        {}

        ShellServer::~ShellServer()
        {
        }

        bool ShellServer::Initialize(uint16_t listenPort)
        {
            Platform::InitializeSocketLib();
            m_running = true;

            NN_LOG("Listening on port %u\n", (uint16_t) listenPort);

            nn::socket::SockAddrIn m_serverAddress;
            memset(&m_serverAddress,0,sizeof(m_serverAddress));
            m_serverAddress.sin_family = nn::socket::Family::Af_Inet;
            m_serverAddress.sin_port = Platform::htons( (uint16_t) listenPort);
            m_serverAddress.sin_addr.S_addr = Platform::GetAddrAny();

            m_serverSocket = Platform::socket(nn::socket::Family::Af_Inet, nn::socket::Type::Sock_Stream, nn::socket::Protocol::IpProto_Tcp);
            Platform::SetNonBlocking(m_serverSocket, true);
            if(m_serverSocket < 0)
                NN_LOG("Socket create failed %d\n", Platform::GetLastError());

            if(Platform::bind(m_serverSocket, (nn::socket::SockAddr*)&m_serverAddress, sizeof(m_serverAddress)) != 0)
            {
                NN_LOG("Bind failed with error: %d\n", Platform::GetLastError());
                return false;
            }


            if(Platform::listen(m_serverSocket, 100) != 0)
            {
                NN_LOG("Listen failed with error %d\n", Platform::GetLastError());
                return false;
            }
            nn::os::InitializeMutex(&m_clientMutex, false, 0);
            m_initialized = true;

            return true;
        }

        void ShellServer::Finalize()
        {
            for(int client : m_clientSockets)
            {
                Platform::close(client);
            }
            Platform::close(m_serverSocket);

            m_running = false;

        }

        bool ShellServer::IsInitialized()
        {
            return m_initialized;
        }

        bool ShellServer::IsCommandPending()
        {
            return m_pendingCommands.size() > 0;
        }

        const ShellCommand& ShellServer::PeekNextCommand()
        {
            return m_pendingCommands.front();
        }

        void ShellServer::PopNextCommand()
        {
            m_pendingCommands.pop_front();
        }

        int ShellServer::SendString(int socket, const char* format, ...)
        {
            const uint16_t HEADER_SIZE = sizeof(ShellPacketHeader);
            static const uint16_t BUFFER_SIZE = 1024 - HEADER_SIZE;
            char buffer[BUFFER_SIZE];
            memset(buffer, 0, BUFFER_SIZE);
            va_list marker;
            va_start(marker, format);
            uint16_t length = (uint16_t)nn::util::VSNPrintf(buffer + HEADER_SIZE, BUFFER_SIZE, format, marker);
            va_end(marker);

            return sendPacket(socket, buffer,length + 1, PACKET_STRING);
        }

        int ShellServer::SendCommandResult(int socket, bool result, const char* format, ...)
        {
            const uint16_t HEADER_SIZE = sizeof(ShellPacketHeader);
            const uint16_t BUFFER_SIZE = 1024 - HEADER_SIZE;
            char buffer[BUFFER_SIZE];
            memset(buffer, 0, BUFFER_SIZE);
            ShellPacketType type = type = result ? PACKET_ACK_PASS : PACKET_ACK_FAIL;

            va_list marker;
            va_start(marker, format);
            uint16_t length = (uint16_t)nn::util::VSNPrintf(buffer + HEADER_SIZE, BUFFER_SIZE, format, marker);
            va_end(marker);

            return sendPacket(socket, buffer, length + 1, type);
        }

        int ShellServer::SendReply(int socket, ShellPacketType type, const char* message, uint16_t length)
        {
            const uint16_t HEADER_SIZE = sizeof(ShellPacketHeader);
            static const uint16_t BUFFER_SIZE = 1024 - HEADER_SIZE;

            char buffer[BUFFER_SIZE];
            memset(buffer, 0, BUFFER_SIZE);
            memcpy(buffer + HEADER_SIZE, message, length);

            return sendPacket(socket, buffer,length, type);
        };

        int ShellServer::sendPacket(int socket, char* buffer, uint16_t bufferSize, ShellPacketType type)
        {
            int result = -1;
            ShellPacketHeader header;
            nn::os::LockMutex(&m_clientMutex);
            for(int client : m_clientSockets)
            {
                if(client == socket)
                {
                    uint16_t netSize = Platform::htons(bufferSize);
                    memcpy(&header.data[0], &netSize, sizeof(netSize));
                    header.data[2] = (char)type;
                    memcpy(buffer, &header,sizeof(header));
                    result = Platform::send(socket, buffer, bufferSize + sizeof(header), nn::socket::MsgFlag::Msg_None);
                    if(result == -1)
                        NN_LOG("Send error: %d\n", Platform::GetLastError());
                }
            }
            nn::os::UnlockMutex(&m_clientMutex);
            return result;
        }

        void ShellServer::UpdateServer()
        {
            static const int BUFFER_SIZE = 1024;
            char commandBuffer[BUFFER_SIZE + 1];
            if(m_running)
            {
                int client = Platform::accept(m_serverSocket, nullptr, nullptr);

                if(client > 0)
                {
                    Platform::SetNonBlocking(client, true);
                    m_clientSockets.push_back(client);
                }

                auto itr = m_clientSockets.begin();
                while(itr != m_clientSockets.end())
                {

                    int result = Platform::recv(*itr, commandBuffer, BUFFER_SIZE, nn::socket::MsgFlag::Msg_None);

                    if(result > 0 && nn::socket::GetLastError() != nn::socket::Errno::EWouldBlock)
                    {
                        commandBuffer[result]  = '\0';
                        NN_LOG("Received \"%s\"\n", commandBuffer);
                        ShellCommand command = parseCommand(commandBuffer, *itr);

                        if(command.type != SHELL_INVALID)
                        {
                            m_pendingCommands.push_back( command );
                            //SendString(*itr, "Created command %d\n", command.id);
                        }
                        else
                           SendString(*itr, "Error parsing \"%s\"", commandBuffer);
                        memset(commandBuffer, 0, result);
                    }
                    ++itr;
                }
            }
        }

        ShellCommand ShellServer::parseCommand(const char* buffer, int socket)
        {
            size_t bufferLength = strlen(buffer);
            for(int i = 0; i < sizeof(sValidCommands) / sizeof(RegisteredShellCommand); ++i)
            {
                size_t commandLength = strlen(sValidCommands[i].name);
                if(commandLength > bufferLength)
                    continue;

                if(strncmp(buffer, sValidCommands[i].name, strlen(sValidCommands[i].name)) == 0)
                    return ShellCommand(sValidCommands[i].type, buffer + commandLength + 1, (int)m_pendingCommands.size(), socket, m_nextId++);
            }
            return ShellCommand(SHELL_INVALID, nullptr, -1,  -1, -1);
        }
    }
}
