﻿/*
 *  Copyright 2005-2014 Acer Cloud Technology, Inc.
 *  All Rights Reserved.
 *
 *  This software contains confidential information and
 *  trade secrets of Acer Cloud Technology, Inc.
 *  Use, disclosure or reproduction is prohibited without
 *  the prior express written permission of Acer Cloud
 *  Technology, Inc.
 */

/*
 *               Copyright (C) 2010, BroadOn Communications Corp.
 *
 *  These coded instructions, statements, and computer programs contain
 *  unpublished  proprietary information of BroadOn Communications Corp.,
 *  and  are protected by Federal copyright law. They may not be disclosed
 *  to  third  parties or copied or duplicated in any form, in whole or in
 *  part, without the prior written consent of BroadOn Communications Corp.
 *
 */
#include <nnt/nntest.h>
#include <nn/nn_Log.h>

#include <nn/escore/escore.h>
#include <nn/ioscrypto/iosctypes.h>
#include <nnt/escoreUtil/testEscore_util_istorage.h>

USING_ESCORE_UTIL_NAMESPACE
USING_ES_NAMESPACE
USING_ISTORAGE_NAMESPACE

#include "../Common/testEs_Est_utils.h"

/* TODO: dynamically generate the following files */
#include "../Common/testEs_Ca_cert_new.cpp"
#include "../Common/testEs_Cp_cert_new.cpp"
#include "../Common/testEs_Xs_cert_new.cpp"

#include "../Common/testEs_Ca_cp_certs.cpp"
#include "../Common/testEs_Ca_xs_certs.cpp"

#include "../Common/testEs_Prod_pki_certs.cpp"

#include "../Common/testEs_Tkt6.cpp"

#include "../Common/testEs_Data3_enc.cpp"
#include "../Common/testEs_Data3_dec.cpp"



/* Certificates */
const int ES_TEST_CERT_CHAIN_LEN = 3; /* CA, XS, CP */

static const void *__certs[ES_TEST_CERT_CHAIN_LEN + 2] = {caNewCert, xsNewCert, cpNewCert, NULL, NULL};
static u32 __nCerts = ES_TEST_CERT_CHAIN_LEN;

/*
 * For negative test using a valid cert chain that just happens not to match
 */
const int ES_TEST_PROD_CERT_CHAIN_LEN = 8;
static void *__prodCerts[ES_TEST_PROD_CERT_CHAIN_LEN];
static u32 __prodCertSize[ES_TEST_PROD_CERT_CHAIN_LEN];
static u32 __nProdCerts = ES_TEST_PROD_CERT_CHAIN_LEN;


typedef enum
{
    ES_INVALID_CERT_SIGTYPE_CA,
    ES_INVALID_CERT_SIGTYPE_ISSUER,
    ES_INVALID_CERT_SIGTYPE_CONTAINER,
    ES_INVALID_CERT_SIG_CA,
    ES_INVALID_CERT_SIG_ISSUER,
    ES_INVALID_CERT_SIG_CONTAINER,
    ES_INVALID_CERT_MISSING_CA,
    ES_INVALID_CERT_MISSING_ISSUER,
    ES_INVALID_CERT_BAD_ISSUER
} ESInvalidCertType;

typedef enum
{
    ES_VERIFY_CERT_FUNC_ET_SET
} ESVerifyCertFunc;

static ESInvalidCertType __invalidCertTypes[] = {ES_INVALID_CERT_SIGTYPE_CA, ES_INVALID_CERT_SIGTYPE_ISSUER, ES_INVALID_CERT_SIGTYPE_CONTAINER,
                                                 ES_INVALID_CERT_SIG_CA, ES_INVALID_CERT_SIG_ISSUER, ES_INVALID_CERT_SIG_CONTAINER,
                                                 ES_INVALID_CERT_MISSING_CA, ES_INVALID_CERT_MISSING_ISSUER, ES_INVALID_CERT_BAD_ISSUER};

static ESVerifyCertFunc __verifyCertFuncs[] = {ES_VERIFY_CERT_FUNC_ET_SET};

static ESError __testSetup()
{
    return ES_ERR_OK;
}


static ESError __insertCertError( ESInvalidCertType type, ESVerifyCertFunc func )
{
    NN_UNUSED(func);
    ESError rv = ES_ERR_OK;

    switch( type )
    {
    case ES_INVALID_CERT_SIGTYPE_CA:
        caNewCert[0] = ~caNewCert[0];
        break;

    case ES_INVALID_CERT_SIGTYPE_ISSUER:
        {
            xsNewCert[0] = ~xsNewCert[0];
        }
        break;

    case ES_INVALID_CERT_SIGTYPE_CONTAINER:
        {
            tkt6[0] = ~tkt6[0];
        }
        break;

    case ES_INVALID_CERT_SIG_CA:
        caNewCert[sizeof( caNewCert ) - 1] = ~caNewCert[sizeof( caNewCert ) - 1];
        break;

    case ES_INVALID_CERT_SIG_ISSUER:
        {
            xsNewCert[sizeof( xsNewCert ) - 1] = ~xsNewCert[sizeof( xsNewCert ) - 1];
        }
        break;

    case ES_INVALID_CERT_SIG_CONTAINER:
        {
            tkt6[sizeof( tkt6 ) - 1] = ~tkt6[sizeof( tkt6 ) - 1];
        }
        break;

    case ES_INVALID_CERT_MISSING_CA:
        __certs[0] = xsNewCert;
        __certs[1] = cpNewCert;
        __nCerts = 2;
        break;

    case ES_INVALID_CERT_MISSING_ISSUER:

        {
            __certs[1] = cpNewCert;
            __nCerts = 2;
        }
        break;

    case ES_INVALID_CERT_BAD_ISSUER:
        {
            tkt6[336] = 'C';
            tkt6[337] = 'P';
            xsNewCert[388] = 'C';
            xsNewCert[389] = 'P';
        }
        break;

    default:
        ES_TEST_LOG( "Invalid type %u\n", type );
        rv = ES_ERR_INVALID;
        goto end;
    }

end:
    return rv;
}


static ESError __reverseCertError( ESInvalidCertType type, ESVerifyCertFunc func )
{
    ESError rv = ES_ERR_OK;

    if( type == ES_INVALID_CERT_BAD_ISSUER )
    {
        tkt6[336] = 'X';
        tkt6[337] = 'S';
        xsNewCert[388] = 'X';
        xsNewCert[389] = 'S';

        cpNewCert[388] = 'C';
        cpNewCert[389] = 'P';
    }
    else if( type == ES_INVALID_CERT_MISSING_CA || type == ES_INVALID_CERT_MISSING_ISSUER )
    {
        __certs[0] = caNewCert;
        __certs[1] = xsNewCert;
        __certs[2] = cpNewCert;
        __nCerts = 3;
    }
    else
    {
        /* The rest is the same as inserting the error */
        rv = __insertCertError( type, func );
    }

    return rv;
}


static ESError __getExpectedInvalidCertError( ESInvalidCertType type )
{
    ESError rv = ES_ERR_OK;

    switch( type )
    {
    case ES_INVALID_CERT_SIGTYPE_CA:
    case ES_INVALID_CERT_SIGTYPE_ISSUER:
        rv = ES_ERR_CERT_INCORRECT_SIG_TYPE;
        break;

    case ES_INVALID_CERT_SIGTYPE_CONTAINER:
        rv = ES_ERR_INCORRECT_SIG_TYPE;
        break;

    case ES_INVALID_CERT_SIG_CONTAINER:
        rv = ES_ERR_VERIFICATION;
        break;

    case ES_INVALID_CERT_SIG_CA:
    case ES_INVALID_CERT_SIG_ISSUER:
        rv = ES_ERR_CERT_VERIFICATION;
        break;

    case ES_INVALID_CERT_MISSING_CA:
    case ES_INVALID_CERT_MISSING_ISSUER:
        rv = ES_ERR_ISSUER_NOT_FOUND;
        break;

    case ES_INVALID_CERT_BAD_ISSUER:
        rv = ES_ERR_CERT_INVALID;
        break;

    default:
        ES_TEST_LOG( "Invalid type %u\n", type );
        rv = ES_ERR_INVALID;
        goto end;
    }

end:
    return rv;
}


static ESError __invokeCertError( ESInvalidCertType type, ESVerifyCertFunc func )
{
    ESError rv;

    ETicketService es;
    ETicket ticket;
    MemoryInputStream ticketStream( tkt6, sizeof( tkt6 ) );

    rv = __insertCertError( type, func );
    if( rv != ES_ERR_OK )
    {
        goto end;
    }

    switch( func )
    {
    case ES_VERIFY_CERT_FUNC_ET_SET:
        rv = ticket.Set( ticketStream, __certs, __nCerts, true );
        break;

    default:
        ES_TEST_LOG( "Invalid func %u\n", func );
        rv = ES_ERR_INVALID;
        goto end;
    }

    if( rv != __getExpectedInvalidCertError( type ) )
    {
        ES_TEST_LOG( "Failed to detect error, %d:%d\n", rv, __getExpectedInvalidCertError( type ) );
        rv = ES_ERR_FAIL;
        goto end;
    }

    rv = __reverseCertError( type, func );
    if( rv != ES_ERR_OK )
    {
        goto end;
    }

end:
    return rv;
}


static ESError __testInvalidCerts()
{
    ESError rv = ES_ERR_OK;
    u32 i, j;

    for( i = 0; i < sizeof( __invalidCertTypes ) / sizeof( __invalidCertTypes[0] ); i++ )
    {
        for( j = 0; j < sizeof( __verifyCertFuncs ) / sizeof( __verifyCertFuncs[0] ); j++ )
        {
            ES_TEST_LOG( "Testing invalid cert type %u, func %u\n", __invalidCertTypes[i], __verifyCertFuncs[j] );

            rv = __invokeCertError( __invalidCertTypes[i], __verifyCertFuncs[j] );
            if( rv != ES_ERR_OK )
            {
                goto end;
            }
        }
    }

end:
    return rv;
}

static ESError __testNonMatchingCerts()
{
    ESError rv = ES_ERR_OK;
    ETicket ticket;
    Certificate cCert;
    MemoryInputStream ticketStream( tkt6, sizeof( tkt6 ) );

    __nProdCerts = ES_TEST_PROD_CERT_CHAIN_LEN;
    rv = cCert.GetNumCertsInList( prodCerts, sizeof( prodCerts ), &__nProdCerts );
    if( rv != ES_ERR_OK )
    {
        __nProdCerts = 0;
        goto end;
    }

    if( __nProdCerts > ES_TEST_PROD_CERT_CHAIN_LEN )
    {
        ES_TEST_LOG( "Too many prod certs: %d\n", __nProdCerts );
        goto end;
    }

    rv = cCert.ParseCertList( prodCerts, sizeof( prodCerts ), __prodCerts, __prodCertSize, &__nProdCerts );
    if( rv != ES_ERR_OK )
    {
        __nProdCerts = 0;
        goto end;
    }

    rv = ticket.Set( ticketStream, (const void **)__prodCerts, __nProdCerts, true );
    if( rv != ES_ERR_ISSUER_NOT_FOUND )
    {
        ES_TEST_LOG( "Failed to detect mismatching cert list , rv=%d\n", rv );
        if( rv == ES_ERR_OK )
        {
            rv = ES_ERR_FAIL;
        }
        goto end;
    }

    rv = ES_ERR_OK;

end:
    return rv;
}

static ESError __testCleanup()
{
    return ES_ERR_OK;
}


TEST( VerifyTest, Verify )
{
    ESError rv = ES_ERR_OK;

    rv = __testSetup();
    EXPECT_EQ( rv, ES_ERR_OK );
    if( rv != ES_ERR_OK )
    {
        goto end;
    }

    rv = __testInvalidCerts();
    EXPECT_EQ( rv, ES_ERR_OK );
    if( rv != ES_ERR_OK )
    {
        goto end;
    }

    rv = __testNonMatchingCerts();
    EXPECT_EQ( rv, ES_ERR_OK );
    if( rv != ES_ERR_OK )
    {
        goto end;
    }

    rv = __testCleanup();
    EXPECT_EQ( rv, ES_ERR_OK );
    if( rv != ES_ERR_OK )
    {
        goto end;
    }

end:
    if( rv == ES_ERR_OK )
    {
        ES_TEST_LOG( "***** Passed verify tests *****\n" );
    }

    return;
}
