﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include "crypto_EccP256Math.h"
#include <nn/nn_SdkAssert.h>
#include <nn/crypto/detail/crypto_BigNum.h>

namespace nn { namespace crypto { namespace detail {

/* the bigint implementation stores the least significant words first */
static const BigNum::Digit p[ECC_P256_BIGINT_DIGITS] =
{
    0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000,
    0x00000000, 0x00000000, 0x00000001, 0xFFFFFFFF,
};

/* parameters for NIST curve on field Fp */
/*
#define SEED_SIZE 5
static BigNum::Digit seed[SEED_SIZE] =
{
    0x819f7e90, 0x139d26b7, 0x6a6678e1, 0x86e70493, 0xc49d3608,
};

static BigNum::Digit c[ECC_P256_BIGINT_DIGITS] =
{
    0x0104fa0d, 0xaf317768, 0xc5114abc, 0xce8d84a9,
    0x75d4f7e0, 0x03cb055c, 0x2985be94, 0x7efba166,
};
*/
const p256_ec_parameter p256_named_parameter =
{
    {/* curve */
        {/* a4 = 3 */
            0x00000003, 0x00000000, 0x00000000, 0x00000000,
            0x00000000, 0x00000000, 0x00000000, 0x00000000,
        },
        {/* a6 = b */
            0x27d2604b, 0x3bce3c3e, 0xcc53b0f6, 0x651d06b0,
            0x769886bc, 0xb3ebbd55, 0xaa3a93e7, 0x5ac635d8,
        },
    },
    {
        {/* gx */
            0xd898c296, 0xf4a13945, 0x2deb33a0, 0x77037d81,
            0x63a440f2, 0xf8bce6e5, 0xe12c4247, 0x6b17d1f2,
        },
        {/* gy */
            0x37bf51f5, 0xcbb64068, 0x6b315ece, 0x2bce3357,
            0x7c0f9e16, 0x8ee7eb4a, 0xfe1a7f9b, 0x4fe342e2,
        },
    },
    {/* order */
        0xFC632551, 0xF3B9CAC2, 0xA7179E84, 0xBCE6FAAD,
        0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF,
    },
    {/* cofactor */
        0x00000001, 0x00000000, 0x00000000, 0x00000000,
        0x00000000, 0x00000000, 0x00000000, 0x00000000,
    },
};

/*
  return res = a mod p
*/
static void p256_mod(BigNum::Digit* r, BigNum::Digit* a, int a_digits)
{
    BigNum::Digit res[ECC_P256_BIGINT_DIGITS];
    BigNum::Digit ap[2 * ECC_P256_BIGINT_DIGITS];
    BigNum::Digit x[ECC_P256_BIGINT_DIGITS];
    NN_SDK_ASSERT(a_digits <= 2 * ECC_P256_BIGINT_DIGITS);
    BigNum::SetZero(ap, 16);
    BigNum::Copy(ap, a, a_digits);

    // res = t + 2 s1 + 2 s2 + s3 + s4 - d1 - d2 - d3 - d4 mod p

    // t
    res[0] = ap[0]; res[1] = ap[1]; res[2] = ap[2]; res[3] = ap[3];
    res[4] = ap[4]; res[5] = ap[5]; res[6] = ap[6]; res[7] = ap[7];

    // s1
    x[0] = 0;      x[1] = 0;      x[2] = 0;      x[3] = ap[11];
    x[4] = ap[12]; x[5] = ap[13]; x[6] = ap[14]; x[7] = ap[15];

    BigNum::AddMod(res, res, x, p, ECC_P256_BIGINT_DIGITS);
    BigNum::AddMod(res, res, x, p, ECC_P256_BIGINT_DIGITS);

    // s2
    x[0] = 0;      x[1] = 0;      x[2] = 0;      x[3] = ap[12];
    x[4] = ap[13]; x[5] = ap[14]; x[6] = ap[15]; x[7] = 0;

    BigNum::AddMod(res, res, x, p, ECC_P256_BIGINT_DIGITS);
    BigNum::AddMod(res, res, x, p, ECC_P256_BIGINT_DIGITS);

    // s3
    x[0] = ap[8];  x[1] = ap[9];  x[2] = ap[10]; x[3] = 0;
    x[4] = 0;      x[5] = 0;      x[6] = ap[14]; x[7] = ap[15];

    BigNum::AddMod(res, res, x, p, ECC_P256_BIGINT_DIGITS);

    // s4
    x[0] = ap[9];  x[1] = ap[10]; x[2] = ap[11]; x[3] = ap[13];
    x[4] = ap[14]; x[5] = ap[15]; x[6] = ap[13]; x[7] = ap[8];

    BigNum::AddMod(res, res, x, p, ECC_P256_BIGINT_DIGITS);

    // d1
    x[0] = ap[11]; x[1] = ap[12]; x[2] = ap[13]; x[3] = 0;
    x[4] = 0;      x[5] = 0;      x[6] = ap[8];  x[7] = ap[10];

    BigNum::SubMod(res, res, x, p, ECC_P256_BIGINT_DIGITS);

    // d2
    x[0] = ap[12]; x[1] = ap[13]; x[2] = ap[14]; x[3] = ap[15];
    x[4] = 0;      x[5] = 0;      x[6] = ap[9];  x[7] = ap[11];

    BigNum::SubMod(res, res, x, p, ECC_P256_BIGINT_DIGITS);

    // d3
    x[0] = ap[13]; x[1] = ap[14]; x[2] = ap[15]; x[3] = ap[8];
    x[4] = ap[9];  x[5] = ap[10]; x[6] = 0;      x[7] = ap[12];

    BigNum::SubMod(res, res, x, p, ECC_P256_BIGINT_DIGITS);

    // d4
    x[0] = ap[14]; x[1] = ap[15]; x[2] = 0;      x[3] = ap[9];
    x[4] = ap[10]; x[5] = ap[11]; x[6] = 0;      x[7] = ap[13];

    BigNum::SubMod(res, res, x, p, ECC_P256_BIGINT_DIGITS);

    BigNum::Copy(r, res, ECC_P256_BIGINT_DIGITS);
}

// Addition in elliptic curve. One must verify beforehand that a and b are elements
// of the curve! If a and b are elements of the curve, the result will be an element of
// the curve.
void p256_add(p256_point* c, const p256_point* a, const p256_point* b, const p256_curve* curv)
{
    BigNum::Digit s[ECC_P256_BIGINT_DIGITS];
    BigNum::Digit t[2 * ECC_P256_BIGINT_DIGITS];
    BigNum::Digit x[2 * ECC_P256_BIGINT_DIGITS];
    /*--------------------------------------------------------------------*
       Definition of addition over an elliptic curve
       A (ax, ay) + B (bx, by) = C (cx, cy)

       cx = k^2 - ax - bx
       cy = k * (ax - cx) - a1

       where

       if (A == B) k = (3 * ax * ax - 3) * (2 * ay)^(-1)
       if (A != B) k = (by - ay) * (bx - ax)^(-1)
     *--------------------------------------------------------------------*/

    // point at infinity is the neutral element for the group law
    if (BigNum::IsZero(a->x, ECC_P256_BIGINT_DIGITS) && BigNum::IsZero(a->y, ECC_P256_BIGINT_DIGITS))
    {
        BigNum::Copy(c->x, b->x, ECC_P256_BIGINT_DIGITS);
        BigNum::Copy(c->y, b->y, ECC_P256_BIGINT_DIGITS);
        return;
    }
    if (BigNum::IsZero(b->x, ECC_P256_BIGINT_DIGITS) && BigNum::IsZero(b->y, ECC_P256_BIGINT_DIGITS))
    {
        BigNum::Copy(c->x, a->x, ECC_P256_BIGINT_DIGITS);
        BigNum::Copy(c->y, a->y, ECC_P256_BIGINT_DIGITS);
        return;
    }

    if (BigNum::Compare(a->x, b->x, ECC_P256_BIGINT_DIGITS) == 0)
    {
        // if a = -b, then a+b=O
        if (BigNum::Compare(a->y, b->y, ECC_P256_BIGINT_DIGITS) != 0)
        {
            BigNum::SetZero(c->x, ECC_P256_BIGINT_DIGITS);
            BigNum::SetZero(c->y, ECC_P256_BIGINT_DIGITS);
            return;
        }

        //doubling: prepare x and s such that k = x / s
        // x <- ax ^ 2
        BigNum::Mult(x, a->x, a->x, ECC_P256_BIGINT_DIGITS);
        p256_mod(x, x, 2 * ECC_P256_BIGINT_DIGITS);

        // s <- 3
        BigNum::SetZero(s, ECC_P256_BIGINT_DIGITS);
        s[0] = 3;

        // x <- 3 * ax ^ 2
        BigNum::Mult(t, x, s, ECC_P256_BIGINT_DIGITS);
        p256_mod(x, t, 2 * ECC_P256_BIGINT_DIGITS);

        // x <- 3 * ax ^ 2 - a4
        BigNum::SubMod(x, x, curv->a4, p, ECC_P256_BIGINT_DIGITS);

        // s <- 2 * ay
        BigNum::AddMod(s, a->y, a->y, p, ECC_P256_BIGINT_DIGITS);
    }
    else
    {
        // x <- by-ay, s <- bx-ax
        BigNum::SubMod(x, b->y, a->y, p, ECC_P256_BIGINT_DIGITS);
        BigNum::SubMod(s, b->x, a->x, p, ECC_P256_BIGINT_DIGITS);
    }

    // s <- k
    BigNum::ModInv(s, s, p, ECC_P256_BIGINT_DIGITS);
    BigNum::Mult(t, x, s, ECC_P256_BIGINT_DIGITS);
    p256_mod(s, t, 2 * ECC_P256_BIGINT_DIGITS);

    // cx <- k^2 - ax - bx
    BigNum::Mult(x, s, s, ECC_P256_BIGINT_DIGITS);
    p256_mod(x, x, 2 * ECC_P256_BIGINT_DIGITS);
    BigNum::SubMod(x, x, a->x, p, ECC_P256_BIGINT_DIGITS);
    BigNum::SubMod(x, x, b->x, p, ECC_P256_BIGINT_DIGITS);

    // do a copy of a->x for the case where c = a
    BigNum::Copy(t, a->x, ECC_P256_BIGINT_DIGITS);
    BigNum::Copy(c->x, x, ECC_P256_BIGINT_DIGITS);

    // x <- ax - cx
    BigNum::SubMod(x, t, c->x, p, ECC_P256_BIGINT_DIGITS);

    // cy <- (ax - cx)*k - ay
    BigNum::Mult(t, x, s, ECC_P256_BIGINT_DIGITS);
    p256_mod(t, t, 2 * ECC_P256_BIGINT_DIGITS);
    BigNum::SubMod(c->y, t, a->y, p, ECC_P256_BIGINT_DIGITS);
}

/* res <- n * a, simple version susceptible to timing attacks.
   a and res must not alias. */
void p256_mul(p256_point* res, const BigNum::Digit* n, int n_digits, const p256_point* a,
              const p256_curve* curv)
{
    static const BigNum::Digit HIGH_BIT_MASK = 1U << (BigNum::DigitBits - 1);

    BigNum::Digit window = HIGH_BIT_MASK;
    int bits = n_digits * BigNum::DigitBits - 1;
    const BigNum::Digit* cur = n + n_digits - 1;

    if (BigNum::IsZero(n, n_digits))
    {
        BigNum::SetZero(res->x, ECC_P256_BIGINT_DIGITS);
        BigNum::SetZero(res->y, ECC_P256_BIGINT_DIGITS);
        return;
    }

    BigNum::Copy(res->x, a->x, ECC_P256_BIGINT_DIGITS);
    BigNum::Copy(res->y, a->y, ECC_P256_BIGINT_DIGITS);
    NN_SDK_ASSERT(p256_validate(res, curv));

    while (*cur == 0)
    {
        bits -= BigNum::DigitBits;
        cur--;
    }

    while (!(window & *cur))
    {
        window >>= 1;
        bits--;
    }
    window >>= 1;
    if (window == 0)
    {
        window = HIGH_BIT_MASK;
        cur--;
    }

    while (bits--)
    {
        p256_add(res, res, res, curv);
        NN_SDK_ASSERT(p256_validate(res, curv));

        if (*cur & window)
        {
            p256_add(res, res, a, curv);
            NN_SDK_ASSERT(p256_validate(res, curv));
        }

        if (bits % BigNum::DigitBits == 0)
        {
            cur--;
        }

        window >>= 1;
        if (window == 0)
        {
            window = HIGH_BIT_MASK;
        }
    }
}

int p256_validate(const p256_point* a, const p256_curve* c)
{
    BigNum::Digit x[2 * ECC_P256_BIGINT_DIGITS];
    BigNum::Digit t[2 * ECC_P256_BIGINT_DIGITS];
    BigNum::Digit y[2 * ECC_P256_BIGINT_DIGITS];

    // point at infinity
    if (BigNum::IsZero(a->x, ECC_P256_BIGINT_DIGITS) &&
        BigNum::IsZero(a->y, ECC_P256_BIGINT_DIGITS))
    {
        return 1;
    }

    // x <- ax^3
    BigNum::Mult(t, a->x, a->x, ECC_P256_BIGINT_DIGITS);
    p256_mod(t, t, 2 * ECC_P256_BIGINT_DIGITS);
    BigNum::Mult(x, t, a->x, ECC_P256_BIGINT_DIGITS);
    p256_mod(x, x, 2 * ECC_P256_BIGINT_DIGITS);

    // t <- a4 * ax
    BigNum::Mult(t, c->a4, a->x, ECC_P256_BIGINT_DIGITS);
    p256_mod(t, t, 2 * ECC_P256_BIGINT_DIGITS);

    // x <- ax^3 - a4 * ax + a6
    BigNum::SubMod(x, x, t, p, ECC_P256_BIGINT_DIGITS);
    BigNum::AddMod(x, x, c->a6, p, ECC_P256_BIGINT_DIGITS);

    // y <- ay^2
    BigNum::Mult(y, a->y, a->y, ECC_P256_BIGINT_DIGITS);
    p256_mod(y, y, 2 * ECC_P256_BIGINT_DIGITS);

    return (BigNum::Compare(x, y, ECC_P256_BIGINT_DIGITS) == 0);
}

/*************** TEST *********************/

/*
static void test_mod(BigNum::Digit* t, int digits)
{
    BigNum::Digit res1[ECC_P256_BIGINT_DIGITS];
    BigNum::Digit res2[ECC_P256_BIGINT_DIGITS];

    p256_mod(res1, t, digits);
    BigNum::Mod(res2, t, digits, p, ECC_P256_BIGINT_DIGITS);
    if (BigNum::Compare(res1, res2, ECC_P256_BIGINT_DIGITS) == 0)
    {
        printf ("MOD -> PASS\n");
    }
    else
    {
        printf ("MOD -> FAIL\n");
    }
}

static void bigint_test_random(BigNum::Digit *r, int digits)
{
    int i;
    for (i = 0; i < digits; i++)
        r[i] = rand();
}

void test_ecc_p256_mod()
{
    //test 0 and p^2 - 1
    int i;
    BigNum::Digit x[2 * ECC_P256_BIGINT_DIGITS] = {0};
    BigNum::Digit one[2 * ECC_P256_BIGINT_DIGITS] = {1};
    BigNum::Digit random[2 * ECC_P256_BIGINT_DIGITS] = {0};
    test_mod(x, 16);
    test_mod(x, 3);

    BigNum::Mult(x, p, p, ECC_P256_BIGINT_DIGITS);
    BigNum::Sub(x, x, one, 2 * ECC_P256_BIGINT_DIGITS);
    test_mod(x, 2 * ECC_P256_BIGINT_DIGITS);

    for (i = 0; i < 100; i++)
    {
        // test random modulos
        bigint_test_random(random, 2 * ECC_P256_BIGINT_DIGITS);
        BigNum::Mod(random, random, 2 * ECC_P256_BIGINT_DIGITS, x, 2 * ECC_P256_BIGINT_DIGITS);
        test_mod(random, 2 * ECC_P256_BIGINT_DIGITS);
    }
}

static void test_validate(p256_point* p)
{
    if (p256_validate(p, &(p256_named_parameter.par_curve)))
    {
        printf("VALIDATE -> PASS\n");
    }
    else
    {
        printf("VALIDATE -> FAIL\n");
    }
}

void test_ecc_p256_points_in_curve()
{
    p256_point temp;
    BigNum::Digit r;
    int i;

    test_validate(&(p256_named_parameter.par_point));

    for (i = 0; i < 5; i++)
    {
        r = rand();
        p256_mul(&temp, &r, 1, &(p256_named_parameter.par_point), &(p256_named_parameter.par_curve));
        test_validate(&temp);
    }
}

void print_point(p256_point *p)
{
    int i;
    printf("x -> ");
    for (i = ECC_P256_BIGINT_DIGITS - 1; i >= 0; i--)
        printf("%x ", p->x[i]);
    printf("\ny -> ");
    for (i = ECC_P256_BIGINT_DIGITS - 1; i >= 0; i--)
        printf("%x ", p->y[i]);
    printf("\n");
}

void test_eq(const char *s, p256_point *p1, p256_point *p2)
{
    if (BigNum::Compare(p1->x, p2->x, ECC_P256_BIGINT_DIGITS) == 0 &&
        BigNum::Compare(p1->y, p2->y, ECC_P256_BIGINT_DIGITS) == 0)
    {
        printf("%s -> PASS\n", s);
    }
    else
    {
        printf("%s -> FAIL\n", s);
        print_point(p1);
        print_point(p2);
    }
}

void test_group()
{
    p256_point temp;
    p256_point temp2;
    BigNum::Digit zero = 0;
    BigNum::Digit one = 1;
    BigNum::Digit two = 2;
    BigNum::Digit fortytwo = 42;
    BigNum::Digit i;

    BigNum::SetZero(temp2.x, ECC_P256_BIGINT_DIGITS);
    BigNum::SetZero(temp2.y, ECC_P256_BIGINT_DIGITS);

    p256_mul(&temp, &zero, 1, &(p256_named_parameter.par_point), &(p256_named_parameter.par_curve));
    test_eq("MUL0", &temp, &temp2);

    p256_mul(&temp, &one, 1, &(p256_named_parameter.par_point), &(p256_named_parameter.par_curve));
    test_eq("MUL1", &temp, &(p256_named_parameter.par_point));

    p256_mul(&temp, &two, 1, &(p256_named_parameter.par_point), &(p256_named_parameter.par_curve));
    p256_add(&temp2, &(p256_named_parameter.par_point), &(p256_named_parameter.par_point), &(p256_named_parameter.par_curve));
    test_eq("MUL2", &temp, &temp2);

    p256_mul(&temp, &fortytwo, 1, &(p256_named_parameter.par_point), &(p256_named_parameter.par_curve));
    BigNum::SetZero(temp2.x, ECC_P256_BIGINT_DIGITS);
    BigNum::SetZero(temp2.y, ECC_P256_BIGINT_DIGITS);
    for (i = 0; i < fortytwo; ++i)
    {
        p256_add(&temp2, &temp2, &(p256_named_parameter.par_point), &(p256_named_parameter.par_curve));
    }
    test_eq("MUL42", &temp, &temp2);

    p256_mul(&temp, p256_named_parameter.point_order, ECC_P256_BIGINT_DIGITS, &(p256_named_parameter.par_point), &(p256_named_parameter.par_curve));

    BigNum::SetZero(temp2.x, ECC_P256_BIGINT_DIGITS);
    BigNum::SetZero(temp2.y, ECC_P256_BIGINT_DIGITS);

    test_eq("GROUP", &temp, &temp2);
}
*/

}}} // namespace nn::crypto::detail
