﻿// --------------------------------------------------------------------------------
// <copyright>
// Copyright (C)Nintendo. All rights reserved.
//
// These coded instructions, statements, and computer programs contain proprietary
// information of Nintendo and/or its licensed developers and are protected by
// national and international copyright laws. They may not be disclosed to third
// parties or copied or duplicated in any form, in whole or in part, without the
// prior written consent of Nintendo.
//
// The content herein is highly confidential and should be handled accordingly.
// </copyright>
// --------------------------------------------------------------------------------
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace MakeSvcVeneer
{
    internal class SvcDriverSourceGenerator : SourceGenerator
    {
        public SvcDriverSourceGenerator(CodeGenNames p) : base(p) { }

        public void Generate(SvcSet ss,
            Dictionary<string, LayoutConversion> conv,
            Dictionary<string, AbiLayout> abi,
            Dictionary<string, SvcLayout> svc,
            string templatePath,
            string path)
        {
            this.Generate(ss, abi, svc, conv, templatePath, path);
        }

        protected override string Generate(
            SvcSet ss,
            Dictionary<string, AbiLayout> abi,
            Dictionary<string, SvcLayout> svc,
            Dictionary<string, LayoutConversion> conv)
        {
            var sb = new StringBuilder();
            sb.AppendLine();

            sb.AppendLine("//");
            sb.AppendFormat("// type name                                size al abstr comon\r\n");
            sb.AppendFormat("//\r\n");
            foreach (var t in ss.Types.Values)
            {
                if (!t.IsPredefined)
                {
                    sb.AppendFormat("// {0,-40} {1,4} {2,2} {3,-5} {4,-5}\r\n",
                        t.Name, t.Size, t.Alignment, t.IsAbstract, t.IsCommon);
                }
            }
            sb.AppendLine("//");
            sb.AppendLine();

            foreach (var op in ss.Operations)
            {
                var lc = conv[op.Name];
                if (lc.IsRequireConversion)
                {
                    try
                    {
                        var text = this.Generate(op, abi[op.Name], svc[op.Name], lc);
                        sb.Append(text);
                        sb.AppendLine();
                    }
                    catch (ErrorException ee)
                    {
                        throw new ErrorException(
                            string.Format("operation: {0}", op.Name), ee);
                    }
                }
            }

            sb.AppendLine("    .end");
            sb.AppendLine();
            return sb.ToString();
        }

        private string MakeDriverSymbolName(string name)
        {
            var fi = new Mangler.FunctionInfo();
            fi.Name = string.Format("{0}::{1}::{2}{3}",
                CodeGenNames.KernelCommonNamespace,
                Params.DriverClassName,
                CodeGenNames.DriverNamePrefix,
                name);
            fi.Parameters = new Mangler.TypeInfo[0];
            return Mangler.Mangle(fi, null);
        }

        private string Generate(Operation op, AbiLayout al, SvcLayout sl, LayoutConversion lc)
        {
            var symbolFormat = string.Format("{0}::{1}{2}{3}",
                CodeGenNames.KernelCommonNamespace,
                CodeGenNames.HandlerNamePrefix,
                "{0}",
                Params.HandlerNamePostfix);
            var callerSymbolName = this.MakeDriverSymbolName(op.Name);
            var calleeSymbolName = MakeSymbolName(symbolFormat, op, al, true);

            var code = this.GenerateCode(calleeSymbolName, al, sl, lc);

            var sb = new StringBuilder();
            sb.AppendFormat("//-------------------------------------------------\r\n");
            sb.AppendFormat("// {0}\r\n", SourceGenerator.MakePrototype(op, al, true));
            sb.AppendFormat("//\r\n");
            sb.Append(this.MakeLayoutComment(al, sl, lc, op));
            sb.AppendLine();
            sb.AppendFormat("    .section .text.{0}, \"ax\"\r\n", callerSymbolName);
            sb.AppendFormat("    .align   2\r\n");
            sb.AppendFormat("    .global  {0}\r\n", callerSymbolName);
            sb.AppendFormat("    .type    {0}, %function\r\n", callerSymbolName);
            sb.AppendFormat("    .type    {0}, %function\r\n", calleeSymbolName);
            sb.AppendFormat("{0}:\r\n", callerSymbolName);
            sb.AppendLine();
            sb.Append(code);
            sb.AppendLine();
            sb.AppendFormat("    .size {0}, [.-{0}]\r\n", callerSymbolName);
            return sb.ToString();
        }

        private string GenerateCode(string calleeSymbolName, AbiLayout al, SvcLayout sl, LayoutConversion lc)
        {
            int usedStackCount = 0;
            var storageMap = new Dictionary<int, int>();

            // ABI でスタックに保存することになっている分のスタックを確保
            foreach (var se in al.In.Params)
            {
                var stackPositions = al.In.GetPositions(se.VariableName)
                                    .Where(x => x.Storage == Layout.StorageType.Stack);
                foreach (var sp in stackPositions)
                {
                    if (usedStackCount < sp.Index + 1)
                    {
                        usedStackCount = sp.Index + 1;
                    }
                }
            }

            // 値渡しを参照渡しに変換するためのストレージ分のスタックを確保
            foreach (var lco in lc.PreOperations.Concat(lc.PostOperations))
            {
                if ((lco is LayoutConversionOperationScatter)
                    || (lco is LayoutConversionOperationIndirectScatter))
                {
                    storageMap.Add(lco.AbiIndex, usedStackCount * al.In.StorageSize);
                    usedStackCount += Util.DivUp(lco.SvcIndex.Length, al.In.StorageSize / sl.In.StorageSize);
                }
            }

            bool isRequirePadding = ((usedStackCount + lc.KeepRegisters.Length + 1) % 2) != 0;

            CodeGenerator cg = CodeGenerator.FromCodeGenParams(al.CodeGenParams, sl.CodeGenParams);

            var tracer = new RegisterUsageTracer();
            for (int regNo = 0; regNo < sl.In.RegisterCount; ++regNo)
            {
                var name = sl.In.GetRegisterParam(regNo);
                if (name != null)
                {
                    tracer.Occupy(regNo, "svc " + name);
                }
            }

            // 保持が必要なレジスタを保存
            cg.SaveRegisters(lc.KeepRegisters, true, isRequirePadding);

            // スタックを使用するならスタックを確保
            if (usedStackCount > 0)
            {
                cg.AllocateFromStack(usedStackCount * al.In.StorageSize);
                cg.AddEmptyLine();
            }

            // ハンドラ呼び出し前の処理
            //   SVC → ABI
            if (lc.PreOperations.Length > 0)
            {
                GeneratePreOperationCode(
                    cg,
                    tracer,
                    lc.PreOperations,
                    storageMap,
                    al.CodeGenParams.RegisterCount);
                cg.AddEmptyLine();
            }
            if (lc.PostOperations.Length > 0)
            {
                foreach (var lco in lc.PostOperations)
                {
                    lco.GenerateCodeOutPrepareToAbi(cg, tracer, storageMap);
                }
                cg.AddEmptyLine();
            }

            // ハンドラ呼び出し
            cg.CallFunction(calleeSymbolName);

            // ハンドラ呼び出し後の処理
            //   ABI → SVC
            if (lc.PostOperations.Length > 0)
            {
                cg.AddEmptyLine();
                foreach (var lco in lc.PostOperations)
                {
                    lco.GenerateCodeOutAbiToSvc(cg, tracer, storageMap);
                }
            }

            // スタックを使用していたならスタックを解放
            if (usedStackCount > 0)
            {
                cg.AddEmptyLine();
                cg.FreeToStack(usedStackCount * al.In.StorageSize);
            }

            // 保持が必要なレジスタを復帰
            cg.RestoreRegisters(lc.KeepRegisters, true, isRequirePadding);

            return cg.CodeText;
        }

        private static string GetOrEmpty(Dictionary<string, string> dic, string key)
        {
            string v;
            return dic.TryGetValue(key, out v) ? v : string.Empty;
        }
        private static string GetVariableSize(Operation op, string name, int pointerSize, int registerSize)
        {
            if (name == "return")
            {
                return op.ReturnType.Size.ToString();
            }
            else
            {
                return op.Parameters.First(x => x.Name == name).GetSizeText(pointerSize, registerSize);
            }
        }
        private static Dictionary<string, string> FormatParams(Layout layout, Operation op)
        {
            var list = new Dictionary<string, string>();
            foreach (var kv in layout.StorageParams)
            {
                var pos = string.Format("{0,-8} {1}", kv.Item1.Storage, kv.Item1.Index);
                var text = string.Format("{0}  {2,3} {3,2} {1}",
                    pos,
                    kv.Item2.VariableName,
                    kv.Item2.IsReference ? "ref" : string.Empty,
                    GetVariableSize(op, kv.Item2.VariableName, layout.PointerSize, layout.RegisterSize));
                list.Add(pos, text);
            }
            return list;
        }
        private string MakeLayoutComment(
            AbiLayout al, SvcLayout sl, LayoutConversion lc, Operation op)
        {
            var inAbiList = FormatParams(al.In, op);
            var inSvcList = FormatParams(sl.In, op);
            var outAbiList = FormatParams(al.Out, op);
            var outSvcList = FormatParams(sl.Out, op);

            var svcMaxLength = inSvcList.Concat(outSvcList).Select(x => x.Value).Add(string.Empty).Max(x => x.Length);
            var inKeys = inAbiList.Concat(inSvcList).ToLookup(x => x.Key).Select(x => x.Key).OrderBy(x => x).ToArray();
            var outKeys = outAbiList.Concat(outSvcList).ToLookup(x => x.Key).Select(x => x.Key).OrderBy(x => x).ToArray();

            var lineFormat = string.Concat("// {0,-3} {1,-", svcMaxLength, "}   {2}\r\n");

            var sb = new StringBuilder();
            sb.AppendFormat(lineFormat, string.Empty, sl.CodeGenParams, al.CodeGenParams);
            sb.AppendFormat(lineFormat, "in", "svc", "abi");
            foreach (var key in inKeys)
            {
                var abi = GetOrEmpty(inAbiList, key);
                var svc = GetOrEmpty(inSvcList, key);
                sb.AppendFormat(lineFormat, string.Empty, svc, abi);
            }
            sb.AppendFormat("//\r\n");
            sb.AppendFormat(lineFormat, "in", "svc", "abi");
            foreach (var key in outKeys)
            {
                var abi = GetOrEmpty(outAbiList, key);
                var svc = GetOrEmpty(outSvcList, key);
                sb.AppendFormat(lineFormat, string.Empty, svc, abi);
            }
            sb.AppendFormat("//\r\n");
            sb.AppendFormat("// conversion\r\n");
            foreach (var c in lc.PreOperations.Concat(lc.PostOperations))
            {
                sb.AppendFormat("//   {0,-15} {1}\r\n", c.Type, c.VariableName);
            }
            return sb.ToString();
        }

        private static int FindFreeRegister(
            RegisterUsageTracer tracer, List<LayoutConversion.Operation> ops, int regMax)
        {
            for (int i = 0; i <= regMax; ++i)
            {
                if (tracer.IsUsing(i))
                {
                    continue;
                }
                if (ops.Any(x => x.AbiIndex == i))
                {
                    continue;
                }
                return i;
            }

            throw new ErrorException("レジスタの使用を解決できませんでした");
        }
        private static void GeneratePreOperationCode(
            CodeGenerator cg,
            RegisterUsageTracer tracer,
            LayoutConversion.Operation[] operations,
            Dictionary<int, int> storageMap,
            int regMax)
        {
            var ops = operations.ToList();

            while (ops.HasItem())
            {
                var ops2 = ops.ToArray();
                foreach (var lco in ops2)
                {
                    if (lco.CanGenerateCodeInSvcToAbi(tracer.Clone()))
                    {
                        lco.GenerateCodeInSvcToAbi(cg, tracer, storageMap);
                        ops.Remove(lco);
                    }
                }

                if (ops.Count == ops2.Length)
                {
                    int freeRegNo = FindFreeRegister(tracer, ops, regMax);

                    var op1 = ops[0];

                    var op1Modifed = op1.MakeModified(freeRegNo);
                    var opLast = new LayoutConversionOperationMove(op1.AbiIndex, freeRegNo, op1.StorageSize, op1.VariableName);
                    ops[0] = op1Modifed;
                    ops.Add(opLast);
                }
            }
        }
    }
}
