﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <string>
#include <nn/os.h>
#include <nn/os/os_SystemEvent.h>
#include <nn/socket/socket_Api.h>
#include <nn/bsdsocket/cfg/cfg_Types.h>
#include <nn/bsdsocket/cfg/cfg_ClientApi.h>
#include "Node.h"
#include "Scanner.h"
#include "Counter.h"

namespace WlanTest {

/*!--------------------------------------------------------------------------*
  @brief        プロトコル受信器

 *---------------------------------------------------------------------------*/
class SocketReceiver : public IReceiver
{
/*---------------------------------------------------------------------------
　　　　　静的メソッド
---------------------------------------------------------------------------*/
private:

    static void ReceiveThreadFunction(void* receiver);

/*---------------------------------------------------------------------------
　　　　　メンバ変数
---------------------------------------------------------------------------*/
private:

    nn::os::ThreadType m_ReceiveThread;
    uint16_t           m_ProtocolType;
    bool               m_RequestStop;
    bool               m_IsStop;

/*---------------------------------------------------------------------------
　　　　　メンバメソッド
---------------------------------------------------------------------------*/
public:

    virtual void Receive();
    virtual nn::Result Start(int32_t priority = nn::os::GetThreadCurrentPriority(nn::os::GetCurrentThread()));
    virtual nn::Result Stop();

/*---------------------------------------------------------------------------
　　　　　コンストラクタ
---------------------------------------------------------------------------*/
public:

    SocketReceiver();

/*---------------------------------------------------------------------------
　　　　　アクセッサ
---------------------------------------------------------------------------*/
public:

    void SetProtocolType(uint16_t type){ m_ProtocolType = type; };

};


/*---------------------------------------------------------------------------
           SocketParam
---------------------------------------------------------------------------*/
class SocketParam
{
public :

    enum ItemId
    {
        ITEM_ID_UNDEFINED = 0,

        ITEM_ID_CHANNEL,
        ITEM_ID_SSID,
        ITEM_ID_SECURITY_MODE,
        ITEM_ID_SECURITY_KEY,
        ITEM_ID_TYPE,
        ITEM_ID_PACKET_FORMAT,
        ITEM_ID_IP_ADDRESS,
        ITEM_ID_GW_IP_ADDRESS,
        ITEM_ID_REMOTE_IP_ADDRESS,
        ITEM_ID_PORT
    };

    static const char* ITEM_STR_CHANNEL;
    static const char* ITEM_STR_SSID;
    static const char* ITEM_STR_SECURITY_MODE;
    static const char* ITEM_STR_SECURITY_KEY;
    static const char* ITEM_STR_TYPE;
    static const char* ITEM_STR_PACKET_FORMAT;
    static const char* ITEM_STR_IP_ADDRESS;
    static const char* ITEM_STR_GW_IP_ADDRESS;
    static const char* ITEM_STR_REMOTE_IP_ADDRESS;
    static const char* ITEM_STR_PORT;

    int                channel;
    char               ssid[nn::wlan::Ssid::SsidLengthMax];
    nn::wlan::Security security;
    SocketType         type;
    PacketFormat       format;
    string             ipAddress;
    string             gwIpAddress;
    string             remoteIpAddress;
    uint16_t           port;

    SocketParam()
    {
        channel                     = 1;
        security.privacyMode        = nn::wlan::SecurityMode_Open;
        security.groupPrivacyMode   = nn::wlan::SecurityMode_Open;
        type                        = SOCKET_TYPE_UNKNOWN;
        format                      = PACKET_FORMAT_WIT;
        gwIpAddress                 = "192.168.11.1";
        ipAddress                   = "192.168.11.2";
        remoteIpAddress             = "192.168.11.3";
        port                        = 50285;

        nn::util::Strlcpy(ssid, "wireless_test_1", sizeof(ssid));
        std::memset(security.key, 0x00, sizeof(security.key));
    }

    void Print()
    {
        NN_LOG("  Channel                  : %d\n", channel);
        NN_LOG("  Ssid                     : %s\n", ssid);
        NN_LOG("  Security mode            : %s\n", ToString(security.privacyMode).c_str());
        NN_LOG("  Security key             : %s\n", security.key);
        NN_LOG("  Type                     : %s\n", ToStringSocketType(type).c_str());
        NN_LOG("  Packet format            : %s\n", ToStringPacketFormat(format).c_str());
        NN_LOG("  IP address               : %s\n", ipAddress.c_str());
        NN_LOG("  GW address               : %s\n", gwIpAddress.c_str());
        NN_LOG("  Remote IP address        : %s\n", remoteIpAddress.c_str());
        NN_LOG("  Port                     : %u\n", port);
    }

    ItemId GetItemId(const string& item)
    {
        if(strcmp(item.c_str(), ITEM_STR_CHANNEL) == 0)               return ITEM_ID_CHANNEL;
        if(strcmp(item.c_str(), ITEM_STR_SSID) == 0)                  return ITEM_ID_SSID;
        if(strcmp(item.c_str(), ITEM_STR_SECURITY_MODE) == 0)         return ITEM_ID_SECURITY_MODE;
        if(strcmp(item.c_str(), ITEM_STR_SECURITY_KEY) == 0)          return ITEM_ID_SECURITY_KEY;
        if(strcmp(item.c_str(), ITEM_STR_TYPE) == 0)                  return ITEM_ID_TYPE;
        if(strcmp(item.c_str(), ITEM_STR_PACKET_FORMAT) == 0)         return ITEM_ID_PACKET_FORMAT;
        if(strcmp(item.c_str(), ITEM_STR_IP_ADDRESS) == 0)            return ITEM_ID_IP_ADDRESS;
        if(strcmp(item.c_str(), ITEM_STR_GW_IP_ADDRESS) == 0)         return ITEM_ID_GW_IP_ADDRESS;
        if(strcmp(item.c_str(), ITEM_STR_REMOTE_IP_ADDRESS) == 0)     return ITEM_ID_REMOTE_IP_ADDRESS;
        if(strcmp(item.c_str(), ITEM_STR_PORT) == 0)                  return ITEM_ID_PORT;

        NN_LOG("  - failed : Unrecognized item name (%s)\n", item.c_str());

        return ITEM_ID_UNDEFINED;
    }

    bool SetParam(const string& item, const string& value)
    {
        ItemId id = GetItemId(item);
        if(id == ITEM_ID_UNDEFINED)
        {
            return false;
        }

        switch(id)
        {
        case ITEM_ID_CHANNEL               : return SetChannel(value);
        case ITEM_ID_SSID                  : return SetSsid(value);
        case ITEM_ID_SECURITY_MODE         : return SetSecurityMode(value);
        case ITEM_ID_SECURITY_KEY          : return SetSecurityKey(value);
        case ITEM_ID_TYPE                  : return SetType(value);
        case ITEM_ID_PACKET_FORMAT         : return SetPacketFormat(value);
        case ITEM_ID_IP_ADDRESS            : return SetIpAddress(value);
        case ITEM_ID_GW_IP_ADDRESS         : return SetGwIpAddress(value);
        case ITEM_ID_REMOTE_IP_ADDRESS     : return SetRemoteIpAddress(value);
        case ITEM_ID_PORT                  : return SetPort(value);
        default : return false;
        }

        return false;
    }

    bool SetChannel(const string& valueStr)
    {
        if(sscanf(valueStr.c_str(), "%d", &channel) == 1)
        {
            if((1 <= channel && channel <= 13) ||
               (36 <= channel && channel % 4 == 0) ||
               (channel == 60 || channel == 64) ||
               (100 <= channel && channel <= 144 && channel % 4 == 0) ||
               (149 <= channel && channel <= 165 && channel % 4 == 1))
            {
                return true;
            }
            else
            {
                NN_LOG("  - failed : Out of range (%u)\n", channel);
                return false;
            }
        }
        else
        {
            return false;
        }
    }

    bool SetSsid(const string& valueStr)
    {
        if(valueStr.size() > sizeof(ssid))
        {
            NN_LOG("  - failed : Out of range (%s)\n", valueStr.c_str());
            return false;
        }

        nn::util::Strlcpy(ssid, valueStr.c_str(), sizeof(ssid));

        return true;
    }

    bool SetSecurityMode(const string& valueStr)
    {
        nn::wlan::SecurityMode mode;
        string str = ToUpper(valueStr);

        if(str == "OPEN") mode = nn::wlan::SecurityMode_Open;
        else if(str == "WEP128-OPEN") mode = nn::wlan::SecurityMode_Wep128Open;
        else if(str == "WPA-AES") mode = nn::wlan::SecurityMode_WpaAes;
        else if(str == "WPA2-AES") mode = nn::wlan::SecurityMode_Wpa2Aes;
        else if(str == "WPA-TKIP") mode = nn::wlan::SecurityMode_WpaTkip;
        else if(str == "WPA2-TKIP") mode = nn::wlan::SecurityMode_Wpa2Tkip;
        else return false;

        security.privacyMode = mode;
        security.groupPrivacyMode = mode;

        return true;
    }

    bool SetSecurityKey(const string& valueStr)
    {
        if(sizeof(security.key) >= valueStr.size())
        {
            std::memcpy(security.key, valueStr.c_str(), valueStr.size());
        }
        else
        {
            NN_LOG("  - failed : Out of range (%s)\n", valueStr.c_str());
            return false;
        }

        return true;
    }

    bool SetType(const string& valueStr)
    {
        string str = ToUpper(valueStr);

        if(str == SOCKET_TYPE_TCP_STR) type = SOCKET_TYPE_TCP;
        else if(str == SOCKET_TYPE_UDP_STR) type = SOCKET_TYPE_UDP;
        else
        {
            type = SOCKET_TYPE_UNKNOWN;
            return false;
        }

        return true;
    }

    bool SetPacketFormat(const string& valueStr)
    {
        string str = ToUpper(valueStr);

        if(str == PACKET_FORMAT_WIT_STR) format = PACKET_FORMAT_WIT;
        else if(str == PACKET_FORMAT_RTP_STR) format = PACKET_FORMAT_RTP;
        else
        {
            format = PACKET_FORMAT_UNKNOWN;
            return false;
        }

        return true;
    }

    bool SetGwIpAddress(const string& valueStr)
    {
        string str = ToUpper(valueStr);

        in_addr in;
        int rtn = nn::socket::InetAton(str.c_str(), &in);
        if( rtn != 1 )
        {
            return false;
        }

        gwIpAddress = str;

        return true;
    }

    bool SetIpAddress(const string& valueStr)
    {
        string str = ToUpper(valueStr);

        if( str == "DHCP" )
        {
        }
        else
        {
            in_addr addr;
            int rtn = nn::socket::InetAton(str.c_str(), &addr);
            if( rtn != 1 )
            {
                return false;
            }
        }

        ipAddress = str;

        return true;
    }

    bool SetRemoteIpAddress(const string& valueStr)
    {
        string str = ToUpper(valueStr);

        if(str == "BROADCAST") remoteIpAddress = BROADCAST_IP_ADDRESS;
        else
        {
            in_addr in;
            int rtn = nn::socket::InetAton(str.c_str(), &in);
            if( rtn != 1 )
            {
                return false;
            }

            remoteIpAddress = str;
        }

        return true;
    }

    bool SetPort(const string& valueStr)
    {
        if(sscanf(valueStr.c_str(), "%hu", &port) == 1)
        {
            if(port <= PORT_NUMBER_MAX)
            {
                return true;
            }
            else
            {
                NN_LOG("  - failed : Out of range (%u)\n", port);
                return false;
            }
        }
        else
        {
            return false;
        }
    }
};


/*!--------------------------------------------------------------------------*
  @brief        Infra端末クラス
 *---------------------------------------------------------------------------*/
class InfraNode : public Node
{
/*---------------------------------------------------------------------------
  　　　　　メンバ変数
---------------------------------------------------------------------------*/
public:
protected:

    nn::wlan::Ssid       m_Ssid;
    nn::wlan::MacAddress m_Bssid;
    int16_t              m_Channel;
    nn::wlan::Security   m_Security;

    string m_IpAddress;
    string m_GwIpAddress;
    string m_RemoteIpAddress;

    // 接続維持用
    nn::os::SystemEventType    m_ConnectionEvent;
    nn::wlan::ConnectionStatus m_ConnectionStatus;

    bool m_IsWaiting;
    nn::os::Mutex m_Cs;

private:


/*---------------------------------------------------------------------------
  　　　　　メンバメソッド
---------------------------------------------------------------------------*/
protected:

    virtual nn::Result WlanInitialize();
    virtual nn::Result WlanFinalize();
    virtual nn::Result WlanOpenMode();
    virtual nn::Result WlanCloseMode();
    virtual nn::Result WlanLocalCreateRxEntry(uint32_t* pRxId, const uint16_t pProtocols[], const int32_t& count, const int32_t& capacity);
    virtual nn::Result WlanSocketCreateRxEntry(uint32_t* pRxid, const uint16_t pProtocols[], const int32_t& count, const int32_t& capacity);
    virtual nn::Result WlanLocalCreateRxEntryForAf(uint32_t* pRxId, const uint16_t pProtocols[], const int32_t& count, const int32_t& capacity);
    virtual nn::Result WlanDetectCreateRxEntryForAf(uint32_t* pRxId, const uint16_t pProtocols[], const int32_t& count, const int32_t& capacity);
    virtual nn::Result WlanLocalDeleteRxEntry(uint32_t* pRxId);
    virtual nn::Result WlanSocketDeleteRxEntry(uint32_t* pRxId);
    virtual nn::Result WlanLocalDeleteRxEntryForAf(uint32_t* pRxId);
    virtual nn::Result WlanDetectDeleteRxEntryForAf(uint32_t* pRxId);
    virtual nn::Result WlanLocalAddMatchingData(const uint32_t& rxId, const nn::wlan::ReceivedDataMatchInfo& pMatchInfo);
    virtual nn::Result WlanLocalRemoveMatchingData(const uint32_t& rxId, const nn::wlan::ReceivedDataMatchInfo& pMatchInfo);

    virtual nn::Result WlanGetFrame( uint32_t rxId, uint8_t pOutput[], size_t* pSize, size_t maxSize, int8_t* pRssi = NULL );
    virtual nn::Result WlanPutFrame( const uint8_t pInput[], size_t size, bool selfCts = false );
    virtual nn::Result WlanGetActionFrame(nn::wlan::MacAddress* pOutSrcMac, uint8_t pOutBuf[], size_t size, size_t* pOutSize, uint32_t rxId, uint16_t* pChannel, int16_t* pRssi);
    virtual nn::Result WlanPutActionFrame(const nn::wlan::MacAddress& dstMac, const uint8_t* pData, size_t size, int16_t channel, uint32_t dwellTime);
    virtual nn::Result WlanGetState(nn::wlan::WlanState *pSta);

    virtual nn::Result WlanCancelGetFrame(uint32_t rxId);
    virtual nn::Result WlanCancelGetActionFrame(uint32_t rxId);

    virtual nn::Result CreateRxEntry();
    virtual nn::Result DeleteRxEntry();

public:

    virtual nn::Result Initialize();
    virtual nn::Result Open();
    virtual nn::Result Close();
    virtual nn::Result Finalize();

    virtual nn::Result CancelConnect();
    virtual nn::wlan::MacAddress GetApAddress();
    virtual bool IsConnected();
    virtual bool IsAfRxReady();

    virtual nn::Result StartReceiveCommand(int32_t priority = nn::os::GetThreadCurrentPriority(nn::os::GetCurrentThread()))
    { return nn::ResultSuccess(); }
    virtual nn::Result StopReceiveCommand()
    { return nn::ResultSuccess(); }

protected:

private:

    /*---------------------------------------------------------------------------
      　　　　　コンストラクタ類
      ---------------------------------------------------------------------------*/
public:
    InfraNode();
    virtual ~InfraNode();

private:

    /*---------------------------------------------------------------------------
      　　　　　アクセッサ
      ---------------------------------------------------------------------------*/
public:

    virtual void GetMacAddressListInBss(nn::wlan::MacAddress list[], uint32_t& num)
    {
        if( num >= 1 )
        {
            if( m_ConnectionStatus.bssid != nn::wlan::MacAddress::CreateZeroMacAddress() )
            {
                list[0] = nn::wlan::MacAddress(m_ConnectionStatus.bssid);
                num = 1;
            }
            else
            {
                num = 0;
            }
        }
        return;
    }

    virtual bool GetRssi(vector<Rssi>* pRssiList)
    {
        NN_ASSERT(pRssiList);
        pRssiList->empty();
        if( m_ConnectionStatus.state == nn::wlan::ConnectionState_Connected )
        {
            Rssi rssi;
            nn::Result result = nn::wlan::Infra::GetRssi(&rssi.value);
            nn::os::Tick tick = nn::os::GetSystemTick();
            if( result.IsSuccess() )
            {
                rssi.address = GetBssid();
                rssi.recordTime = tick;
                pRssiList->push_back(rssi);
                return true;
            }
        }

        return false;

    }

    virtual nn::wlan::LinkLevel GetLinkLevel()
    {
        if( m_ConnectionStatus.state == nn::wlan::ConnectionState_Connected )
        {
            nn::wlan::LinkLevel linkLevel;
            nn::Result result = nn::wlan::Infra::GetLinkLevel(&linkLevel);
            if( result.IsSuccess() )
            {
                return linkLevel;
            }
        }

        return nn::wlan::LinkLevel_0;
    }

    nn::wlan::Ssid GetSsid(){ return m_Ssid; }
    void SetSsid(nn::wlan::Ssid ssid){ m_Ssid = ssid; }
    void SetSsid(std::string str)
    {
        nn::wlan::Ssid ssid(str.c_str());
        m_Ssid = ssid;
    }

    nn::wlan::MacAddress GetBssid(){ return m_Bssid; }
    void SetBssid(nn::wlan::MacAddress id){ m_Bssid = id; }

    virtual int16_t GetChannel(){ return m_Channel; }
    void SetChannel(int ch){ m_Channel = ch; }

    nn::wlan::Security GetSecurity(){ return m_Security; }
    void SetSecurity(nn::wlan::Security sec){ m_Security = sec; }
};

/*!--------------------------------------------------------------------------*
  @brief        Socket端末クラス
 *---------------------------------------------------------------------------*/
class SocketNode : public InfraNode
{
/*---------------------------------------------------------------------------
  　　　　　メンバ変数
---------------------------------------------------------------------------*/
public:
protected:

    SocketType m_Type;
    int m_Socket;
    fd_set m_ReadFds;
    uint16_t m_Port;

    string m_InterfaceName;
    bool m_EnableDhcp;
    bool m_SocketInitialized;
    bool m_IsIpAddrAssigned;

    PacketFormat m_RxPacketFormat;

    CyclicSequenceNumberCounter m_Errors;

private:

/*---------------------------------------------------------------------------
  　　　　　メンバメソッド
---------------------------------------------------------------------------*/
public:

    virtual nn::Result Initialize();
    virtual nn::Result Open();
    virtual nn::Result Close();
    virtual nn::Result Finalize();

    virtual nn::Result WlanGetFrame( uint32_t rxId, uint8_t pOutput[], size_t* pSize, size_t maxSize, int8_t* pRssi);

    virtual void MaintainConnection();
    virtual void StopMaintainConnection();

    virtual nn::Result Send(uint8_t data[], size_t size, uint8_t ieInd, size_t* pSentSize);
    virtual nn::Result Receive(uint8_t pOutput[], size_t* pSize, const size_t maxSize);
    virtual bool WaitReceive(nn::TimeSpan timeSpan);

    virtual bool IsConnected();

    virtual void ClearStatistics();

    void SetSocketType(SocketType type){ m_Type = type; }
    void SetPort(uint16_t port){ m_Port = port; }
    void SetIpAddress(string addr){ m_IpAddress = addr; }
    void SetGwIpAddress(string addr){ m_GwIpAddress = addr; }
    void SetRemoteIpAddress(string addr){ m_RemoteIpAddress = addr; }
    void SetRxPacketFormat(PacketFormat format)
    {
        m_RxPacketFormat = format;

        if( format == PACKET_FORMAT_WIT )
        {
            m_Errors.SetEndNo(0xffffffffffffffff);
        }
        else if( format == PACKET_FORMAT_RTP )
        {
            m_Errors.SetEndNo(0xffff);
        }
    }
    void EnableDhcp() { m_EnableDhcp = true; };
    void DisableDhcp() { m_EnableDhcp = false; }

    SocketType GetSocketType(){ return m_Type; }
    uint16_t GetPort(){ return m_Port; }
    string GetIpAddress(){ return m_IpAddress; }
    string GetGwIpAddress(){ return m_GwIpAddress; }
    string GetRemoteIpAddress(){ return m_RemoteIpAddress; }
    PacketFormat GetRxPacketFormat() { return m_RxPacketFormat; }
    CyclicSequenceNumberCounter GetError() { return m_Errors; }

protected:

    virtual nn::Result SocketSend( const uint8_t pInput[], size_t size, int flags, size_t* pSentSize ) = 0;
    virtual nn::Result SocketRecv( uint8_t pOutput[], size_t* pSize, sockaddr_in* pFromAddr, const size_t maxSize, int flags = 0 ) = 0;

    virtual void CreateSocket(int& socket) = 0;
    virtual void ConfigureInterface();

private:

/*---------------------------------------------------------------------------*
      　　　　　コンストラクタ類
 *---------------------------------------------------------------------------*/
public:

    SocketNode();
    virtual ~SocketNode();

};


/*!--------------------------------------------------------------------------*
  @brief        UDP 端末クラス
 *---------------------------------------------------------------------------*/
class UdpNode : public SocketNode
{
/*---------------------------------------------------------------------------
  　　　　　メンバ変数
---------------------------------------------------------------------------*/
public:
protected:

    sockaddr_in m_TxSockAddr;
    sockaddr_in m_RxSockAddr;

private:

/*---------------------------------------------------------------------------
  　　　　　メンバメソッド
---------------------------------------------------------------------------*/
public:

    virtual nn::Result Initialize();
    virtual nn::Result Open();
    virtual nn::Result Close();
    virtual nn::Result Finalize();

protected:

    virtual nn::Result SocketSend( const uint8_t pInput[], size_t size, int flags, size_t* pSentSize );
    virtual nn::Result SocketRecv( uint8_t pOutput[], size_t* pSize, sockaddr_in* pFromAddr, const size_t maxSize, int flags = 0 );

    virtual void CreateSocket(int& socket);
    virtual nn::Result OpenUdpServer(sockaddr_in& socket);
    virtual nn::Result OpenUdpClient(sockaddr_in& socket);

private:

/*---------------------------------------------------------------------------
  　　　　　コンストラクタ類
---------------------------------------------------------------------------*/
public:

    UdpNode();
    virtual ~UdpNode();

};

} // WlanTest
