﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <cstring>
#include <nn/nn_SdkAssert.h>
#include <nn/crypto/detail/crypto_BigNum.h>
#include <nn/crypto/detail/crypto_BigNumOnStack.h>
#include <nn/crypto/detail/crypto_Clear.h>
#include "crypto_BigNumMath.h"

namespace nn { namespace crypto { namespace detail {

BigNum::BigNum() NN_NOEXCEPT
{
    m_Digits = nullptr;
    m_DigitsCount = 0;
    m_Capacity = 0;
    m_Negative = 0;
    m_Flags = 0;
}

BigNum::~BigNum() NN_NOEXCEPT
{
    // Free m_Digits if dynamically allocated
}

bool BigNum::Set(const void* pSrc, size_t srcLen) NN_NOEXCEPT
{
    NN_SDK_ASSERT( pSrc || !srcLen );   // Ok to call with (nullptr, 0)

    const uint8_t* lead = static_cast<const uint8_t*>(pSrc);
    while ( srcLen && *lead == 0 )
    {
        ++lead;
        --srcLen;
    }

    NN_SDK_ASSERT(srcLen <= m_Capacity * sizeof(Digit) );
    if ( srcLen > m_Capacity * sizeof(Digit))
    {
        return false;
    }
    m_DigitsCount = (srcLen + sizeof(Digit) - 1) / sizeof(Digit);

    DigitsFromOctetString( m_Digits, m_Capacity, lead, srcLen );  // OK to call with srcLen == 0

    return true;
}

void BigNum::Get(void* pDst, size_t dstLen) const NN_NOEXCEPT
{
    NN_SDK_ASSERT( dstLen >= GetSize() );
    OctetStringFromDigits( pDst, dstLen, m_Digits, m_DigitsCount );
}

size_t BigNum::GetSize() const NN_NOEXCEPT
{
    if ( m_DigitsCount <= 0 )
    {
        return 0;
    }
    size_t nBytes = m_DigitsCount * sizeof(Digit);
    Digit w = m_Digits[m_DigitsCount - 1];
    NN_SDK_ASSERT(nBytes >= 4); // Because m_DigitsCount >= 1 above
    NN_SDK_ASSERT(sizeof(Digit) == 4); // Explicit tests byte by byte below
    if ( w >= (1 << 24) )
    {
        return nBytes;
    }
    if ( w >= (1 << 16) )
    {
        return nBytes - 1;
    }
    if ( w >= (1 << 8) )
    {
        return nBytes - 2;
    }
    NN_SDK_ASSERT(w);   // High word of 0 would mean m_DigitsCount is wrong
    return nBytes - 3;
}

void BigNum::ReserveStatic( Digit* buffer, int maxDigits ) NN_NOEXCEPT
{
    m_Digits = buffer;
    m_Capacity = maxDigits;
}

bool BigNum::ModExp( void* outBlock, const void* inBlock, const BigNum& exp,
                     size_t blockSize, uint32_t* pWorkBuffer, size_t workBufferSize ) const NN_NOEXCEPT
{
    if (exp.IsZero() || IsZero())
    {
        return false;
    }
    NN_SDK_ASSERT( blockSize == GetSize() );

    BigNum::DigitAllocator allocator;
    allocator.Initialize(pWorkBuffer, static_cast<int>(workBufferSize / sizeof(Digit)));

    BigNum sig;
    const int sigDigitCount = static_cast<int>(blockSize / sizeof(Digit));
    Digit* pSigBuffer = allocator.AllocateDigits(sigDigitCount);
    if (pSigBuffer == nullptr)
    {
        return false;
    }
    sig.ReserveStatic(pSigBuffer, sigDigitCount);
    if ( !sig.Set( inBlock, blockSize ) )
    {
        return false;
    }

    bool ret = ModExp( sig.m_Digits, sig.m_Digits,
                       exp.m_Digits, static_cast<int>(exp.m_DigitsCount),
                       this->m_Digits, static_cast<int>(this->m_DigitsCount),
                       &allocator);
    if (ret)
    {
        sig.Recount();
        sig.Get( outBlock, blockSize );
    }

    allocator.FreeDigits(pSigBuffer, sigDigitCount);

    // wipe out contents on work memory
    ClearMemory(pWorkBuffer, allocator.GetUsedSize());

    return ret;
}

void BigNum::Cleanse() NN_NOEXCEPT
{
    std::memset( m_Digits, 0, m_DigitsCount );
}

void BigNum::Recount() NN_NOEXCEPT
{
    m_DigitsCount = GetDigits(m_Digits, static_cast<int>(m_Capacity));
}

}}} // namespace nn::crypto::detail
