﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <algorithm>
#include <cstring>
#include <mutex>
#include <nn/socket.h>

#include <nn/nifm/nifm_ApiInternetConnectionStatus.h>
#include <nn/nim/detail/nim_Log.h>
#include <nn/nim/nim_Config.h>
#include <nn/nim/nim_Result.h>
#include <nn/nim/srv/nim_LocalCommunicationConfig.h>
#include <nn/nim/srv/nim_LocalCommunicationDeliveryProtocol.h>
#include <nn/result/result_HandlingUtility.h>
#include <nn/util/util_ScopeExit.h>

namespace nn { namespace nim { namespace srv {

namespace
{

const char LocalCommunicationDeliveryProtocolRequestSignature[4] = { 'L', 'D', 'R', 'q' };   // Local communication Delivery protocol ReQuest
const char LocalCommunicationDeliveryProtocolResponseSignature[4] = { 'L', 'D', 'R', 's' };  // Local communication Delivery protocol ReSponse

} // namespace

LocalCommunicationDeliveryProtocolHeader MakeLocalCommunicationDeliveryProtocolRequestHeader(LocalCommunicationDeliveryProtocolTag tag, int64_t size)
{
    LocalCommunicationDeliveryProtocolHeader header{};
    std::memcpy(header.signature, LocalCommunicationDeliveryProtocolRequestSignature, sizeof(header.signature));
    header.tag = tag;
    header.additionalHeaderSize = 0;
    header.size = size;
    return header;
}

LocalCommunicationDeliveryProtocolHeader MakeLocalCommunicationDeliveryProtocolResponseHeader(LocalCommunicationDeliveryProtocolTag tag, int64_t size)
{
    LocalCommunicationDeliveryProtocolHeader header{};
    std::memcpy(header.signature, LocalCommunicationDeliveryProtocolResponseSignature, sizeof(header.signature));
    header.tag = tag;
    header.additionalHeaderSize = 0;
    header.size = size;
    return header;
}

Result CheckLocalCommunicationDeliveryProtocolSignature(const LocalCommunicationDeliveryProtocolHeader& header) NN_NOEXCEPT
{
    NN_RESULT_THROW_UNLESS(
            std::memcmp(header.signature, LocalCommunicationDeliveryProtocolRequestSignature, sizeof(header.signature)) == 0 ||
            std::memcmp(header.signature, LocalCommunicationDeliveryProtocolResponseSignature, sizeof(header.signature)) == 0,
            ResultLocalCommunicationInvalidSignature());

    NN_RESULT_SUCCESS;
}

void LocalCommunicationDeliveryProtocolBase::Cancel() NN_NOEXCEPT
{
    std::lock_guard<os::Mutex> guard(m_CancelMutex);
    if (m_Socket >= 0)
    {
        if (socket::Shutdown(m_Socket, socket::ShutdownMethod::Shut_RdWr) < 0)
        {
            socket::Errno lastErrNo = socket::GetLastError();
            NN_UNUSED(lastErrNo);
            NN_DETAIL_NIM_TRACE("socket::Shutdown is failed: %d\n", lastErrNo);
        }
    }
    m_CancelRequest = true;
}

void LocalCommunicationDeliveryProtocolBase::ResetCancel() NN_NOEXCEPT
{
    std::lock_guard<os::Mutex> guard(m_CancelMutex);
    m_CancelRequest = false;
}

bool LocalCommunicationDeliveryProtocolBase::IsCancelRequested() const NN_NOEXCEPT
{
    std::lock_guard<os::Mutex> guard(m_CancelMutex);
    return m_CancelRequest;
}

Result LocalCommunicationDeliveryProtocolBase::Send(const LocalCommunicationDeliveryProtocolHeader& header, const void* buffer, size_t size, UpdateSendBufferFunction updateFunc) NN_NOEXCEPT
{
    NN_RESULT_DO(SendHeader(header));
    NN_RESULT_DO(SendData(buffer, size, header.size, updateFunc));

    NN_RESULT_SUCCESS;
}

Result LocalCommunicationDeliveryProtocolBase::SendHeader(const LocalCommunicationDeliveryProtocolHeader& header) NN_NOEXCEPT
{
    NN_RESULT_THROW_UNLESS(!IsCancelRequested(), ResultLocalCommunicationConnectionCanceled());
    NN_RESULT_DO(CheckLocalCommunicationDeliveryProtocolSignature(header));
    NN_RESULT_DO(HandleSocketReturnCode(static_cast<int>(socket::Send(m_Socket, &header, sizeof(header), socket::MsgFlag::Msg_None))));

    NN_RESULT_SUCCESS;
}

Result LocalCommunicationDeliveryProtocolBase::SendData(const void* buffer, size_t size, int64_t contentLength, UpdateSendBufferFunction updateFunc) NN_NOEXCEPT
{
    NN_RESULT_THROW_UNLESS(!IsCancelRequested(), ResultLocalCommunicationConnectionCanceled());
    int64_t offset = 0;
    while (offset < contentLength)
    {
        NN_RESULT_THROW_UNLESS(!IsCancelRequested(), ResultLocalCommunicationConnectionCanceled());
        size_t writeSize = ((contentLength - offset) > static_cast<int64_t>(size)) ? size : static_cast<size_t>(contentLength - offset);
        if (updateFunc)
        {
            NN_RESULT_DO(updateFunc(writeSize, offset));
        }
        NN_RESULT_DO(HandleSocketReturnCode(static_cast<int>(socket::Send(m_Socket, buffer, writeSize, socket::MsgFlag::Msg_None))));
        offset += writeSize;
    }

    NN_RESULT_SUCCESS;
}

Result LocalCommunicationDeliveryProtocolBase::ReceiveHeader(LocalCommunicationDeliveryProtocolHeader* outValue) NN_NOEXCEPT
{
    NN_RESULT_THROW_UNLESS(!IsCancelRequested(), ResultLocalCommunicationConnectionCanceled());

    // 初期バージョンヘッダを受信
    LocalCommunicationDeliveryProtocolHeader header;
    NN_RESULT_DO(HandleSocketReturnCode(static_cast<int>(socket::RecvFrom(m_Socket, &header, sizeof(header), socket::MsgFlag::Msg_WaitAll, nullptr, nullptr))));
    NN_RESULT_DO(CheckLocalCommunicationDeliveryProtocolSignature(header));
    NN_RESULT_THROW_UNLESS(header.size >= 0, ResultLocalCommunicationInvalidDataSize());
    NN_RESULT_THROW_UNLESS(header.additionalHeaderSize <= LocalCommunicationDeliveryProtocolAdditionalHeaderSizeMax, ResultLocalCommunicationInvalidAdditionalHeaderSize());

    size_t remainSize = header.additionalHeaderSize;

    // 既知追加ヘッダを受信
    // (いまのところ追加ヘッダはすべて未知)

    // 未知追加ヘッダを受信して破棄
    Bit8 buf[128];
    while (remainSize > 0)
    {
        size_t recvSize = std::min(sizeof(buf), static_cast<size_t>(remainSize));
        auto receivedSize = socket::RecvFrom(m_Socket, buf, recvSize, socket::MsgFlag::Msg_WaitAll, nullptr, nullptr);
        NN_RESULT_DO(HandleSocketReturnCode(static_cast<int>(receivedSize)));
        NN_RESULT_THROW_UNLESS(receivedSize == static_cast<ssize_t>(recvSize), ResultLocalCommunicationSessionClosed());
        remainSize -= recvSize;
    }

    *outValue = header;
    NN_RESULT_SUCCESS;
}

Result LocalCommunicationDeliveryProtocolBase::ReceiveData(void* buffer, size_t bufferSize, int64_t contentLength, ProcessReceiveBufferFunction recvFunc) NN_NOEXCEPT
{
    int64_t offset = 0;
    while (offset < contentLength)
    {
        NN_RESULT_THROW_UNLESS(!IsCancelRequested(), ResultLocalCommunicationConnectionCanceled());
        uint64_t remainSize = static_cast<uint64_t>(contentLength - offset);
        size_t tryReceiveSize = static_cast<size_t>(std::min(remainSize, static_cast<uint64_t>(bufferSize)));
        auto recvSize = socket::RecvFrom(m_Socket, buffer, tryReceiveSize, socket::MsgFlag::Msg_WaitAll, nullptr, nullptr);
        NN_RESULT_DO(HandleSocketReturnCode(static_cast<int>(recvSize)));
        if (recvFunc)
        {
            NN_RESULT_DO(recvFunc(recvSize, offset));
        }
        offset += recvSize;
    }
    NN_RESULT_SUCCESS;
}

Result LocalCommunicationDeliveryProtocolBase::WaitClientConnection(int socketFd) NN_NOEXCEPT
{
    while (NN_STATIC_CONDITION(true))
    {
        NN_RESULT_THROW_UNLESS(!IsCancelRequested(), ResultLocalCommunicationConnectionCanceled());

        socket::TimeVal tv{};
        tv.tv_sec = 1;
        socket::FdSet set;
        socket::FdSetZero(&set);
        socket::FdSetSet(socketFd, &set);
        auto returnValue = socket::Select(socketFd + 1, &set, nullptr, nullptr, &tv);
        NN_RESULT_DO(HandleSocketReturnCode(returnValue));
        if (returnValue > 0)
        {
            break;
        }
    }
    NN_RESULT_SUCCESS;
}

Result LocalCommunicationDeliveryProtocolBase::HandleSocketReturnCode(int returnCode) NN_NOEXCEPT
{
    if (returnCode < 0)
    {
        NN_RESULT_THROW_UNLESS(!IsCancelRequested(), ResultLocalCommunicationConnectionCanceled());
        socket::Errno lastErrNo = socket::GetLastError();
        NN_DETAIL_NIM_TRACE("Socket Error: %d\n", lastErrNo);
        NN_RESULT_THROW_UNLESS(!(lastErrNo == socket::Errno::EHostUnreach ||
                                 lastErrNo == socket::Errno::EHostDown ||
                                 lastErrNo == socket::Errno::ENetDown ||
                                 lastErrNo == socket::Errno::EConnReset ||
                                 lastErrNo == socket::Errno::EPipe)
                               , ResultLocalCommunicationSessionClosed());
        NN_RESULT_THROW(ResultLocalCommunicationSocketUnexpectedError());
    }
    NN_RESULT_SUCCESS;
}

Result LocalCommunicationDeliveryProtocolBase::SetLinger(int socketFd) NN_NOEXCEPT
{
    socket::Linger linger { 1, 0 }; // onoff = 1, linger = 0
    NN_RESULT_DO(HandleSocketReturnCode(socket::SetSockOpt(socketFd, socket::Level::Sol_Socket, socket::Option::So_Nn_Linger, &linger, sizeof(linger))));

    NN_RESULT_SUCCESS;
}

LocalCommunicationDeliveryProtocolServer::~LocalCommunicationDeliveryProtocolServer() NN_NOEXCEPT
{
    Finalize();
}

Result LocalCommunicationDeliveryProtocolServer::Initialize(uint32_t ipv4, uint16_t port) NN_NOEXCEPT
{
    m_ServerSocket = -1;
    m_Socket = -1;
    socket::SockAddrIn saPeer = {};

    int socketFd = socket::Socket(socket::Family::Af_Inet, socket::Type::Sock_Stream, socket::Protocol::IpProto_Tcp);
    NN_RESULT_DO(HandleSocketReturnCode(socketFd));
    int fcntlFlag = socket::Fcntl(socketFd, socket::FcntlCommand::F_GetFl, 0);
    NN_RESULT_DO(HandleSocketReturnCode(fcntlFlag));
    NN_RESULT_DO(HandleSocketReturnCode(socket::Fcntl(socketFd, socket::FcntlCommand::F_SetFl, fcntlFlag | static_cast<int>(socket::FcntlFlag::O_NonBlock))));

    NN_RESULT_DO(SetLinger(socketFd));
    int reuseAddr = 1;
    NN_RESULT_DO(HandleSocketReturnCode(socket::SetSockOpt(socketFd, socket::Level::Sol_Socket, socket::Option::So_ReuseAddr, &reuseAddr, sizeof(reuseAddr))));

    bool isSuccess = false;
    NN_UTIL_SCOPE_EXIT
    {
        if (!isSuccess)
        {
            socket::Close(socketFd);
        }
    };

    // TORIAEZU
    saPeer.sin_addr.S_addr = socket::InetHtonl(socket::InAddr_Any);
    saPeer.sin_port = socket::InetHtons(port);
    saPeer.sin_family = socket::Family::Af_Inet;

    NN_RESULT_DO(HandleSocketReturnCode(socket::Bind(socketFd, reinterpret_cast<socket::SockAddr *>(&saPeer), sizeof(saPeer))));

    NN_RESULT_DO(HandleSocketReturnCode(socket::Listen(socketFd, 1)));

    m_PeerAddress = ipv4;
    m_ServerSocket = socketFd;
    isSuccess = true;

    NN_RESULT_SUCCESS;
}

Result LocalCommunicationDeliveryProtocolServer::WaitClient() NN_NOEXCEPT
{
    socket::SockAddrIn clientAddr = {};
    socket::SockLenT clientAddrSize = sizeof(clientAddr);

    int clientSocket = -1;
    for (;;)
    {
        NN_RESULT_DO(WaitClientConnection(m_ServerSocket));
        clientSocket = socket::Accept(m_ServerSocket, reinterpret_cast<socket::SockAddr *>(&clientAddr), &clientAddrSize);
        NN_RESULT_DO(HandleSocketReturnCode(clientSocket));

        if (clientAddr.sin_addr.S_addr == socket::InetHtonl(m_PeerAddress))
        {
            break;
        }

        socket::Close(clientSocket);
        clientSocket = -1;
    }
    NN_RESULT_DO(SetLinger(clientSocket));

    // LDN のネットワーク以外にコンテンツを垂れ流すのを防ぐため、インターネット接続ではないことを確認する。
    nifm::InternetConnectionStatus status;
    NN_RESULT_THROW_UNLESS(nifm::GetInternetConnectionStatus(&status).IsFailure(), ResultLocalCommunicationSocketUnexpectedError());

    // クライアントのソケットは同期的に処理をしたいので、フラグを元に戻しておく
    int fcntlFlag = socket::Fcntl(clientSocket, socket::FcntlCommand::F_GetFl, socket::FcntlFlag::None);
    NN_RESULT_DO(HandleSocketReturnCode(fcntlFlag));
    NN_RESULT_DO(HandleSocketReturnCode(socket::Fcntl(clientSocket, socket::FcntlCommand::F_SetFl, (fcntlFlag & static_cast<int>(~socket::FcntlFlag::O_NonBlock)))));

    // クライアントのソケットは送信スループットが必要なので送信と受信のバッファサイズをデフォルトと逆にする
    const int RecvBufSize = SocketSendBufferSize;
    const int SendBufSize = SocketReceiveBufferSize;
    socket::SetSockOpt(clientSocket, socket::Level::Sol_Socket, socket::Option::So_RcvBuf, &RecvBufSize, sizeof(RecvBufSize));
    socket::SetSockOpt(clientSocket, socket::Level::Sol_Socket, socket::Option::So_SndBuf, &SendBufSize, sizeof(SendBufSize));

    m_Socket = clientSocket;
    NN_RESULT_SUCCESS;
}

void LocalCommunicationDeliveryProtocolServer::Finalize() NN_NOEXCEPT
{
    if (m_Socket != -1)
    {
        socket::Close(m_Socket);
        m_Socket = -1;
    }
    if (m_ServerSocket != -1)
    {
        socket::Close(m_ServerSocket);
        m_ServerSocket = -1;
    }
}

LocalCommunicationDeliveryProtocolClient::~LocalCommunicationDeliveryProtocolClient() NN_NOEXCEPT
{
    Finalize();
}

Result LocalCommunicationDeliveryProtocolClient::Initialize(uint32_t ipv4, uint16_t port) NN_NOEXCEPT
{
    m_Socket = -1;
    socket::SockAddrIn saPeer = {};

    int socketFd = socket::Socket(socket::Family::Af_Inet, socket::Type::Sock_Stream, socket::Protocol::IpProto_Tcp);
    NN_RESULT_DO(HandleSocketReturnCode(socketFd));

    bool isSuccess = false;
    NN_UTIL_SCOPE_EXIT
    {
        if (!isSuccess)
        {
            socket::Close(socketFd);
        }
    };

    // TORIAEZU: Win 環境だと Bit32 を in_addr に入れられないので、テスト用に文字列から通信先を指定する
    saPeer.sin_addr.S_addr = socket::InetHtonl(ipv4);

#if defined(NN_BUILD_CONFIG_OS_WIN)
    socket::InAddr peerAddr;
    socket::InetAton("127.0.0.1", &peerAddr);
    saPeer.sin_addr = peerAddr;
#endif

    saPeer.sin_port = socket::InetHtons(port);
    saPeer.sin_family = socket::Family::Af_Inet;

    // NonBlocking に変更
    int fcntlFlag = socket::Fcntl(socketFd, socket::FcntlCommand::F_GetFl, 0);
    NN_RESULT_DO(HandleSocketReturnCode(fcntlFlag));
    NN_RESULT_DO(HandleSocketReturnCode(socket::Fcntl(socketFd, socket::FcntlCommand::F_SetFl, (fcntlFlag | static_cast<int>(socket::FcntlFlag::O_NonBlock)))));

    auto connectReturnCode = socket::Connect(
        socketFd, reinterpret_cast<socket::SockAddr *>(&saPeer), sizeof(saPeer));
    // NonBlocking で Connect をしているので、InProgress になったら Select で接続完了を待つ
    // InProgress 以外のエラーだと終了させる
    if((connectReturnCode < 0) && (socket::GetLastError() == socket::Errno::EInProgress))
    {
        // 5 秒で Connect できなかったらタイムアウトにする
        socket::TimeVal tv{};
        tv.tv_sec = 5;
        socket::FdSet set;
        socket::FdSetZero(&set);
        socket::FdSetSet(socketFd, &set);
        auto selectReturnCode = socket::Select(socketFd + 1, nullptr, &set, nullptr, &tv);
        // タイムアウトは ResultLocalCommunicationSocketUnexpectedError にする
        NN_RESULT_THROW_UNLESS(selectReturnCode != 0, ResultLocalCommunicationSocketUnexpectedError());
        // Select がエラーだった場合は Errno を表示して Error とする
        NN_RESULT_DO(HandleSocketReturnCode(selectReturnCode));
    }
    else
    {
        NN_RESULT_DO(HandleSocketReturnCode(connectReturnCode));
    }

    NN_RESULT_DO(SetLinger(socketFd));

    // NonBlocking から Bloking に変更
    fcntlFlag = socket::Fcntl(socketFd, socket::FcntlCommand::F_GetFl, socket::FcntlFlag::None);
    NN_RESULT_DO(HandleSocketReturnCode(fcntlFlag));
    NN_RESULT_DO(HandleSocketReturnCode(socket::Fcntl(socketFd, socket::FcntlCommand::F_SetFl, (fcntlFlag & static_cast<int>(~socket::FcntlFlag::O_NonBlock)))));

    m_Socket = socketFd;
    isSuccess = true;

    NN_RESULT_SUCCESS;
}

void LocalCommunicationDeliveryProtocolClient::Finalize() NN_NOEXCEPT
{
    if (m_Socket != -1)
    {
        socket::Close(m_Socket);
        m_Socket = -1;
    }
}

}}} // nn::nim::srv
