﻿// --------------------------------------------------------------------------------
// <copyright>
// Copyright (C)Nintendo. All rights reserved.
//
// These coded instructions, statements, and computer programs contain proprietary
// information of Nintendo and/or its licensed developers and are protected by
// national and international copyright laws. They may not be disclosed to third
// parties or copied or duplicated in any form, in whole or in part, without the
// prior written consent of Nintendo.
//
// The content herein is highly confidential and should be handled accordingly.
// </copyright>
// --------------------------------------------------------------------------------
using System;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Linq;
using System.Text.RegularExpressions;
using System.Collections.Generic;
using System.Collections;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;

namespace Nintendo.ControlTarget
{
    public class ShellResult
    {
        public ShellResult(string message)
        {
            Message = message;

            if (FindResultCode(message))
            {
                ResultCode = GetResultCode(message);
            }
            else
            {
                ResultCode = -1;
            }
        }

        public ShellResult(string message, int resultCode)
        {
            Message = message;
            ResultCode = resultCode;
        }

        public string Message { get; set; }
        public int ResultCode { get; set; }

        public static bool FindResultCode(string result)
        {
            try
            {
                GetResultCode(result);
                return true;
            }
            catch
            {
                return false;
            }
        }

        public static int GetResultCode(string result)
        {
            var splitted = new List<string>(result.Trim().Split(new string[] { "\r\n", "\n" }, StringSplitOptions.RemoveEmptyEntries));
            splitted.Reverse();
            var match = new Regex(@"\s+->\s+([-0-9]+).*");
            var resultCodeLine = splitted.Find(s =>
            {
                return match.IsMatch(s);
            });
            var matched = match.Match(resultCodeLine);
            return int.Parse(matched.Groups[1].Value);
        }

        public static bool HasNextShellPrompt(string result)
        {
            var splitted = new List<string>(result.Trim().Split(new string[] { "\r\n", "\n" }, StringSplitOptions.RemoveEmptyEntries));
            return splitted.LastOrDefault() == "shell#";
        }
    }

    public class ShellAccessor : IDisposable
    {
        public ShellAccessor(IPEndPoint endPoint)
        {
            Initialize(endPoint);
        }

        private void Initialize(IPEndPoint endPoint)
        {
            Trace.WriteLine($"Connect: {endPoint}");
            socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);

            try
            {
                RetryUtility.Do(
                    () =>
                    {
                        Trace.WriteLine($"Try to connect: {endPoint}");
                        var result = socket.BeginConnect(endPoint, null, null);
                        if (!result.AsyncWaitHandle.WaitOne(2000, true))
                        {
                            socket.EndConnect(result);
                            throw new TimeoutException();
                        }
                    },
                    e =>
                    {
                        Trace.WriteLine($"Failed to connect: {e.Message}");
                    },
                    5,
                    TimeSpan.Zero);
            }
            catch (Exception e)
            {
                throw new Exception($"[ERROR] Failed to connect: {e.Message}");
            }

            socketStream = new NetworkStream(socket);
            socketWriter = new StreamWriter(socketStream);
            socketWriter.AutoFlush = true;
            socketReader = new StreamReader(socketStream);

            Trace.WriteLine($"Connected: {endPoint}");

            StartDataflow();
        }

        private void StartDataflow()
        {
            DataflowUtility.LinkBlock("HostBridgeAccessor::BroadCastBlock->NullBlock", BroadCastBlock, NullBlock);
            BroadCastBlock.Post(string.Empty);

            ReadTask = Task.Factory.StartNew(() =>
            {
                while (true)
                {
                    const int BUFFER_SIZE = 4096;
                    byte[] buffer = new byte[BUFFER_SIZE];
                    byte[] eliminatedBuffer = new byte[BUFFER_SIZE];
                    var read = socketStream.Read(buffer, 0, BUFFER_SIZE);
                    var eliminatedRead = eliminator.CopyWithElimination(eliminatedBuffer, buffer, read);
                    var eliminatedString = Encoding.UTF8.GetString(eliminatedBuffer, 0, eliminatedRead);
                    BroadCastBlock.Post(eliminatedString);
                }
            });

            WriteTask = Task.Factory.StartNew(async () =>
            {
                while (await WriteBuffer.OutputAvailableAsync())
                {
                    var text = WriteBuffer.Receive();
                    socketWriter.Write(text);
                }
            }).Unwrap();

            SendCommand(string.Empty);
        }

        public void KillAll()
        {
            foreach (var pid in ListProcess().Item2)
            {
                KillProcess(pid);
            }
        }

        public ShellResult KillProcess(int pid)
        {
            return SendCommand(string.Format("T {0}", pid));
        }

        public Tuple<ShellResult, bool> ExistProcess(int pid)
        {
            var result = ListProcess();
            return new Tuple<ShellResult, bool>(result.Item1, result.Item2.Exists(n => { return n == pid; }));
        }

        public Tuple<ShellResult, List<int>> ListProcess()
        {
            var result = SendCommand("P");
            var pidPattern = new Regex(@"\s+PID=\s+([0-9]+).*");
            var list = result.Message.Split(new string[] { "\n", "\r\n" }, StringSplitOptions.RemoveEmptyEntries);
            var pids = from l in list
                       where pidPattern.IsMatch(l)
                       select int.Parse(pidPattern.Match(l).Groups[1].Value);

            return new Tuple<ShellResult, List<int>>(result, pids.ToList());
        }

        public ShellResult LoadProgram(FileInfo file, string[] arguments = null)
        {
            var firstArgument = file.FullName;
            if (firstArgument.EndsWith(".kip"))
            {
                firstArgument = "host:" + firstArgument;
            }

            if (arguments == null || arguments.Length == 0)
            {
                return SendCommand("L " + firstArgument);
            }
            else
            {
                return SendCommand("L " + firstArgument + " " + string.Join(" ", arguments));
            }
        }

        public ShellResult SendCommand(string command)
        {
            Trace.WriteLine("SendCommand: {0}", command);

            lock (this)
            {
                BroadCastBlock.Post(string.Empty);
                return SendCommand(command, WritePort, ReadPort, TimeSpan.FromMinutes(1));
            }
        }

        public void Dispose()
        {
            socket.Disconnect(false);
            socketWriter.Dispose();
            socketStream.Dispose();
            socket.Dispose();
        }

        public ISourceBlock<string> ReadPort { get { return BroadCastBlock; } }
        public ITargetBlock<string> WritePort { get { return WriteBuffer; } }

        private static ShellResult SendCommand(string command, ITargetBlock<string> target, ISourceBlock<string> source, TimeSpan timeout)
        {
            var builder = new StringBuilder();
            var receiveTask = Task.Factory.StartNew(() =>
            {
                BufferBlock<string> sourceBuffer = new BufferBlock<string>();
                using (var unlink = DataflowUtility.LinkBlock("ShellAccessor::source->sourceBuffer", source, sourceBuffer))
                {
                    sourceBuffer.Receive();
                    while (!ShellResult.HasNextShellPrompt(builder.ToString()))
                    {
                        builder.Append(sourceBuffer.Receive());
                    }
                }
            });

            target.Post(command + "\r\n");

            receiveTask.Wait(timeout);

            return new ShellResult(builder.ToString());
        }

        private bool FoundResultCode(string result)
        {
            try
            {
                GetResultCode(result);
                return true;
            }
            catch
            {
                return false;
            }
        }

        private int GetResultCode(string result)
        {
            var splitted = new List<string>(result.Trim().Split(new string[] { "\r\n", "\n" }, StringSplitOptions.RemoveEmptyEntries));
            splitted.Reverse();
            var match = new Regex(@"\s+->\s+([-0-9]+).*");
            var resultCodeLine = splitted.Find(s =>
            {
                return match.IsMatch(s);
            });
            var matched = match.Match(resultCodeLine);
            return int.Parse(matched.Groups[1].Value);
        }

        private BroadcastBlock<string> BroadCastBlock = new BroadcastBlock<string>(s => { return s; });
        private BufferBlock<string> WriteBuffer = new BufferBlock<string>();
        private ITargetBlock<string> NullBlock = DataflowBlock.NullTarget<string>();
        private Task ReadTask;
        private Task WriteTask;
        private Socket socket;
        private NetworkStream socketStream;
        private StreamWriter socketWriter;
        private StreamReader socketReader;
        private TelnetIacEliminator eliminator = new TelnetIacEliminator();

        public static readonly IPEndPoint ShellEndPoint = new IPEndPoint(IPAddress.Parse("192.168.0.10"), 23);
    }
}
