﻿/*---------------------------------------------------------------------------*
  Copyright (C)2015 Nintendo Co., Ltd.  All rights reserved.

  These coded instructions, statements, and computer programs contain
  proprietary information of Nintendo of America Inc. and/or Nintendo
  Company Ltd., and are protected by Federal copyright law.  They may
  not be disclosed to third parties or copied or duplicated in any form,
  in whole or in part, without the prior written consent of Nintendo.
 *---------------------------------------------------------------------------*/

#include <nn/os.h>
#include <nn/nn_Abort.h>
#include <nn/nn_SdkLog.h>
#include <nn/mem/mem_StandardAllocator.h>
#include <nn/socket/socket_Config.h>

extern "C"
{

#include <sys/malloc.h>
#include <sys/queue.h>
#include <sys/systm.h>
#include <sys/kernel.h>
#include <sys/sysproto.h>
#include <siglo/client.h>
#include <siglo/uma.h>
#include <siglo/idtab.h>
#include <net/if.h>

#include <netinet/tcp_var.h>

extern int udp_shutdown_all_uid(uid_t uid,int forced);
extern int rip_shutdown_all_uid(uid_t uid,int forced);
extern int udplite_shutdown_all_uid(uid_t uid,int forced);
extern int udp_shutdown_all(int forced);
extern int rip_shutdown_all(int forced);
extern int udplite_shutdown_all(int forced);

extern nn::mem::StandardAllocator g_Allocator;

static void uma_snprint_zone_siglo(uma_zone_t zone, char** buffer, size_t buffer_length, int* written)
{
    if (zone != NULL)
    {
        *written = uma_print_zone_siglo(*buffer, buffer_length, zone);
        if (*written > 0) *buffer += *written;
    }
}

// list of all registered clients
static struct slisthead client_list = SLIST_HEAD_INITIALIZER(client_list);
static MALLOC_DEFINE(M_CLIENT, "iclient", "client container allocations");

int client_create(uint64_t pid, void* mempool, size_t mempool_size, void* private_data, const client_config* p_config)
{
    static nn::os::MutexType client_create_mutex = NN_OS_MUTEX_INITIALIZER(1);
    uint32_t sb_limit;
    int      rval = -1;
    struct client* container = NULL;

    nn::os::LockMutex(&client_create_mutex);
    critical_enter();
    SLIST_FOREACH(container, &client_list, cl_entry) {
        if (container->cl_pid == pid) {
            // already exists
            critical_exit();
            nn::os::UnlockMutex(&client_create_mutex);
            return 0;
        }
    }
    critical_exit();

    container = (struct client*)malloc(sizeof(struct client), M_CLIENT, M_ZERO);
    if (container != NULL) {
        memset(container, 0, sizeof(struct client));
        mtx_init(&(container->cl_prison.pr_mtx), "prison mutex", NULL, MTX_DEF);
        mtx_init(&(container->cl_proc.p_mtx),    "proc mtx",     NULL, MTX_DEF);
        FILEDESC_LOCK_INIT(&(container->cl_filedesc));

        // Determine limit for total space used for client's socket buffers,
        // use 90% of transfer memory, or 1Mbyte if no transfer memory was provided

        if (mempool != NULL &&
            mempool_size > 0)
        {
            sb_limit = (mempool_size * 9) / 10;
        } else {
            sb_limit = (1 * 1024 * 1024);
        }

        // Per process limit contains current and max values,
        // our bsd port is using only the current value,
        // the value is a combined limit for tx/rx sockbufs.
        container->cl_plimit.pl_rlimit[RLIMIT_SBSIZE] = {sb_limit, sb_limit};

        // Propagate socket library configuration down into process limits
        if(p_config != NULL){
            container->cl_gids[0] = p_config->gid;
            container->cl_plimit.pl_rlimit[RLIMIT_TCP_SEND_SBSIZE] =
                {p_config->tcpInitialSendBufferSize, p_config->tcpAutoSendBufferSizeMax};
            container->cl_plimit.pl_rlimit[RLIMIT_TCP_RECEIVE_SBSIZE] =
                {p_config->tcpInitialReceiveBufferSize, p_config->tcpAutoReceiveBufferSizeMax};
            container->cl_plimit.pl_rlimit[RLIMIT_UDP_SEND_SBSIZE] =
                {p_config->udpSendBufferSize};
            container->cl_plimit.pl_rlimit[RLIMIT_UDP_RECEIVE_SBSIZE] =
                {p_config->udpReceiveBufferSize, 0};
            container->cl_plimit.pl_rlimit[RLIMIT_SB_EFFICIENCY] =
                {p_config->socketBufferEfficiency, 0};
            container->cl_plimit.pl_rlimit[RLIMIT_LIB_COMPATIBILITY] =
                {p_config->version, nn::socket::LibraryVersion};
        }else{
            /* Applies only to DEFAULT_PID, core stack, not a full client */
            container->cl_plimit.pl_rlimit[RLIMIT_LIB_COMPATIBILITY] =
                {nn::socket::LibraryVersion, nn::socket::LibraryVersion};
        }

        container->cl_pid                   = pid;
        container->cl_private_data          = private_data;
        container->cl_mempool               = mempool;
        container->cl_mempool_size          = mempool_size;
        container->cl_ucred.cr_uid          = pid;
        container->cl_ucred.cr_ref          = 1;
        container->cl_ucred.cr_prison       = &container->cl_prison;
        container->cl_ucred.cr_uidinfo      = &container->cl_uidinfo;
        container->cl_ucred.cr_groups       = &container->cl_gids[0];
        container->cl_ucred.cr_container    = container;
        container->cl_filedesc.fd_nfiles    = 0;
        container->cl_filedesc.fd_lastfile  = 0;
        container->cl_filedesc.fd_freefile  = 0;
        container->cl_filedesc.fd_ofiles    = container->cl_filedescent;
        container->cl_filedesc.fd_map       = container->cl_fdmap;
        container->cl_proc.p_limit          = &container->cl_plimit;
        container->cl_proc.p_fd             = &container->cl_filedesc;
        container->cl_proc.p_malloc         = &container->cl_malloc;
        container->cl_proc.p_pid            = (uint32_t)pid;
        container->cl_flags                 = ClientFlags::FLAG_VALID;

        if (pid != DEFAULT_PID &&
            mempool != NULL &&
            mempool_size > 0)
        {
            container->cl_allocator.Initialize(mempool, mempool_size);
            container->cl_malloc.ks_allocator = &container->cl_allocator;
        } else {
            container->cl_malloc.ks_allocator = NULL;
        }

        snprintf(container->cl_name, sizeof(container->cl_name), "client pid: %u", pid == DEFAULT_PID ? 0 : (unsigned)pid);
        container->cl_malloc.ks_shortdesc = container->cl_name;
        container->cl_malloc.ks_longdesc  = container->cl_name;

        // this client's malloc
        malloc_init(&container->cl_malloc);

        if (container->cl_mempool_size != 0)
        {
            container->cl_vtab_entries = (uintptr_t*)malloc(container->cl_mempool_size / PAGE_SIZE * sizeof(uintptr_t),
                                                            container->cl_proc.p_malloc,
                                                            M_ZERO);
        }

        // setup mbuf allocators (default proc zones are created in uma.cpp)
        // mempool size is used to set limits on number of mbufs, clusters, etc...
        if (pid != DEFAULT_PID) {
            create_zone_allocators(
                    &container->cl_proc,
                    mempool_size ? mempool_size : (1 * 1024 * 1024)
            );
        }
        critical_enter();
        SLIST_INSERT_HEAD(&client_list, container, cl_entry);
        critical_exit();
        rval = 0;
    }
    nn::os::UnlockMutex(&client_create_mutex);

    return rval;
}

struct client* client_acquire_ref(uint64_t pid)
{
    struct client* item;
    struct client* container = NULL;
    critical_enter();
    SLIST_FOREACH(item, &client_list, cl_entry) {
        if (item->cl_pid == pid && item->cl_flags & ClientFlags::FLAG_VALID) {
            container = item;
            crhold(&(container->cl_ucred));
            break;
        }
    }
    critical_exit();
    NN_ABORT_UNLESS(container != NULL);
    return container;
}

void client_release_ref()
{
    crfree(curthread->td_ucred);
}

void client_release_ref_by_pid(uint64_t pid)
{
    struct client* item;
    critical_enter();
    SLIST_FOREACH(item, &client_list, cl_entry) {
        if (item->cl_pid == pid) {
            crfree(&(item->cl_ucred));
            break;
        }
    }
    critical_exit();
}

static struct client* client_deactivate(uint64_t pid)
{
    struct client* item;
    struct client* container = NULL;
    critical_enter();
    SLIST_FOREACH(item, &client_list, cl_entry) {
        if (item->cl_pid == pid) {
            FILEDESC_XLOCK(&(item->cl_filedesc));
            container = item;
            container->cl_flags &= ~ClientFlags::FLAG_VALID;
            FILEDESC_XUNLOCK(&(item->cl_filedesc));
        }
    }
    critical_exit();
    return container;
}

static void close_all_sockets(struct client* container)
{
    int fd;
    struct thread *td = curthread;
    for (fd = 0; fd < maxfilesperproc; fd++) {
        if (get_field(container->cl_fdmap, fd)) {

            // linger-0 for guaranteed non-blocking close
            struct linger soLinger = {true,  0};
            struct setsockopt_args my_setsockopt_args;
            my_setsockopt_args.s       = fd;
            my_setsockopt_args.level   = SOL_SOCKET;
            my_setsockopt_args.name    = SO_LINGER;
            my_setsockopt_args.val     = (caddr_t)&soLinger;
            my_setsockopt_args.valsize = sizeof(soLinger);
            td->td_retval[0] = 0;
            sys_setsockopt(td, &my_setsockopt_args);

            // Close the descriptor
            struct close_args my_close_args;
            my_close_args.fd = fd;
            td->td_retval[0] = 0;
            sys_close(td, &my_close_args);
        }
    }
    return;
}

int client_shutdown_all_sockets(int * pcount, int forced)
{
    int error = 0;

    *pcount = 0;
    *pcount += tcp_shutdown_all(forced);
    *pcount += udp_shutdown_all(forced);
    *pcount += udplite_shutdown_all(forced);
    *pcount += rip_shutdown_all(forced);

    return error;
}

int client_check_is_valid_from_cred(struct ucred * cred)
{
    if (cred == NULL)
    {
        return 0;
    }
    struct client * client = (struct client *)cred->cr_container;
    if (client == NULL ||
        !(client->cl_flags & ClientFlags::FLAG_VALID))
    {
        return 0;
    }

    return 1;
}

int client_terminate(uint64_t pid, void **private_data)
{
    struct client *container = client_deactivate(pid);
    NN_ABORT_UNLESS(container != NULL, "unknown client process");


    // Subsequent commands executed on behalf of terminating process
    curthread->td_ucred = &container->cl_ucred;
    curthread->td_proc  = &container->cl_proc;

    // Any sockets that client failed to close are now forcibly shutdown

    tcp_shutdown_all_uid(pid, true);
    udp_shutdown_all_uid(pid, true);
    udplite_shutdown_all_uid(pid, true);
    rip_shutdown_all_uid(pid, true);

    // close sockets
    close_all_sockets(container);

    // Wait for cross-process (shared socket) ref count to drain.
    // Fortunately Nintendo SSL system process is well behaved. So this will be reliable and quick.
    u_int last_reported_cr_ref = 0;
    u_int cr_cross_proc_ref;
    while((cr_cross_proc_ref=container->cl_ucred.cr_cross_proc_ref) > 0) {
        if(cr_cross_proc_ref != last_reported_cr_ref){
            printf("bsdsocket pid=%d termination pending with cr_cross_proc_ref %d.\n", pid, cr_cross_proc_ref);
            last_reported_cr_ref = cr_cross_proc_ref;
        }
        DELAY(CLIENT_TERMINATE_POLL_INTERVAL);
    }
    if(last_reported_cr_ref != 0){
        printf("bsdsocket pid=%d termination proceeding.\n", pid);
    }



    // drop any remaining TCP connections, such as:
    // - connections still sending previously submitted data
    // - connections still in TIME_WAIT, FIN_WAIT2, etc.
    tcp_drop_uid(pid,1);

    // ensure all API calls have finished.
    last_reported_cr_ref = 1;
    u_int cr_ref;
    while ((cr_ref = container->cl_ucred.cr_ref) > 1) {
        if (cr_ref != last_reported_cr_ref) {
            printf("bsdsocket pid=%d termination pending with cr_ref %d.\n", pid, cr_ref);
            last_reported_cr_ref = cr_ref;
        }
        DELAY(CLIENT_TERMINATE_POLL_INTERVAL);
    }
    if (last_reported_cr_ref != 1) {
        printf("bsdsocket pid=%d termination proceeding.\n", pid);
    }
    // clear interface egress queues of any outstanding client zone mbufs

    vnet_if_inc_zones_terminating();
    vnet_if_clear_zone_egress_packets(&container->cl_proc.p_zones);

    // this verifies that that all chunks have been released
    // back into each zone, and will wait until all packets are free.
    u_int allocated_count = get_zone_allocator_count(&container->cl_proc);
    u_int last_allocated_count = 0;
    while (allocated_count > 0)
    {
        if (allocated_count != last_allocated_count)
        {
            printf("bsdsocket pid=%d waiting for zone memory to free (%d allocated).\n", pid, allocated_count);
            last_allocated_count = allocated_count;
        }
        DELAY(CLIENT_TERMINATE_POLL_INTERVAL);
        allocated_count = get_zone_allocator_count(&container->cl_proc);
    }
    if (allocated_count != last_allocated_count)
    {
        printf("bsdsocket pid=%d zone memory has freed!\n", pid);
    }
    // this will again verify that all chunks have been released
    // back into each zone, it will panic if
    // something is still in use.
    destroy_zone_allocators(&container->cl_proc);
    vnet_if_dec_zones_terminating();

    if (container->cl_vtab_entries)
    {
        free(container->cl_vtab_entries, container->cl_proc.p_malloc);
    }

    malloc_uninit(&container->cl_malloc);


    if (container->cl_pid != DEFAULT_PID &&
        container->cl_mempool != NULL &&
        container->cl_mempool_size > 0)
    {
        container->cl_allocator.Finalize();
    }
    critical_enter();
    SLIST_REMOVE(&client_list, container, client, cl_entry);
    critical_exit();
    // cookie stored by the upper layer
    *private_data = container->cl_private_data;

    // Release reference on this container.
    // This structure will no longer be used, but some
    // internal code may still use ucred embedded in it.
    // So we keep this around until crfree drops ucred's
    // reference count to 0.
    crfree(&(container->cl_ucred));

    return 0;
}

void client_free(void* arg)
{
    struct client* container = (struct client*)arg;
    FILEDESC_LOCK_DESTROY(&(container->cl_filedesc));
    mtx_destroy(&(container->cl_prison.pr_mtx));
    mtx_destroy(&(container->cl_proc.p_mtx));
    free(container, M_CLIENT);
}

struct client* client_find_by_mempool(void* addr)
{
    client* container = NULL;

    critical_enter();
    SLIST_FOREACH(container, &client_list, cl_entry) {
        if (container->cl_mempool &&
            container->cl_mempool_size > 0 &&
            addr >= container->cl_mempool &&
            addr < static_cast<char *>(container->cl_mempool) + container->cl_mempool_size)
        {
            break;
        }
    }
    critical_exit();

    return container;
}

int client_get_resource_statistics(uint64_t pid, nn::socket::StatisticsType type, void* buffer, size_t bufferLength, uint32_t options)
{
    int rval                 = -1;
    struct client* container = NULL;
    uint64_t interesting_pid = (options & nn::socket::StatisticsOption_BsdSocketProcess) ? 0 : pid;

    critical_enter();
    SLIST_FOREACH(container, &client_list, cl_entry) {
        if (container->cl_pid == interesting_pid) {
            rval = 0;
            break;
        }
    }

    if ((rval == 0) && (container != NULL) && (buffer != NULL))
    {
        struct uma_client_zones *zones = &container->cl_proc.p_zones;
        nn::mem::StandardAllocator *stdalloc_p = (interesting_pid == DEFAULT_PID) ? &g_Allocator : &container->cl_allocator;
        int free_size = stdalloc_p->GetTotalFreeSize();
        int allocatable_size = stdalloc_p->GetAllocatableSize();

        switch (type)
        {
        case nn::socket::StatisticsType_Default:  // return ResourceStatistics type
            nn::socket::ResourceStatistics* rs;
            rs = (nn::socket::ResourceStatistics*) buffer;
            rs->pid = pid;
            rs->descriptorCount = container->cl_filedesc.fd_nfiles;
            rs->transferMemoryPoolTotalSize = container->cl_mempool_size;
            rs->transferMemoryPoolTotalFreeSize = free_size;
            rs->transferMemoryPoolAllocatableSize = allocatable_size;
            rval = sizeof(nn::socket::ResourceStatistics);
            break;

        case nn::socket::StatisticsType_Memory:  // return a string with the info
            char *s;
            int   written;
            s = (char*) buffer;
            written = snprintf(s, bufferLength, "pid = %d%s\nmem pool size = %d KB\n\tfree size = %d KB\n\tallocatable size = %d KB\n",
                               interesting_pid,
                               (interesting_pid == DEFAULT_PID) ? "(bsdsocket)" : "",
                               container->cl_mempool_size / 1024,
                               free_size / 1024,
                               allocatable_size / 1024);
            if (written > 0) s += written;

            uma_snprint_zone_siglo(zones->zone_mbuf, &s, (size_t)(bufferLength - (s - (char*)buffer)), &written);
            if (interesting_pid == 0)
            {
                uma_snprint_zone_siglo(zones->zone_pack, &s, (size_t)(bufferLength - (s - (char*)buffer)), &written);
            }
            uma_snprint_zone_siglo(zones->zone_clust,& s, (size_t)(bufferLength - (s - (char*)buffer)), &written);
            uma_snprint_zone_siglo(zones->zone_jumbop, &s, (size_t)(bufferLength - (s - (char*)buffer)), &written);
            uma_snprint_zone_siglo(zones->zone_jumbo9, &s, ((size_t)bufferLength - (s - (char*)buffer)), &written);
            uma_snprint_zone_siglo(zones->zone_jumbo16,&s, (size_t)(bufferLength - (s - (char*)buffer)), &written);

            if ((written >0) && (written < bufferLength))
            {
                written++;
                s[written] = '\0';
            }

            if (s > buffer)
            {
                rval = (int) (s - (char*)buffer);
            }
            break;

        default:
            rval = -1;
        }
    }

    critical_exit();

    return rval;
}

}
