﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <cstring>
#include <cstdlib>

#include <nn/os.h>
#include <nn/init.h>

#include <nn/nn_Abort.h>
#include <nn/nn_SdkLog.h>
#include <nn/nn_Common.h>
#include <nn/nn_Log.h>

#include <nn/socket/socket_Api.h>
#include <nn/socket/socket_Types.h>
#include <nn/nn_Assert.h>
#include <nn/os/os_Thread.h>


#include <nnt.h>

namespace nnt { namespace {

        const char* s_HostName = "www.nintendo.com";
        const char* s_HostNameToConnectTo = "www.example.com";
        const char* s_NintendoIpAddressString = "205.166.76.26";
        const char* s_Service = "http";

        // hostname can't contain url
        const char* s_BadHostname = "https://www.google.com";

        // poorly formed ip address
        const char* s_BadIpAddressString = "1..2.3.4";

        // there is presently no such service
        const char* s_BadService = "nmzxcjasdfuioewqio";

        // again for example.com
        const uint32_t s_NintendoIpAddress = nn::socket::InetHtonl(0xCDA64C1A);

        // breaks ip address constraints
        const uint32_t s_BadIpAddress = nn::socket::InetHtonl(0x00000001);

        // maximum number of threads for concurrent test
        const uint32_t s_MaxThreads = 100;

        // maximum number of requests per thread
        const uint32_t s_MaxRequestsPerThread = 10; // 100 also tested

        void TestGetHostByNameInternal(bool shouldFail, const char* hostName, unsigned int retryCount=0)
        {
            // TODO
            // herror should be implemented

            int i;
            nn::socket::HostEnt *he = NULL;
            nn::socket::InAddr **addr_list;

            for (;;)
            {
                he = nn::socket::GetHostEntByName(hostName);
                if (he != nullptr)
                {
                    break;
                }
                else if (he == nullptr)
                {
                    if (nn::socket::GetLastError() == nn::socket::Errno::EAgain && retryCount != 0)
                    {
                        --retryCount;
                        continue;
                    }
                    // add failure case here (herror vs perror)
                    else if ( shouldFail == false )
                    {
                        ADD_FAILURE();
                    };
                    return;
                }
            };

            // print information about this host:
            //NN_SDK_LOG("Official name is: %s\n", he->h_name);
            //NN_SDK_LOG("IP address: %s\n", inet_ntoa(*(nn::socket::InAddr*)he->h_addr));
            //NN_SDK_LOG("All addresses: ");
            addr_list = (nn::socket::InAddr **)he->h_addr_list;
            for(i = 0; addr_list[i] != NULL; i++)
            {
                //NN_SDK_LOG("%s ", inet_ntoa(*addr_list[i]));
            }
            //NN_SDK_LOG("\n");

            if ( shouldFail == true )
            {
                ADD_FAILURE();
            }
        }

        void TestGetHostByAddrInternal(bool shouldFail, const char* address)
        {
            nn::socket::HostEnt *he = NULL;
            nn::socket::InAddr addr;
            nn::socket::InetAton(address, &addr);
            he = nn::socket::GetHostEntByAddr(&addr, sizeof(addr), nn::socket::Family::Af_Inet);

            if (he == NULL)
            {
                if (shouldFail == false)
                {
                    ADD_FAILURE();
                    return;
                }
                return;
            }

            //NN_SDK_LOG("Host name: %s\n", he->h_name);

            if ( shouldFail == true )
            {
                ADD_FAILURE();
            }

        }

        void TestGetAddrInfoInternal(bool shouldFail, const char* host, const char* service)
        {
            // todo: fix af_unspec
            // fix non null servinfo
            int sockfd;
            nn::socket::AddrInfo hints, *servinfo = NULL, *p;
            nn::socket::AiErrno rv;

            memset(&hints, 0, sizeof hints);
            hints.ai_family = nn::socket::Family::Af_Inet; // use Af_Inet6 to force IPv6 -- FIX Af_Unspec
            hints.ai_socktype = nn::socket::Type::Sock_Stream;

            // we use a different host here
            if ((rv = nn::socket::GetAddrInfo(host, service, &hints, &servinfo)) != nn::socket::AiErrno::EAi_Success)
            {
                //NN_SDK_LOG("GetAddrInfo (host=%s, service=%s) returned (%d): %s\n",
                //host, service, rv, gai_strerror(rv));
                if (shouldFail == false)
                {
                    ADD_FAILURE();
                };
                return;
            }

            // loop through all the results and connect to the first we can
            for(p = servinfo; p != NULL; p = p->ai_next)
            {
                nn::socket::SockAddrIn* pin = (nn::socket::SockAddrIn*) p->ai_addr;
                NN_UNUSED(pin);
                if ((sockfd = nn::socket::Socket(p->ai_family, p->ai_socktype,
                                                 p->ai_protocol)) == -1)
                {
                    //perror("socket"); // don't fail test if the socket
                    continue;         // can't be created ...
                }
                if (nn::socket::Connect(sockfd, p->ai_addr, p->ai_addrlen) == -1)
                {
                    nn::socket::Close(sockfd);    // ... or connected
                    /*
                      NN_SDK_LOG("connect to %d.%d.%d.%d:%d failure\n" ,
                               (pin->sin_addr.S_addr >> 24) & 0xFF,
                               (pin->sin_addr.S_addr >> 16) & 0xFF,
                               (pin->sin_addr.S_addr >>  8) & 0xFF,
                               (pin->sin_addr.S_addr      ) & 0xFF,
                               pin->sin_port);
                    */
                    //perror("connect");
                    continue;
                }

/*                NN_SDK_LOG("connect to %d.%d.%d.%d:%d success\n" ,
                           (pin->sin_addr.S_addr >> 24) & 0xFF,
                           (pin->sin_addr.S_addr >> 16) & 0xFF,
                           (pin->sin_addr.S_addr >>  8) & 0xFF,
                           (pin->sin_addr.S_addr      ) & 0xFF,
                           pin->sin_port);
*/
                nn::socket::Close(sockfd);       // done
                if (p == NULL)
                {
                    //NN_SDK_LOG("failed to connect\n");
                    if ( shouldFail == false )
                    {
                        ADD_FAILURE();
                    }
                    return;
                };
            };

            nn::socket::FreeAddrInfo(servinfo);
            if ( shouldFail == true )
            {
                ADD_FAILURE();
            }

        }

        void TestGetNameInfoInternal(bool shouldFail, const nn::socket::SockAddrIn & sockAddrIn)
        {
            char host[1024];
            char service[1024];

            if ( nn::socket::AiErrno::EAi_Success == nn::socket::GetNameInfo(reinterpret_cast<const nn::socket::SockAddr*>(&sockAddrIn),
                                                                             sizeof(sockAddrIn),
                                                                             host, sizeof(host),
                                                                             service, sizeof(service),
                                                                             nn::socket::NameInfoFlag::Ni_None) )
            {
                NN_SDK_LOG("   host: %s\n", host);
                NN_SDK_LOG("service: %s\n", service);

                if ( shouldFail == true)
                {
                    ADD_FAILURE();
                }
            };

            if ( shouldFail == true)
            {
                ADD_FAILURE();
            };
        }

        TEST(Resolver, GetHostByNameGoodHost)
        {
            TestGetHostByNameInternal(false, s_HostName);
        }

        TEST(Resolver, GetHostByNameBadHost)
        {
            TestGetHostByNameInternal(true, s_BadHostname);
        }

        TEST(Resolver, GetHostByAddrGoodAddress)
        {
            TestGetHostByAddrInternal(false, s_NintendoIpAddressString);
        }

        TEST(Resolver, GetHostByAddrBadAddress)
        {
            TestGetHostByAddrInternal(true, s_BadIpAddressString);
        }

        TEST(Resolver, GetAddrInfoGoodHostGoodService)
        {
            TestGetAddrInfoInternal(false, s_HostNameToConnectTo, s_Service);
        }

        TEST(Resolver, GetAddrInfoGoodHostBadService)
        {
            TestGetAddrInfoInternal(true, s_HostNameToConnectTo, s_BadService);
        }

        TEST(Resolver, GetAddrInfoBadHostGoodService)
        {
            TestGetAddrInfoInternal(true, s_BadHostname, s_Service);
        }

        TEST(Resolver, GetAddrInfoBadHostBadService)
        {
            TestGetAddrInfoInternal(true, s_BadHostname, s_BadService);
        }

        TEST(Resolver, GetNameInfoGoodSockaddr)
        {
            nn::socket::SockAddrIn sin;
            memset(&sin, '\0', sizeof(nn::socket::SockAddrIn));
            sin.sin_family = nn::socket::Family::Af_Inet;
            sin.sin_port = 80;
            sin.sin_addr.S_addr = s_NintendoIpAddress;
            TestGetNameInfoInternal(false, sin);
        }

        TEST(Resolver, GetHostErrorStringInRange)
        {
            for (int index = static_cast<int>(nn::socket::HErrno::Netdb_Internal); index < static_cast<int>(nn::socket::HErrno::No_Address); ++index )
            {
                const char* string = nn::socket::HStrError(static_cast<nn::socket::HErrno>(index));
                if ( string != NULL )
                {
                    NN_SDK_LOG("hstrerror(%d): %s\n", index, string);
                }
                else
                {
                    ADD_FAILURE();
                };
            };
        };

        TEST(Resolver, GetHostErrorStringOutOfRange) // mostly a 'dontcrash' test
        {
            for (int index = -255; index <255; ++index )
            {
                const char* string = nn::socket::HStrError(static_cast<nn::socket::HErrno>(index));
                if ( string == NULL )
                {
                    ADD_FAILURE();
                };
            };
        };

        TEST(Resolver, GetGaiErrorStringInRange)
        {
            for (unsigned index = 0; index < static_cast<int>(nn::socket::AiErrno::EAi_Max); ++index )
            {
                const char* string = nn::socket::GAIStrError(static_cast<nn::socket::AiErrno>(index));
                if ( string != NULL )
                {
                    NN_SDK_LOG("gai_strerror(%d): %s\n", index, string);
                }
                else
                {
                    ADD_FAILURE();
                };
            };
        };

        TEST(Resolver, GetGaiErrorStringOutOfRange)
        {
            for (int index = -255; index <255; ++index )
            {
                const char* string = nn::socket::GAIStrError(static_cast<nn::socket::AiErrno>(index));
                if ( string == NULL )
                {
                    ADD_FAILURE();
                };
            };
        };


        /*
         * @NOTE:
         * to avoid polluting the test, this really ought to be the last test
         */


        class ThreadContext
        {
            static const int s_StackSize = 65536;
            nn::os::ThreadType m_Thread;
            void* m_ThreadStack;

            static void s_ConcurrentResolverTestFunction(void* argument)
            {
                for (int idx=0; idx<s_MaxRequestsPerThread; idx++)
                {
                    int negativeOne = -1;
                    TestGetHostByNameInternal(false, s_HostName, (unsigned int) negativeOne);
                }
            };

        public:
            ThreadContext()
            {
                NN_SDK_LOG("Creating thread: %p\n", this);
                m_ThreadStack = memalign(nn::os::ThreadStackAlignment, s_StackSize);

                NN_ASSERT( true == (nn::os::CreateThread( &m_Thread,
                                                          s_ConcurrentResolverTestFunction,
                                                          NULL,
                                                          m_ThreadStack,
                                                          s_StackSize,
                                                          nn::os::DefaultThreadPriority )).IsSuccess(),
                           "Cannot create thread." );

                nn::os::StartThread( &m_Thread );

            };

            ~ThreadContext()
            {
                free(m_ThreadStack);
            };

            void Wait()
            {
                nn::os::WaitThread( &m_Thread );
            };

            void Destroy()
            {
                nn::os::DestroyThread( &m_Thread );
            };

        };

        TEST(Resolver, ConcurrentResolverTests)
        {
            nn::Result result;

            // creates all the threads, starts them,
            ThreadContext* threads[s_MaxThreads] = { NULL };

            for (int idx=0; idx<s_MaxThreads; ++idx)
            {
                threads[idx] = new ThreadContext();
            };

            for (int idx=0; idx<s_MaxThreads; ++idx)
            {
                NN_SDK_LOG("Waiting on thread: %p\n", &threads[idx]);
                threads[idx]->Wait();
                threads[idx]->Destroy();
            };

            for (int idx=0; idx<s_MaxThreads; ++idx)
            {
                delete threads[idx];
                threads[idx] = NULL;
            };
        };

    } //resolver
} //nnt

namespace
{
    NN_ALIGNAS(4096) uint8_t g_SocketMemoryPoolBuffer[nn::socket::DefaultSocketMemoryPoolSize];
}

extern "C" void nninitStartup()
{    // メモリヒープの全体サイズを設定する
    const size_t MemoryHeapSize = 16 * 1024 * 1024;
    auto result = nn::os::SetMemoryHeapSize( MemoryHeapSize + nn::socket::DefaultSocketMemoryPoolSize );

    NN_ASSERT( result.IsSuccess() );

    // メモリヒープから malloc で使用するメモリ領域を確保
    uintptr_t address = 0;

    result = nn::os::AllocateMemoryBlock( &address, MemoryHeapSize );
    NN_ASSERT( result.IsSuccess() );

    // malloc 用のメモリ領域を設定する
    nn::init::InitializeAllocator( reinterpret_cast<void*>(address), MemoryHeapSize );
}


//---------------------------------------------------------------------------
//  Test Main 関数
//---------------------------------------------------------------------------

extern "C"
void nnMain()
{
    int     argc = nnt::GetHostArgc();
    char **argv = nnt::GetHostArgv();

    nn::socket::Initialize(reinterpret_cast<void*>(g_SocketMemoryPoolBuffer),
                           nn::socket::DefaultSocketMemoryPoolSize,
                           nn::socket::DefaultSocketAllocatorSize,
                           nn::socket::DefaultConcurrencyLimit);

    // GoogleTest おまじない
    ::testing::InitGoogleTest(&argc, argv);
    int result = RUN_ALL_TESTS();

    // テスト終了
    NN_LOG("\n=== End Test of resolver tests\n");

    nnt::Exit(result);
}

