﻿/* _GUN_SOURCE for asprintf */
#ifndef _GNU_SOURCE
#define _GNU_SOURCE 1
#endif

#include <stdio.h>
#include <stdint.h>
#include "mtest.h"

int eulpf(float x)
{
    union { float f; uint32_t i; } u = { x };
    int e = u.i>>23 & 0xff;

    if (!e)
        e++;
    return e - 0x7f - 23;
}

int eulp(double x)
{
    union { double f; uint64_t i; } u = { x };
    int e = u.i>>52 & 0x7ff;

    if (!e)
        e++;
    return e - 0x3ff - 52;
}

int eulpl(long double x)
{
#if LDBL_MANT_DIG == 53
    return eulp(x);
#elif LDBL_MANT_DIG == 64
    union { long double f; struct {uint64_t m; uint16_t e; uint16_t pad;} i; } u = { x };
    int e = u.i.e & 0x7fff;

    if (!e)
        e++;
    return e - 0x3fff - 63;
#else
    // TODO
    return 0;
#endif
}

float ulperrf(float got, float want, float dwant)
{
    if (isnan(got) && isnan(want))
        return 0;
    if (got == want) {
        if (signbit(got) == signbit(want))
            return dwant;
        return inf;
    }
    if (isinf(got)) {
        got = copysignf(0x1p127, got);
        want *= 0.5;
    }
    return scalbn(got - want, -eulpf(want)) + dwant;
}

float ulperr(double got, double want, float dwant)
{
    if (isnan(got) && isnan(want))
        return 0;
    if (got == want) {
        if (signbit(got) == signbit(want))
            return dwant;
        return inf; // treat 0 sign errors badly
    }
    if (isinf(got)) {
        got = copysign(0x1p1023, got);
        want *= 0.5;
    }
    return scalbn(got - want, -eulp(want)) + dwant;
}

float ulperrl(long double got, long double want, float dwant)
{
#if LDBL_MANT_DIG == 53
    return ulperr(got, want, dwant);
#elif LDBL_MANT_DIG == 64
    if (isnan(got) && isnan(want))
        return 0;
    if (got == want) {
        if (signbit(got) == signbit(want))
            return dwant;
        return inf;
    }
    if (isinf(got)) {
        got = copysignl(0x1p16383L, got);
        want *= 0.5;
    }
    return scalbnl(got - want, -eulpl(want)) + dwant;
#else
    // TODO
    return inf;
#endif
}

#define length(a) (sizeof(a)/sizeof*(a))
#define flag(x) {x, #x}
static struct {
    int flag;
    char *s;
} eflags[] = {
    flag(INEXACT),
    flag(INVALID),
    flag(DIVBYZERO),
    flag(UNDERFLOW),
    flag(OVERFLOW)
};


char *estr(int f)
{
    static char buf[256];
    char *p = buf;
    int i, all = 0;

    for (i = 0; i < length(eflags); i++)
        if (f & eflags[i].flag) {
            p += sprintf(p, "%s%s", all ? "|" : "", eflags[i].s);
            all |= eflags[i].flag;
        }
    if (all != f) {
        p += sprintf(p, "%s%d", all ? "|" : "", f & ~all);
        all = f;
    }
    p += sprintf(p, "%s", all ? "" : "0");
    return buf;
}

char *estrs(int got_in, int want_in, int r, E_CHECK_TYPE e_check)
{
    static char buf[256];
    char *p = buf;
    int got = got_in;
    int want = want_in;

    /* adjust inputs to clear ignored bits */
    if (r == RN) {
        switch (e_check) {
        default:
            break;
        case check_RN_inexact_omisson:
            /* got == want || got == (want|INEXACT) */
            if ((got & INEXACT) && !(want&INEXACT)) {
                want |= INEXACT;
            }
            break;
        case check_RN_inexact_igore:
            /* (got|INEXACT) == (want|INEXACT) */
            if ((got & INEXACT) && !(want&INEXACT)) {
                want |= INEXACT;
            }
            if (!(got & INEXACT) && (want&INEXACT)) {
                want &= ~INEXACT;
            }
            break;
        }
    } else if (e_check != check_all) {
       if ((got&INEXACT) && !(want&INEXACT)) {
            want |= INEXACT;
        }
        if ((got&UNDERFLOW) && !(want&UNDERFLOW)) {
            want |= UNDERFLOW;
        }
        if (!(got&INEXACT) && (want&INEXACT)) {
            want &= ~INEXACT;
        }
        if (!(got&UNDERFLOW) && (want&UNDERFLOW)) {
            want &= ~UNDERFLOW;
        }
    }

    p += sprintf(p, "want %s got ", estr(want));
    p += sprintf(p, "%s", estr(got));
    return buf;
}

char *estr2(int f)
{
    static char buf[256];
    char *p = buf;
    int i, all = 0;

    for (i = 0; i < length(eflags); i++)
        if (f & eflags[i].flag) {
            p += sprintf(p, "%s%s", all ? "|" : "", eflags[i].s);
            all |= eflags[i].flag;
        }
    if (all != f) {
        p += sprintf(p, "%s%d", all ? "|" : "", f & ~all);
        all = f;
    }
    p += sprintf(p, "%s", all ? "" : "0");
    return buf;
}

char *rstr(int r)
{
    switch (r) {
    case RN: return "RN";
#ifdef FE_TOWARDZERO
    case RZ: return "RZ";
#endif
#ifdef FE_UPWARD
    case RU: return "RU";
#endif
#ifdef FE_DOWNWARD
    case RD: return "RD";
#endif
    }
    return "R?";
}


int test_ll_l_checkexceptall(long double (*f)(long double, long double), const char *f_name, struct ll_l *t, size_t n)
{
    #pragma STDC FENV_ACCESS ON
    long double y;
    float d;
    int e, i, err = 0;
    struct ll_l *p;

    for (i = 0; i < n; i++) {
        p = t + i;

        if (p->r < 0)
            continue;
        fesetround(p->r);
        feclearexcept(FE_ALL_EXCEPT);
        y = f(p->x, p->x2);
        e = fetestexcept(INEXACT|INVALID|DIVBYZERO|UNDERFLOW|OVERFLOW);

        int passed = (t_checkexceptall(e, p->e, p->r));
        char *msg = NULL;
        asprintf(&msg, "bad fp exception: %s %s(%La,%La)=%La, %s",
                rstr(p->r), f_name, p->x, p->x2, p->y, ESTRS_ALL(e, p));
        if (p->expect_failure) {
            NTD_TESTCASE_KNOWN_FAILURE(p->fail_msg, passed, p->file, p->line, msg);
        } else {
            if (!passed) err++;
            NTD_TESTCASE(passed, p->file, p->line, msg);
        }
        free(msg);
        d = ulperrl(y, p->y, p->dy);
        if (!checkcr(y, p->y, p->r)) {
            printf("%s:%d: %s fdiml(%La,%La) want %La got %La ulperr %.3f = %a + %a\n",
                short_path(p->file), p->line, rstr(p->r), p->x, p->x2, p->y, y, d, d-p->dy, p->dy);
            err++;
        }
    }
    return !!err;
}

