﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include "NetTest_Port.h"
#include "Modules/SslHttpsDownloadModule.h"
#include <nn/ssl.h>

#include <cstdio>
#include <cstdlib>
#include <cstring>

namespace
{
    const uint32_t         BufferSize    = 1024 * 2;
    const char* const      g_pHttpCmd    = "GET";
    const char* const      g_pHttpVers   = "HTTP/1.0";
    const uint32_t         PollTimeoutMs = 10000;
} // un-named namespace

namespace NATF {
namespace Modules {

    // Constructor
    SslHttpsDownload::SslHttpsDownload(bool doInit, const char* pHostName, unsigned short port, const char* pResource, bool printResource, const MD5Hash::Result& expectedHash, bool isBlockingIo, nn::ssl::Connection::VerifyOption verifyOption) NN_NOEXCEPT
        :
        m_doInit(doInit),
        m_pHostName(pHostName),
        m_port(port),
        m_pResource(pResource),
        m_printResource(printResource),
        m_expectedHash(expectedHash),
        m_isBlocking(isBlockingIo),
        m_verifyOption(verifyOption) {}

    // Init
    bool SslHttpsDownload::Init() NN_NOEXCEPT
    {
        if( m_doInit )
        {
            Log("Initializing SSL library...\n");
            nn::Result result = nn::ssl::Initialize();
            if(result.IsFailure())
            {
                LogError("Failed to initialize SSL library.\n");
                return false;
            }
        }

        return true;
    }

    // CleanUp
    void SslHttpsDownload::CleanUp() NN_NOEXCEPT
    {
        if( m_doInit )
        {
            nn::ssl::Finalize();
        }
    }

    // SendHttpRequest
    bool SslHttpsDownload::SendHttpRequest(nn::ssl::Connection& sslConnection, const char* pResource, const char* pHostName) const NN_NOEXCEPT
    {
        bool isSuccess = true;
        unsigned totalWritten;
        static const unsigned MaxRequestSize = 1024;
        char pGetRequest[MaxRequestSize];

        // Copy http request into write buffer.
        NETTEST_SNPRINTF(pGetRequest, MaxRequestSize, "%s /%s %s\r\nHost: %s\r\n\r\n", g_pHttpCmd, pResource, g_pHttpVers, pHostName);
        size_t requestSize = strlen(pGetRequest);
        Log("HTTP Request: %s", pGetRequest);

        // While sending a buffer, only a portion may have successfully sent.
        // You may have to call the write funtion a few times to send any remaining data.
        totalWritten = 0;
        do
        {
            int rval;

            Log("Writing data...\n");
            int nBytesWritten = rval = sslConnection.Write(&pGetRequest[totalWritten], (unsigned)requestSize - totalWritten);
            if( rval <= 0 )
            {
                LogError("Error: NSSLWrite: rval: %d\n", rval);
                isSuccess = false;
                goto out;
            }
            else // Write was successful(or at least part), keep track of how much has been sent.
            {
                totalWritten += nBytesWritten;
            }
        } while( totalWritten < requestSize );

    out:

        return isSuccess;
    }

    // CreateSocketWithOptions
    int SslHttpsDownload::CreateSocketWithOptions() const NN_NOEXCEPT
    {
        int  socketFd  = -1;
        int  rval      = -1;
        bool doCleanup = false;
        int  value     = -1;
        NetTest::SockLen paramLen;

        // Create socket for SSL connection.
        socketFd = NetTest::Socket(nn::socket::Family::Af_Inet, nn::socket::Type::Sock_Stream, nn::socket::Protocol::IpProto_Tcp);
        if (socketFd < 0)
        {
            LogError("Error: Failed to create client socket. rval: %d, SOErr:%d\n", socketFd, NetTest::GetLastError());
            doCleanup = true;
            goto out;
        }

        value = 1024 * 64 - 1;
        rval = NetTest::SetSockOpt(socketFd, nn::socket::Level::Sol_Socket, nn::socket::Option::So_RcvBuf, &value, sizeof(value));
        if (rval < 0)
        {
            Log("WARNING: Failed to set recv buffer! rval: %d, errno: %d\n\n", rval, NetTest::GetLastError());
        }

        paramLen = sizeof(int);
        rval = NetTest::GetSockOpt(socketFd, nn::socket::Level::Sol_Socket, nn::socket::Option::So_RcvBuf, &value, &paramLen);
        if (rval < 0)
        {
            Log("WARNING: Failed to get recv buffer! rval: %d, errno: %d\n\n", rval, NetTest::GetLastError());
        }
        else
        {
            Log("nn::socket::Option::So_RcvBuf: %d\n", value);
        }

    out:

        if( doCleanup )
        {
            if( socketFd >= 0 )
            {
                rval = NetTest::Close(socketFd);
                if( rval != 0 )
                {
                    LogError("Error: Failed to close socket. rval: %d, SOErr: %d\n", rval, NetTest::GetLastError());
                }
                else
                {
                    socketFd = -1;
                }
            }
        }

        return socketFd;
    }

    void SslHttpsDownload::InetNtoP(int32_t ipAddr, char pOutIp[MaxIpStrLen]) const
    {
        uint8_t* pData = reinterpret_cast<uint8_t*>(&ipAddr);

        NETTEST_SNPRINTF(pOutIp, MaxIpStrLen, "%d.%d.%d.%d", static_cast<int>(pData[0]),
                                                             static_cast<int>(pData[1]),
                                                             static_cast<int>(pData[2]),
                                                             static_cast<int>(pData[3]));
    }

    // ConnectToServer
    bool SslHttpsDownload::ConnectToServer(int socketFd, const char* pHostName, unsigned short portNum) const NN_NOEXCEPT
    {
        bool isSuccess   = true;
        int  rval        = -1;
        NetTest::HostEnt* pHostEnt = nullptr;
        nn::socket::InAddr inetAddr;
        char pIp[MaxIpStrLen];
        NetTest::SockAddrIn serverAddr;

        memset(&serverAddr, 0, sizeof(serverAddr));

        if( !pHostName )
        {
            LogError("Error: pHostName is NULL\n");
            isSuccess = false;
            goto out;
        }

        Log("Resolving %s\n", pHostName);
        pHostEnt = NetTest::GetHostByName(pHostName);

        if(pHostEnt == nullptr)
        {
            LogError("Failed to resolve host name.\n");
            isSuccess = false;
            goto out;
        }

        // Just pick the first one
        memcpy(&inetAddr, pHostEnt->h_addr_list[0], sizeof(inetAddr));
        InetNtoP(inetAddr.S_addr, pIp);
        Log("Server IP: %s\n", pIp);

        // Initialize server address struct
        serverAddr.sin_family      = nn::socket::Family::Af_Inet;
        serverAddr.sin_port        = NetTest::Htons(portNum);
        serverAddr.sin_addr.S_addr = inetAddr.S_addr;

        // Print server ip.
        Log("Connecting to %s\n", pHostName);

        rval = NetTest::Connect(socketFd, (NetTest::SockAddr *)&serverAddr, sizeof(serverAddr));
        if( rval < 0 )
        {
            LogError("Error: Connect: rval: %d SOErr: %d\n", rval, NetTest::GetLastError());
            isSuccess = false;
            goto out;
        }

    out:

        return isSuccess;
    }

    // ParseHeader
    char* SslHttpsDownload::ParseHeader(char* pBuffer, unsigned bufSize, bool& isLastNewLine, bool& isHeaderComplete) const NN_NOEXCEPT
    {
        if( isLastNewLine && pBuffer[0] == '\n' )
        {
            isHeaderComplete = true;

            if( bufSize > 1 )
            {
                return &pBuffer[1];
            }
            else
            {
                return NULL;
            }
        }

        unsigned i = 0;
        unsigned newLinesCount = 0;
        for(; i < bufSize; ++i)
        {
            if( pBuffer[i] == '\n' || pBuffer[i] == '\r' )
            {
                ++newLinesCount;
                if( newLinesCount >= 4 )
                {
                    isHeaderComplete = true;
                    if( i >= bufSize - 1 )
                    {
                        return NULL;
                    }
                    else
                    {
                        return &pBuffer[i + 1];
                    }
                }
            }
            else
            {
                newLinesCount = 0;
            }
        }

        if( newLinesCount )
        {
            isLastNewLine = true;
        }

        return NULL;
    }

    // ReceiveResponse
    bool SslHttpsDownload::ReceiveResponse(nn::ssl::Connection& sslConnection, MD5Hash::Result& hashResult) const NN_NOEXCEPT
    {
        int rval;
        bool isSuccess        = true;
        unsigned totalRead    = 0;
        char pBuffer[BufferSize + 1];
        bool isHeaderComplete = false;
        bool isLastNewLine    = false;
        bool hasPollTimedout  = false;
        int64_t milSec        = 0;
        MD5Hash md5Hash;
        float throughput;
        NetTest::Tick start;
        NetTest::Tick end;
        NetTest::Time duration;

        // Read the response from server
        Log("Reading header...\n\n");
        start = NetTest::GetTick();
        do
        {
            nn::ssl::Connection::PollEvent pollInEvent  = nn::ssl::Connection::PollEvent::PollEvent_Read;
            nn::ssl::Connection::PollEvent pollOutEvent = nn::ssl::Connection::PollEvent::PollEvent_None;

            nn::Result result = sslConnection.Poll(&pollOutEvent, &pollInEvent, PollTimeoutMs);
            if( result.IsFailure() )
            {
                LogError("Error: Poll: errno: %d\n", NetTest::GetLastError());
                isSuccess = false;
                goto out;
            }
            if( nn::ssl::ResultIoTimeout::Includes(result) )
            {
                Log("Poll: Timed out!\n");
                hasPollTimedout = true;
                break;
            }
            else if( (pollOutEvent & nn::ssl::Connection::PollEvent::PollEvent_Read) !=
                      nn::ssl::Connection::PollEvent::PollEvent_Read )
            {
                Log("Warning: Ssl connection not readable after poll!\n");
                break;
            }

            // Socket is now readable, read the data.
            int bytesRead = rval = sslConnection.Read(pBuffer, BufferSize);
            if( rval == 0 )
            {
                NN_LOG("\n");
                Log("Recv: Connection has been closed. rval: %d\n", rval);
                break;
            } else if( rval < 0 )
            {
                LogError("Error: Recv: rval: %d errno: %d\n", rval, NetTest::GetLastError());
                isSuccess = false;
                goto out;
            }

            // If bytes were read
            if( bytesRead > 0 )
            {
                totalRead += bytesRead;
                char* pData = pBuffer;
                if( !isHeaderComplete )
                {
                    pData = ParseHeader(pBuffer, bytesRead, isLastNewLine, isHeaderComplete);

                    // If there was also data in this recv buf,
                    //  only print header portion.
                    if( pData )
                    {
                        NN_LOG("%.*s", pData - pBuffer, pBuffer);
                    }
                    else
                    {
                        NN_LOG("%.*s", bytesRead, pBuffer);
                    }

                    if( isHeaderComplete )
                    {
                        Log("Header download complete!\n");
                        Log("Reading data...\n\n");
                    }
                }

                if( pData )
                {
                    // Continue calculating the MD5 hash value.
                    md5Hash.Update((unsigned char*)pData, bytesRead - (int)(pData - pBuffer));

                    // Print read bytes.
                    if( m_printResource )
                    {
                        NN_LOG("%.*s", bytesRead, pBuffer);
                    }
                }
            }
        } while( rval > 0 );
        NN_LOG("\n");

        end = NetTest::GetTick();
        duration = NetTest::TickToTime(end - start);

        Log("Finalizing MD5 Hash...\n");
        md5Hash.Final(hashResult);

        // Calculate throughput
        //  Don't count poll timeout against us.
        milSec = duration.GetMilliSeconds() - static_cast<int64_t>(hasPollTimedout) * PollTimeoutMs;
        throughput = totalRead * 8.0f / (milSec / 1000.0f) / 1000000.0f;
        Log("Total bytes: %d, Throughput: %d.%d Mbits/Sec\n", (int)totalRead, (int)throughput, (int)(throughput * 100.0f) % 100 );

    out:

        return isSuccess;
    }

    bool SslHttpsDownload::ConfigSslContext(nn::ssl::Context* pOutContext) const NN_NOEXCEPT
    {
        nn::Result result = pOutContext->Create(nn::ssl::Context::SslVersion_Auto);
        if(result.IsFailure())
        {
            LogError("Failed to create context.\n");
            return false;
        }

        nn::ssl::SslConnectionId contextId;
        pOutContext->GetContextId(&contextId);
        if(result.IsFailure())
        {
            LogError("Failed to get context ID\n");
            pOutContext->Destroy();
            return false;
        }

        Log("Created SSL context (ID:%d).\n", contextId);
        return true;
    }

    bool SslHttpsDownload::ConfigSslConnection(nn::ssl::Connection* pOutConnection, nn::ssl::Context* pInContext, int socketFd, const char* pHostName, char* pServerCertBuff, uint32_t serverCertBuffLen, nn::ssl::Connection::VerifyOption verifyOption) const NN_NOEXCEPT
    {
        bool isSuccess = false;
        size_t hostNameLen = 0;

        nn::Result result = pOutConnection->Create(pInContext);
        if(result.IsFailure())
        {
            LogError("pOutConnection->Create failed.\n");
            return false;
        }

        result = pOutConnection->SetSocketDescriptor(socketFd);
        if(result.IsFailure())
        {
            LogError("pOutConnection->SetSocketDescriptor failed.\n");
            goto cleanup;
        }

        hostNameLen = strlen(pHostName);
        result = pOutConnection->SetHostName(pHostName, (uint32_t)hostNameLen);
        if(result.IsFailure())
        {
            LogError("pOutConnection->SetHostName failed.\n");
            goto cleanup;
        }

        result = pOutConnection->SetOption(nn::ssl::Connection::OptionType_SkipDefaultVerify, true);
        if(result.IsFailure())
        {
            LogError("pOutConnection->SetOption (OptionType_SkipDefaultVerify) failed.\n");
            goto cleanup;
        }

        result = pOutConnection->SetVerifyOption(verifyOption);
        if(result.IsFailure())
        {
            LogError("pOutConnection->SetVerifyOption failed.\n");
            goto cleanup;
        }

        // Set buffer to get peer certificate
        result = pOutConnection->SetServerCertBuffer(pServerCertBuff, serverCertBuffLen);
        if(result.IsFailure())
        {
            LogError("SetServerCertBuffer failed.\n");
            goto cleanup;
        }

        nn::ssl::Connection::IoMode ioMode;
        if( m_isBlocking )
        {
            ioMode = nn::ssl::Connection::IoMode_Blocking;
        }
        else
        {
            ioMode = nn::ssl::Connection::IoMode_NonBlocking;
        }

        result = pOutConnection->SetIoMode(ioMode);
        if(result.IsFailure())
        {
            LogError("SetIoMode failed.\n");
            goto cleanup;
        }

        result = pOutConnection->GetIoMode(&ioMode);
        if(result.IsFailure())
        {
            LogError("GetIoMode failed.\n");
            goto cleanup;
        }
        Log("IoMode: %d\n", ioMode);

        isSuccess = true;

    cleanup:
        if( !isSuccess )
        {
            pOutConnection->Destroy();
        }

        return isSuccess;
    }

    // GetServerPage
    bool SslHttpsDownload::GetServerPage(const char* pHostName, unsigned short serverPort, const char* pResource, MD5Hash::Result& hashResult) const NN_NOEXCEPT
    {
        bool isSuccess = true;
        int rval       = -1;
        int socketFd   = -1;
        nn::Result result;
        nn::ssl::Context sslContext;
        nn::ssl::Connection sslConnection;
        bool isContextInited = false;
        bool isConnectionInited = false;
        static const uint32_t CertBuffLen = 1024 * 4;
        char pServerCertBuff[CertBuffLen] = {0};

        if( !pResource )
        {
            pResource = "";
        }

        // Create socket with custom options
        Log("Creating socket...\n");
        socketFd = CreateSocketWithOptions();
        if( socketFd < 0 )
        {
            isSuccess = false;
            goto out;
        }

        // Connect socket to server
        Log("Connecting socket to server...\n");
        if( !ConnectToServer(socketFd, pHostName, serverPort) )
        {
            isSuccess = false;
            goto out;
        }

        Log("Configuring SSL context...\n");
        if( !ConfigSslContext(&sslContext) )
        {
            isSuccess = false;
            goto out;
        }
        isContextInited = true;

        Log("Configuring SSL connection...\n");
        if( !ConfigSslConnection(&sslConnection, &sslContext, socketFd, pHostName, pServerCertBuff, CertBuffLen, m_verifyOption) )
        {
            isSuccess = false;
            goto out;
        }
        isConnectionInited = true;

        Log("Doing SSL handshake...\n");
        result = sslConnection.DoHandshake();
        if(result.IsFailure())
        {
            LogError("Ssl hand shake failed! Desc: %d\n\n", result.GetDescription());
            isSuccess = false;
            goto out;
        }
        Log("SSL Handshake completed.\n");

        // Send HTTP request through SSL connection
        Log("Sending Http request...\n");
        if( !SendHttpRequest(sslConnection, pResource, pHostName) )
        {
            isSuccess = false;
            goto out;
        }

        // Receive response from our HTTP request
        Log("Receiving Http request...\n");
        if( !ReceiveResponse(sslConnection, hashResult) )
        {
            isSuccess = false;
            goto out;
        }

    out:

        if( isConnectionInited )
        {
            sslConnection.Destroy();
        }
        else
        {
            if( socketFd >= 0 )
            {
                // Close socket
                rval = NetTest::Close(socketFd);
                if( rval != 0 )
                {
                    Log("Warning: Failed to close socket. rval: %d, errno: %d\n\n", rval, NetTest::GetLastError());
                }
                else
                {
                    socketFd = -1;
                }
            }
        }

        if( isContextInited )
        {
            sslContext.Destroy();
        }

        return isSuccess;
    }

    // Run
    bool SslHttpsDownload::Run() NN_NOEXCEPT
    {
        bool isSuccess = true;
        MD5Hash::Result hashResult;

        if( !Init() )
        {
            return false;
        }

        if( !GetServerPage(m_pHostName, m_port, m_pResource, hashResult) )
        {
            isSuccess = false;
            goto cleanup;
        }

        Log("Hash: %.2x %.2x %.2x %.2x %.2x %.2x %.2x %.2x %.2x %.2x %.2x %.2x %.2x %.2x %.2x %.2x\n"
            , hashResult.m_pHash[0]
            , hashResult.m_pHash[1]
            , hashResult.m_pHash[2]
            , hashResult.m_pHash[3]
            , hashResult.m_pHash[4]
            , hashResult.m_pHash[5]
            , hashResult.m_pHash[6]
            , hashResult.m_pHash[7]
            , hashResult.m_pHash[8]
            , hashResult.m_pHash[9]
            , hashResult.m_pHash[10]
            , hashResult.m_pHash[11]
            , hashResult.m_pHash[12]
            , hashResult.m_pHash[13]
            , hashResult.m_pHash[14]
            , hashResult.m_pHash[15] );

        if( memcmp(hashResult.m_pHash, m_expectedHash.m_pHash, sizeof(hashResult.m_pHash)) == 0 )
        {
            Log( "MD5 has matches as expected.\n" );
        }
        else
        {
            Log( " Error: MD5 hash does NOT match expected.\n\n" );
            isSuccess = false;
            goto cleanup;
        }

cleanup:
        CleanUp();
        return isSuccess;
    }

    // GetName
    const char* SslHttpsDownload::GetName() const NN_NOEXCEPT
    { return "SslHttpsDownload"; }
}
}
