﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <algorithm>
#include <memory>
#include <nn/nn_Common.h>
#include <nn/ldn/ldn_Result.h>
#include <nn/ldn/detail/Advertise/ldn_AdvertiseDirector.h>
#include <nn/ldn/detail/Advertise/ldn_AdvertiseManager.h>
#include <nn/ldn/detail/Advertise/ldn_Aes128CtrSha256Advertise.h>
#include <nn/ldn/detail/Advertise/ldn_PlainSha256Advertise.h>
#include <nn/ldn/detail/Debug/ldn_Log.h>
#include <nn/ldn/detail/Utility/ldn_Crypto.h>
#include <nn/ldn/detail/Utility/ldn_ReverseByteOrder.h>
#include <nn/os/os_Tick.h>
#include <nn/result/result_HandlingUtility.h>

namespace nn { namespace ldn { namespace detail { namespace
{
    // 鍵を生成するために使用するパラメータです。
    const Bit8 KekSourceForAdvertiseKey[] =
    {
        0xC3, 0xC0, 0x78, 0xFA, 0x13, 0x20, 0x6A, 0x7B,
        0xF5, 0x69, 0xD5, 0x19, 0x4F, 0x83, 0xE1, 0x99
    };

    /**
     * @brief           アドバータイズの暗号化と復号化に使用する鍵を生成します。
     * @param[in]       pOutAesKey      鍵の出力先です。
     * @param[in]       networkId       ネットワーク・バイトオーダのネットワーク識別子です。
     */
    void CreateKey(Bit8 (&pOutAesKey)[16], const NetworkId& networkId) NN_NOEXCEPT
    {
        // NetworkID から Advertise 用のセッション鍵を生成します。
        GenerateAesKey(
            pOutAesKey, &networkId, sizeof(NetworkId),
            KekSourceForAdvertiseKey, sizeof(KekSourceForAdvertiseKey));

        // 生成されたキーをダンプします。
        NN_LDN_LOG_DEBUG("Advertise Key : ");
        for (size_t i = 0; i < 16; ++i)
        {
            NN_LDN_LOG_DEBUG_WITHOUT_PREFIX("%02X ", pOutAesKey[i]);
        }
        NN_LDN_LOG_DEBUG_WITHOUT_PREFIX("\n");
    }

    /**
     * @brief           アドバータイズを構築するビルダーを生成します。
     * @param[in]       networkId       ネットワークの識別子です。
     * @param[in]       security        セキュリティ・モードです。
     * @param[in]       buffer          パーサーに供給するバッファです。
     * @param[in]       bufferSize      バッファのバイトサイズです。
     * @return          スキャン結果の解析に適したパーサーです。
     */
    IAdvertiseBuilder* CreateBuilder(
        const NetworkId& networkId, SecurityMode security,
        void* buffer, size_t bufferSize) NN_NOEXCEPT
    {
        switch (security)
        {
        case SecurityMode_Product:
        case SecurityMode_Debug:
            Bit8 key[16];
            CreateKey(key, ConvertToNetworkByteOrder(networkId));
            return new Aes128CtrSha256AdvertiseBuilder(buffer, bufferSize, key);
        case SecurityMode_SystemDebug:
            return new PlainSha256AdvertiseBuilder();
        default:
            return nullptr;
        }
    }

    /**
     * @brief           スキャン結果の解析に適切なパーサーを生成します。
     * @param[in]       result          スキャン結果です。
     * @param[in]       buffer          パーサーに供給するバッファです。
     * @param[in]       bufferSize      バッファのバイトサイズです。
     * @return          スキャン結果の解析に適したパーサーです。
     */
    IAdvertiseParser* CreateParser(
        const L2ScanResult& result, void* buffer, size_t bufferSize) NN_NOEXCEPT
    {
        if (AdvertiseSizeMin <= result.dataSize)
        {
            const auto& header = *reinterpret_cast<const AdvertiseHeader*>(result.data);
            switch (header.format)
            {
            case AdvertiseFormat_PlainSha256:
                return new PlainSha256AdvertiseParser(buffer, bufferSize);
            case AdvertiseFormat_Aes128CtrSha256:
                Bit8 key[16];
                CreateKey(key, header.networkId);
                return new AesCtr128Sha256AdvertiseParser(buffer, bufferSize, key);
            default:
                break;
            }
        }
        return nullptr;
    }

    /**
     * @brief           スキャン結果の解析に適切なパーサーを生成します。
     * @param[in]       networkId       ネットワークの識別子です。
     * @param[in]       security        セキュリティ・モードです。
     * @param[in]       buffer          パーサーに供給するバッファです。
     * @param[in]       bufferSize      バッファのバイトサイズです。
     * @return          スキャン結果の解析に適したパーサーです。
     */
    IAdvertiseParser* CreateParser(
        const NetworkId& networkId, SecurityMode security,
        void* buffer, size_t bufferSize) NN_NOEXCEPT
    {
        switch (security)
        {
        case SecurityMode_Product:
        case SecurityMode_Debug:
            Bit8 key[16];
            CreateKey(key, ConvertToNetworkByteOrder(networkId));
            return new AesCtr128Sha256AdvertiseParser(buffer, bufferSize, key);
        case SecurityMode_SystemDebug:
            return new PlainSha256AdvertiseParser(buffer, bufferSize);
        default:
            return nullptr;
        }
    }

    /**
     * @brief           指定されたネットワークが条件を満たすか判定します。
     * @param[in]       network         対象のネットワークです。
     * @param[in]       filter          フィルタリングの条件です。
     * @return          @a network が @a filter で指定された条件を満たしていれば true です。
     */
    bool IsFilterPassed(const NetworkInfo& network, const ScanFilter& filter) NN_NOEXCEPT
    {
        const auto& common    = network.common;
        const auto& networkId = network.networkId;

        #define NN_LDN_FILTER_REQUIRES(name, pred) do {\
            if ((filter.flag & (name)) == 0 || (pred)) { } else { return false; }\
        } while (NN_STATIC_CONDITION(false))

        // ネットワーク種別によるフィルタです。
        NN_LDN_FILTER_REQUIRES(ScanFilterFlag_NetworkType,
            common.networkType == filter.networkType);

        // BSSID によるフィルタです。
        NN_LDN_FILTER_REQUIRES(ScanFilterFlag_Bssid,
            common.bssid == filter.bssid);

        // SSID によるフィルタです。
        NN_LDN_FILTER_REQUIRES(ScanFilterFlag_Ssid,
            IsValidSsid(common.ssid) && common.ssid == filter.ssid);

        // ローカル通信識別子によるフィルタです。
        NN_LDN_FILTER_REQUIRES(ScanFilterFlag_LocalCommunicationId,
            common.networkType == NetworkType_Ldn &&
            networkId.intentId.localCommunicationId ==
            filter.networkId.intentId.localCommunicationId);

        // シーン識別子によるフィルタです。
        NN_LDN_FILTER_REQUIRES(ScanFilterFlag_SceneId,
            common.networkType == NetworkType_Ldn &&
            networkId.intentId.sceneId == filter.networkId.intentId.sceneId);

        // SessionId によるフィルタです。
        NN_LDN_FILTER_REQUIRES(ScanFilterFlag_SessionId,
            common.networkType == NetworkType_Ldn &&
            networkId.sessionId == filter.networkId.sessionId);

        #undef NN_LDN_FILTER_REQUIRES
        return true;
    }

}}}} // namespace nn::ldn::detail::<unnamed>

namespace nn { namespace ldn { namespace detail
{
    AdvertiseDistributor::AdvertiseDistributor() NN_NOEXCEPT
        : m_Buffer(nullptr),
          m_pBuilder(nullptr),
          m_pNetworkInterface(nullptr)
    {
    }

    AdvertiseDistributor::~AdvertiseDistributor() NN_NOEXCEPT
    {
        if (m_Buffer != nullptr)
        {
            Finalize();
        }
    }

    size_t AdvertiseDistributor::GetRequiredBufferSize() NN_NOEXCEPT
    {
        return sizeof(impl::AdvertiseDistributorBuffer);
    }

    void AdvertiseDistributor::Initialize(
        void* buffer, size_t bufferSize, INetworkInterface* pInterface) NN_NOEXCEPT
    {
        NN_SDK_ASSERT(m_Buffer == nullptr);
        NN_SDK_ASSERT(m_pNetworkInterface == nullptr);
        NN_SDK_ASSERT_NOT_NULL(buffer);
        NN_SDK_ASSERT(GetRequiredBufferSize() <= bufferSize);
        NN_SDK_ASSERT_NOT_NULL(pInterface);
        NN_UNUSED(bufferSize);
        m_pNetworkInterface = pInterface;
        m_Buffer = static_cast<impl::AdvertiseDistributorBuffer*>(buffer);
    }

    void AdvertiseDistributor::Finalize() NN_NOEXCEPT
    {
        NN_SDK_ASSERT_NOT_NULL(m_Buffer);
        NN_SDK_ASSERT_NOT_NULL(m_pNetworkInterface);

        // 配信中であれば停止します。
        if (m_pBuilder != nullptr)
        {
            StopDistribution();
        }

        // リソースを解放します。
        m_Buffer = nullptr;
        m_pNetworkInterface = nullptr;
    }

    Result AdvertiseDistributor::StartDistribution(
        Version version, const NetworkId& networkId,
        SecurityMode security, const LdnNetworkInfo& ldn) NN_NOEXCEPT
    {
        NN_SDK_ASSERT(m_pBuilder == nullptr);
        NN_SDK_ASSERT_NOT_NULL(m_Buffer);
        NN_SDK_ASSERT_NOT_NULL(m_pNetworkInterface);

        // SecurityMode に応じて適切な Builder を生成します。
        m_pBuilder = CreateBuilder(
            networkId, security, m_Buffer->builder, sizeof(m_Buffer->builder));
        NN_SDK_ASSERT_NOT_NULL(m_pBuilder);

        // Advertise で配信するデータを生成します。
        m_pBuilder->SetNetworkId(networkId);
        m_pBuilder->SetVersion(version);
        Result result = SetAdvertiseData(ldn);
        if (result.IsFailure())
        {
            delete m_pBuilder;
            m_pBuilder = nullptr;
        }
        return result;
    }

    Result AdvertiseDistributor::Update(const LdnNetworkInfo& ldn) NN_NOEXCEPT
    {
        NN_SDK_ASSERT_NOT_NULL(m_pBuilder);
        NN_SDK_ASSERT_NOT_NULL(m_Buffer);
        NN_SDK_ASSERT_NOT_NULL(m_pNetworkInterface);
        NN_RESULT_DO(SetAdvertiseData(ldn));
        NN_RESULT_SUCCESS;
    }

    Result AdvertiseDistributor::StopDistribution() NN_NOEXCEPT
    {
        NN_SDK_ASSERT_NOT_NULL(m_pBuilder);
        NN_SDK_ASSERT_NOT_NULL(m_Buffer);
        NN_SDK_ASSERT_NOT_NULL(m_pNetworkInterface);

        // Builder を解放します。
        delete m_pBuilder;
        m_pBuilder = nullptr;

        // Advertise の配信を停止します。
        NN_RESULT_DO(m_pNetworkInterface->SetBeaconData(nullptr, 0));
        NN_RESULT_SUCCESS;
    }

    Result AdvertiseDistributor::SetAdvertiseData(const LdnNetworkInfo& ldn) NN_NOEXCEPT
    {
        // アドバータイズで配信するデータのボディ部分を構築します。
        size_t bodySize;
        auto& body = m_Buffer->body;
        CreateAdvertiseBody(body, &bodySize, sizeof(body), ldn);
        NN_SDK_ASSERT(bodySize <= AdvertiseBodySizeMax);

        // アドバータイズを構築します。
        size_t advertiseSize;
        auto& advertise = m_Buffer->advertise;
        m_pBuilder->Build(advertise, &advertiseSize, sizeof(advertise), body, sizeof(body));

        // Advertise の配信を開始します。
        NN_RESULT_DO(m_pNetworkInterface->SetBeaconData(advertise, advertiseSize));
        NN_RESULT_SUCCESS;
    }

    AdvertiseScanner::AdvertiseScanner() NN_NOEXCEPT
      : m_Buffer(nullptr),
        m_pNetworkInterface(nullptr),
        m_ScanResultCountMax(0)
    {
    }

    AdvertiseScanner::~AdvertiseScanner() NN_NOEXCEPT
    {
        if (m_Buffer)
        {
            Finalize();
        }
    }

    size_t AdvertiseScanner::GetRequiredBufferSize(int scanResultCount) NN_NOEXCEPT
    {
        return (4 * 1024) + (sizeof(L2ScanResult) * scanResultCount);
    }

    void AdvertiseScanner::Initialize(
        void* buffer, size_t bufferSize, int scanResultCount,
        INetworkInterface* pInterface) NN_NOEXCEPT
    {
        NN_SDK_ASSERT(m_Buffer == nullptr);
        NN_SDK_ASSERT(m_pNetworkInterface == nullptr);
        NN_SDK_ASSERT_NOT_NULL(buffer);
        NN_SDK_ASSERT(GetRequiredBufferSize(scanResultCount) <= bufferSize);
        NN_SDK_ASSERT_MINMAX(scanResultCount, 1, ScanResultCountMax);
        NN_SDK_ASSERT_NOT_NULL(pInterface);
        NN_UNUSED(bufferSize);
        m_Buffer = static_cast<impl::AdvertiseScannerBuffer*>(buffer);
        m_ScanResultCountMax = scanResultCount;
        m_pNetworkInterface = pInterface;
    }

    void AdvertiseScanner::Finalize() NN_NOEXCEPT
    {
        NN_SDK_ASSERT_NOT_NULL(m_Buffer);
        NN_SDK_ASSERT_NOT_NULL(m_pNetworkInterface);
        m_Buffer = nullptr;
        m_ScanResultCountMax = 0;
        m_pNetworkInterface = nullptr;
    }

    Result AdvertiseScanner::Scan(
        NetworkInfo* pOutScanResultArray, int* pOutCount, int bufferCount,
        const ScanFilter& filter, int channel) NN_NOEXCEPT
    {
        NN_SDK_ASSERT_NOT_NULL(m_Buffer);
        NN_SDK_ASSERT_NOT_NULL(m_pNetworkInterface);
        NN_SDK_ASSERT_NOT_NULL(pOutScanResultArray);
        NN_SDK_ASSERT_NOT_NULL(pOutCount);
        NN_SDK_ASSERT_MINMAX(bufferCount, 1, m_ScanResultCountMax);

        // スキャン結果を 0 で初期化しておきます。
        *pOutCount = 0;
        std::memset(pOutScanResultArray, 0, sizeof(NetworkInfo) * bufferCount);

        // 周辺のネットワークをスキャンします。
        int count;
        NN_RESULT_DO(m_pNetworkInterface->Scan(
            m_Buffer->l2ScanResult, &count, m_ScanResultCountMax, channel));

        // スキャン結果を解析します。
        for (int i = 0; i < count && *pOutCount < bufferCount; ++i)
        {
            const auto& scanResult = m_Buffer->l2ScanResult[i];
            auto& out = pOutScanResultArray[*pOutCount];
            out.common.bssid = scanResult.bssid;
            out.common.ssid = scanResult.ssid;
            out.common.channel = scanResult.channel;
            out.common.linkLevel = static_cast<int8_t>(
                m_pNetworkInterface->ConvertRssiToLinkLevel(scanResult.rssi));
            if (0U < scanResult.dataSize)
            {
                out.common.networkType = NetworkType_Ldn;
                std::unique_ptr<IAdvertiseParser> pParser(CreateParser(
                    scanResult, m_Buffer->parser, sizeof(m_Buffer->parser)));
                if (pParser)
                {
                    auto result = pParser->Parse(scanResult.data, scanResult.dataSize);
                    if (result == AdvertiseParserResult_Success)
                    {
                        out.networkId = pParser->GetNetworkId();
                        auto data = pParser->GetData();
                        auto dataSize = pParser->GetDataSize();
                        if (AnalyzeAdvertiseBody(&out.ldn, data, dataSize) &&
                            IsFilterPassed(out, filter))
                        {
                            *pOutCount += 1;
                        }
                    }
                }
            }
            else
            {
                out.common.networkType = NetworkType_General;
                if (IsFilterPassed(out, filter))
                {
                    *pOutCount += 1;
                }
            }
        }
        NN_LDN_LOG_INFO("Found %d networks and %d networks passed filter\n", count, *pOutCount);
        NN_RESULT_SUCCESS;
    }

    AdvertiseMonitor::AdvertiseMonitor() NN_NOEXCEPT
      : m_Buffer(nullptr),
        m_pNetworkInterface(nullptr),
        m_pParser(nullptr)
    {
    }

    AdvertiseMonitor::~AdvertiseMonitor() NN_NOEXCEPT
    {
        if (m_Buffer)
        {
            Finalize();
        }
    }

    size_t AdvertiseMonitor::GetRequiredBufferSize() NN_NOEXCEPT
    {
        return sizeof(impl::AdvertiseMonitorBuffer);
    }

    void AdvertiseMonitor::Initialize(
        void* buffer, size_t bufferSize, INetworkInterface* pInterface) NN_NOEXCEPT
    {
        NN_SDK_ASSERT(m_Buffer == nullptr);
        NN_SDK_ASSERT(m_pNetworkInterface == nullptr);
        NN_SDK_ASSERT_NOT_NULL(buffer);
        NN_SDK_ASSERT(GetRequiredBufferSize() <= bufferSize);
        NN_SDK_ASSERT_NOT_NULL(pInterface);
        NN_UNUSED(bufferSize);
        m_pNetworkInterface = pInterface;
        m_Buffer = static_cast<impl::AdvertiseMonitorBuffer*>(buffer);
        NetworkInterfaceProfile profile;
        m_pNetworkInterface->GetNetworkInterfaceProfile(&profile);
        m_MacAddress = profile.physicalAddress;
    }

    void AdvertiseMonitor::Finalize() NN_NOEXCEPT
    {
        NN_SDK_ASSERT_NOT_NULL(m_Buffer);
        NN_SDK_ASSERT_NOT_NULL(m_pNetworkInterface);

        // アドバータイズを監視中であれば停止します。
        if (m_pParser != nullptr)
        {
            StopMonitoring();
        }

        // リソースを解放します。
        m_Buffer = nullptr;
        m_pNetworkInterface = nullptr;
    }

    Result AdvertiseMonitor::StartMonitoring(
        int* pOutAid, LdnNetworkInfo* pOutInfo,
        const NetworkId& networkId, SecurityMode security) NN_NOEXCEPT
    {
        NN_SDK_ASSERT_NOT_NULL(m_Buffer);
        NN_SDK_ASSERT_NOT_NULL(m_pNetworkInterface);
        NN_SDK_ASSERT(m_pParser == nullptr);
        NN_SDK_ASSERT_NOT_NULL(pOutAid);
        NN_SDK_ASSERT_NOT_NULL(pOutInfo);

        // アドバータイズの解析に使用する Parser を作成します。
        m_pParser = CreateParser(networkId, security, m_Buffer->parser, sizeof(m_Buffer->parser));
        NN_SDK_ASSERT_NOT_NULL(m_pParser);
        m_NetworkId = networkId;

        // 最初のアドバータイズの受信まで待機します。
        const TimeSpan timeout = TimeSpan::FromMilliSeconds(1000);
        nn::os::Tick start = nn::os::GetSystemTick();
        nn::os::Tick now = start;
        auto diff = (now - start).ToTimeSpan();
        while (diff < timeout && GetReceivedEvent().TimedWait(timeout - diff))
        {
            if (GetAdvertise(pOutInfo))
            {
                for (int i = 1; i < NodeCountMax; ++i)
                {
                    const auto& node = pOutInfo->nodes[i];
                    if (node.isConnected && node.macAddress == m_MacAddress)
                    {
                        *pOutAid = i;
                        return ResultSuccess();
                    }
                }
            }
            now = nn::os::GetSystemTick();
            diff = (now - start).ToTimeSpan();
        }

        // タイムアウト時間内に最初のアドバータイズを受信できませんでした。
        StopMonitoring();
        return ResultConnectionTimeout();
    }

    void AdvertiseMonitor::StopMonitoring() NN_NOEXCEPT
    {
        NN_SDK_ASSERT_NOT_NULL(m_Buffer);
        NN_SDK_ASSERT_NOT_NULL(m_pNetworkInterface);
        NN_SDK_ASSERT_NOT_NULL(m_pParser);

        // アドバータイズの解析に使用する Parser を破棄します。
        delete m_pParser;
        m_pParser = nullptr;
    }

    nn::os::SystemEvent& AdvertiseMonitor::GetReceivedEvent() NN_NOEXCEPT
    {
        NN_SDK_ASSERT_NOT_NULL(m_pNetworkInterface);
        return m_pNetworkInterface->GetBeaconReceivedEvent();
    }

    bool AdvertiseMonitor::GetAdvertise(LdnNetworkInfo* pOutInfo) NN_NOEXCEPT
    {
        NN_SDK_ASSERT_NOT_NULL(m_Buffer);
        NN_SDK_ASSERT_NOT_NULL(m_pNetworkInterface);
        NN_SDK_ASSERT_NOT_NULL(m_pParser);
        NN_SDK_ASSERT_NOT_NULL(pOutInfo);

        // 最後に受信したアドバータイズを取得します。
        auto& scanResult = m_Buffer->l2ScanResult;
        if (!m_pNetworkInterface->GetBeaconData(&scanResult))
        {
            return false;
        }

        // アドバータイズを復号します。
        auto result = m_pParser->Parse(scanResult.data, scanResult.dataSize);
        if (result != AdvertiseParserResult_Success)
        {
            return false;
        }
        else if (m_pParser->GetNetworkId() != m_NetworkId)
        {
            return false;
        }

        // アドバータイズのペイロード部分を解析します。
        return AnalyzeAdvertiseBody(pOutInfo, m_pParser->GetData(), m_pParser->GetDataSize());
    }

}}} // namespace nn::ldn::detail
