﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <functional>
#include <thread>

#include <nn/nn_Log.h>
#include <nn/os.h>

#include <nn/htclow.h>
#include <nn/htclow/detail/htclow_InternalTypes.h>
#include <nn/htclow/detail/htclow_DebugApi.h>

#include <nnt/nntest.h>
#include <nnt/result/testResult_Assert.h>

#include "../../../../../Programs/Eris/Sources/Libraries/htclow/server/htclow_Packet.h"

#include "../testHtclow_Util.h"

#include "testHtclow_PacketUtil.h"

namespace nnt { namespace htclow {

namespace {
    const nn::htclow::ModuleId TestModuleId = static_cast<nn::htclow::ModuleId>(0);
    const nn::htclow::ChannelId TestChannelId = static_cast<nn::htclow::ChannelId>(0);

    void NormalHandshakeTestCommon(int16_t synAckVersion)
    {
        nn::htclow::Module module(TestModuleId);
        nn::htclow::Channel channel(&module, TestChannelId);

        const auto channelInternal = nn::htclow::detail::ConvertChannelType(*(channel.GetBase()));

        std::thread connectThread([&]() NN_NOEXCEPT {
            NNT_HTCLOW_ASSERT_RESULT_SUCCESS(channel.Connect());
        });

        // ターゲットから Syn パケット受信
        const auto synPacket = ReceivePacketForTest();
        AssertSynPacket(*synPacket, channelInternal);

        // ターゲットに SynAck パケット送信
        auto synAckPacket = MakeSynAckPacket(channelInternal);
        synAckPacket->GetHeader()->version = synAckVersion;
        SendPacketForTest(*synAckPacket);

        connectThread.join();

        // ターゲットから Data パケットを受信してバージョンを確認
        const int dataSize = 4;
        const auto data = MakeRandomArray(dataSize, 1);

        size_t size;
        channel.Send(&size, data.get(), dataSize);
        ASSERT_EQ(dataSize, size);

        const auto dataPacket = ReceivePacketForTest();
        ASSERT_EQ(synAckVersion, dataPacket->GetHeader()->version);
    }

    void AbnormalHandshakeTestCommon(std::function<std::unique_ptr<nn::htclow::server::SendPacket>()> makeSynAckPacket)
    {
        nn::htclow::Module module(TestModuleId);
        nn::htclow::Channel channel(&module, TestChannelId);

        const auto channelInternal = nn::htclow::detail::ConvertChannelType(*(channel.GetBase()));

        nn::Result connectResult;
        std::thread connectThread([&connectResult, &channel]() NN_NOEXCEPT {
            connectResult = channel.Connect();
        });

        // ターゲットから Syn パケット受信
        const auto synPacket = ReceivePacketForTest();
        AssertSynPacket(*synPacket, channelInternal);

        // ターゲットに SynAck パケット送信 (makeSynAckHeader 使用)
        const auto synAckPacket = makeSynAckPacket();
        SendPacketForTest(*synAckPacket);

        connectThread.join();
        NNT_HTCLOW_ASSERT_RESULT_INCLUDED(nn::htclow::ResultChannelClosed(), connectResult);
    }
}

class ProtocolTest : public ::testing::Test
{
    virtual void SetUp() NN_NOEXCEPT NN_OVERRIDE
    {
        nn::htclow::detail::OpenDriver(nn::htclow::detail::DriverType::Debug);
    }

    virtual void TearDown() NN_NOEXCEPT NN_OVERRIDE
    {
        nn::htclow::detail::CloseDriver();
    }
};

TEST_F(ProtocolTest, DefaultVersion)
{
    NormalHandshakeTestCommon(GetMaxVersion());
}

TEST_F(ProtocolTest, LowerVersion)
{
    NormalHandshakeTestCommon(0);
}

TEST_F(ProtocolTest, IncorrectVersion)
{
    const nn::htclow::detail::ChannelInternalType channel = { TestChannelId, 0, TestModuleId };

    AbnormalHandshakeTestCommon([=]() {
        auto packet = MakeSynAckPacket(channel);
        packet->GetHeader()->version += 1; // Syn のバージョンより高いバージョンを SynAck で返すとエラーになる
        return packet;
    });
}

TEST_F(ProtocolTest, IncorrectProtocolId)
{
    const int16_t IncorrectProtocolId = -1;
    const nn::htclow::detail::ChannelInternalType channel = { TestChannelId, 0, TestModuleId };

    AbnormalHandshakeTestCommon([=]() {
        auto packet = MakeSynAckPacket(channel);
        packet->GetHeader()->protocol = IncorrectProtocolId;
        return packet;
    });
}

TEST_F(ProtocolTest, IncorrectBodySize)
{
    const int16_t IncorrectSynAckBodySize = 1;
    const nn::htclow::detail::ChannelInternalType channel = { TestChannelId, 0, TestModuleId };

    AbnormalHandshakeTestCommon([=]() {
        auto packet = MakeSynAckPacket(channel);
        packet->GetHeader()->bodySize = IncorrectSynAckBodySize;
        return packet;
    });
}

}}
