﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <nn/os.h>
#include <nn/crypto/crypto_Sha256Generator.h>
#include <nn/ndd/ndd_Types.h>

class PayloadUtil
{
public:
    static const size_t HashSize = nn::crypto::Sha256Generator::HashSize;
    static const size_t SeedSize = 1;
    struct Payload
    {
        //Lengthフィールドは無し。サイズの異常変化はhashで検出
        uint8_t seed;
        uint8_t data[nn::ndd::SendDataSizeMax - HashSize - SeedSize];
        uint8_t hash[HashSize];
    };
    void Generate(uint8_t seed, size_t payloadSize);
    bool Import(const void* payload, size_t payloadSize);
    const void* GetPtr() const;
    size_t GetSize() const;

private:
    Payload m_Payload;
    size_t m_DataSize;
    void GetHash(uint8_t hash[], size_t size);
    void GetHash(uint8_t hash[], size_t hashSize, const void* pData, size_t dataSize) const;
};

void PayloadUtil::GetHash(uint8_t hash[], size_t size)
{
    nn::crypto::Sha256Generator sha256Generator;
    sha256Generator.Initialize();
    sha256Generator.Update(&m_Payload.data[0], m_DataSize);
    sha256Generator.GetHash(&hash[0], size);
}

void PayloadUtil::GetHash(uint8_t hash[], size_t hashSize, const void* pData, size_t dataSize) const
{
    nn::crypto::Sha256Generator sha256Generator;
    sha256Generator.Initialize();
    sha256Generator.Update(pData, dataSize);
    sha256Generator.GetHash(&hash[0], hashSize);
}

void PayloadUtil::Generate(uint8_t seed, size_t payloadSize)
{
    NN_ABORT_UNLESS(payloadSize > HashSize + SeedSize);
    NN_ABORT_UNLESS(payloadSize <= nn::ndd::SendDataSizeMax);

    //seed
    m_Payload.seed = seed;

    //data
    auto dataSize = payloadSize - HashSize - SeedSize;
    memset(&m_Payload.data[0], seed, dataSize);
    m_DataSize = dataSize;

    //hash
    uint8_t hash[HashSize];
    GetHash(&hash[0], HashSize);
    memcpy(&m_Payload.hash[0], &hash[0], HashSize);
}

bool PayloadUtil::Import(const void* pPayload, size_t payloadSize)
{
    NN_ABORT_UNLESS(payloadSize > HashSize + SeedSize);
    NN_ABORT_UNLESS(payloadSize <= nn::ndd::SendDataSizeMax);

    const auto pCandidatePayload = reinterpret_cast<const PayloadUtil::Payload*>(pPayload);
    const auto candidateDataSize = payloadSize - HashSize - SeedSize;

    //seed
    //サイズ固定のため、チェック無し

    //data
    for(int i=0;i<candidateDataSize;++i)
    {
        if(pCandidatePayload->data[i] != pCandidatePayload->seed)
        {
            return false;
        }
    }

    //hash
    uint8_t hash[HashSize];
    GetHash(&hash[0], HashSize, &pCandidatePayload->data[0], candidateDataSize);
    auto cmpResult = memcmp(&pCandidatePayload->hash[0], &hash[0], HashSize);
    if(cmpResult != 0)
    {
        return false;
    }

    m_Payload = *pCandidatePayload;
    m_DataSize = candidateDataSize;
    return true;
}

const void* PayloadUtil::GetPtr() const
{
    return &m_Payload;
}

size_t PayloadUtil::GetSize() const
{
    return SeedSize + m_DataSize + HashSize;
}

