﻿// --------------------------------------------------------------------------------
// <copyright>
// Copyright (C)Nintendo. All rights reserved.
//
// These coded instructions, statements, and computer programs contain proprietary
// information of Nintendo and/or its licensed developers and are protected by
// national and international copyright laws. They may not be disclosed to third
// parties or copied or duplicated in any form, in whole or in part, without the
// prior written consent of Nintendo.
//
// The content herein is highly confidential and should be handled accordingly.
// </copyright>
// --------------------------------------------------------------------------------
using System;
using System.IO;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
using MakeNso.Elf;

namespace MakeNso
{
    /// <summary>
    /// NSO オブジェクトのヘッダ情報
    /// </summary>
    [StructLayout(LayoutKind.Sequential)]
    internal struct NsoHeader
    {
        [MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
        public byte[] Signature;
        public uint Version;
        public uint Reserved1;
        public uint Flags;
        public uint TextFileOffset;
        public uint TextMemoryOffset;
        public uint TextSize;
        public uint ModuleNameOffset;
        public uint RoFileOffset;
        public uint RoMemoryOffset;
        public uint RoSize;
        public uint ModuleNameSize;
        public uint DataFileOffset;
        public uint DataMemoryOffset;
        public uint DataSize;
        public uint BssSize;
        [MarshalAs(UnmanagedType.ByValArray, SizeConst = 0x20)]
        public byte[] ModuleId;
        public uint TextFileSize;
        public uint RoFileSize;
        public uint DataFileSize;
        [MarshalAs(UnmanagedType.ByValArray, SizeConst = 0x1c)]
        public byte[] Reserved2;
        public uint EmbededOffset;
        public uint EmbededSize;
        public uint DynStrOffset;
        public uint DynStrSize;
        public uint DynSymOffset;
        public uint DynSymSize;
        [MarshalAs(UnmanagedType.ByValArray, SizeConst = 0x20)]
        public byte[] TextHash;
        [MarshalAs(UnmanagedType.ByValArray, SizeConst = 0x20)]
        public byte[] RoHash;
        [MarshalAs(UnmanagedType.ByValArray, SizeConst = 0x20)]
        public byte[] DataHash;
    }

    enum NsoHeaderFlags
    {
        TextCompress    = (1 << 0),
        RoCompress      = (1 << 1),
        DataCompress    = (1 << 2),
        TextHash        = (1 << 3),
        RoHash          = (1 << 4),
        DataHash        = (1 << 5),
    };

    /// <summary>
    /// NSO  ファイルを表現するクラス
    /// </summary>
    internal class NsoFile
    {
        private NsoHeader header;
        private uint headerSize;
        private string moduleName;
        private byte[] textBinary;
        private byte[] roBinary;
        private byte[] dataBinary;

        private uint bssMemoryOffset;

        public NsoFile()
        {
            header = new NsoHeader();
            header.Signature = new byte[4];
            header.Signature[0] = (byte)'N';
            header.Signature[1] = (byte)'S';
            header.Signature[2] = (byte)'O';
            header.Signature[3] = (byte)'0';
            header.ModuleId = new byte[0x20];
            header.Reserved2 = new byte[0x34];
            header.TextHash = new byte[0x20];
            header.RoHash = new byte[0x20];
            header.DataHash = new byte[0x20];
            headerSize = (uint)Marshal.SizeOf(typeof(NsoHeader));

            // TORIAEZU
            header.Flags = 0;
        }

        public NsoFile(string nsoPath)
        {
            header = new NsoHeader();
            header.Signature = new byte[4];
            header.ModuleId = new byte[0x20];
            header.Reserved2 = new byte[0x34];
            header.TextHash = new byte[0x20];
            header.RoHash = new byte[0x20];
            header.DataHash = new byte[0x20];
            headerSize = (uint)Marshal.SizeOf(typeof(NsoHeader));

            var buffer = new byte[headerSize];

            using (var fs = File.OpenRead(nsoPath))
            {
                fs.Position = 0;
                fs.Read(buffer, 0, (int)headerSize);

                var handle = GCHandle.Alloc(buffer, GCHandleType.Pinned);
                try
                {
                    header = (NsoHeader)Marshal.PtrToStructure(handle.AddrOfPinnedObject(), typeof(NsoHeader));
                }
                finally
                {
                    handle.Free();
                }

                if (header.Signature[0] != (byte)'N' ||
                        header.Signature[1] != (byte)'S' ||
                        header.Signature[2] != (byte)'O' ||
                        header.Signature[3] != (byte)'0')
                {
                    throw new Exception();
                }

                if (header.TextFileSize > 0)
                {
                    textBinary = new byte[header.TextFileSize];
                    fs.Position = header.TextFileOffset;
                    fs.Read(textBinary, 0, (int)header.TextFileSize);
                }

                if (header.RoFileSize > 0)
                {
                    roBinary = new byte[header.RoFileSize];
                    fs.Position = header.RoFileOffset;
                    fs.Read(roBinary, 0, (int)header.RoFileSize);
                }

                if (header.DataFileSize > 0)
                {
                    dataBinary = new byte[header.DataFileSize];
                    fs.Position = header.DataFileOffset;
                    fs.Read(dataBinary, 0, (int)header.DataFileSize);
                }
            }
        }

        /// <summary>
        /// モジュール名を設定します
        /// </summary>
        /// <param name="moduleName">モジュール名</param>
        public void SetModuleName(string moduleName)
        {
            this.moduleName = moduleName;
            header.ModuleNameSize = (uint)moduleName.Length + 1;
        }

        public void SetModuleId(byte[] moduleId)
        {
            moduleId.CopyTo(header.ModuleId, 0);
        }

        /// <summary>
        /// api_infoセクションの情報を設定します
        /// </summary>
        /// <param name="info">api_infoセクションのオフセット</param>
        /// <param name="info">api_infoセクションのサイズ</param>
        public void SetApiInfo(ulong offset, ulong size)
        {
            header.EmbededOffset = (uint)offset;
            header.EmbededSize = (uint)size;
        }

        /// <summary>
        /// dynstrセクションの情報を設定します
        /// </summary>
        /// <param name="info">dynstrセクションのオフセット</param>
        /// <param name="info">dynstrセクションのサイズ</param>
        public void SetDynStrInfo(ulong offset, ulong size)
        {
            header.DynStrOffset = (uint)offset;
            header.DynStrSize = (uint)size;
        }

        /// <summary>
        /// dynsymセクションの情報を設定します
        /// </summary>
        /// <param name="info">dynsymセクションのオフセット</param>
        /// <param name="info">dynsymセクションのサイズ</param>
        public void SetDynSymInfo(ulong offset, ulong size)
        {
            header.DynSymOffset = (uint)offset;
            header.DynSymSize = (uint)size;
        }

        public bool CompressMode { get; set; }

        private byte[] Compress(byte[] srcData, int bufferSize)
        {
            if (!CompressMode)
            {
                throw new Exception();
            }
            int srcSize = srcData.Length;

            int compressBufferSize = Lz4.LZ4_compressBound(srcSize);
            if (compressBufferSize <= 0)
            {
                throw new Exception();
            }

            var compressData = new byte[compressBufferSize];

            // 圧縮
            int compSize = Lz4.LZ4_compress_default(srcData, compressData, srcSize, compressBufferSize);
            if (!(0 < compSize && compSize < srcSize))
            {
                // Console.WriteLine("LZ4_compress_default: error {0}", compSize);
                throw new Exception();
            }
            Array.Resize(ref compressData, compSize);

            // in-place 伸長検証
            {
                var verifyData = new byte[bufferSize];
                Array.Copy(compressData, 0, verifyData, bufferSize - compSize, compSize);

                // LZ4呼び出し
                int ret = Lz4.LZ4_decompress_safe(verifyData, bufferSize - compSize, verifyData, 0, compSize, srcSize);
                if (ret != srcSize)
                {
                    // Console.WriteLine("LZ4_decompress_safe(in-place): error {0}", ret);
                    throw new Exception();
                }
                Array.Resize(ref verifyData, srcSize);
                if (!verifyData.SequenceEqual(srcData))
                {
                    // Console.WriteLine("LZ4_decompress_safe(in-place): verify error {0}", ret);
                    throw new Exception();
                }
            }

            return compressData;
        }

        private byte[] Deompress(byte[] compData, int bufferSize)
        {
            var decompressData = new byte[bufferSize];

            // LZ4呼び出し
            int ret = Lz4.LZ4_decompress_safe(compData, 0, decompressData, 0, compData.Length, bufferSize);
            if (ret != bufferSize)
            {
                throw new Exception();
            }
            return decompressData;
        }

        public void ExtractData(FileStream fs)
        {
            if (header.TextFileSize > 0)
            {
                var textMemory = textBinary;
                if ((header.Flags & (uint)NsoHeaderFlags.TextCompress) != 0)
                {
                    textMemory = Deompress(textBinary, (int)header.TextSize);
                }
                if ((header.Flags & (uint)NsoHeaderFlags.TextHash) != 0)
                {
                    SHA256Managed sha256 = new SHA256Managed();
                    var hash = sha256.ComputeHash(textMemory);
                    if (!hash.SequenceEqual(header.TextHash))
                    {
                        throw new Exception();
                    }
                }
                fs.Position = header.TextMemoryOffset;
                fs.Write(textMemory, 0, textMemory.Length);
            }

            if (header.RoFileSize > 0)
            {
                var roMemory = roBinary;
                if ((header.Flags & (uint)NsoHeaderFlags.RoCompress) != 0)
                {
                    roMemory = Deompress(roBinary, (int)header.RoSize);
                }
                if ((header.Flags & (uint)NsoHeaderFlags.RoHash) != 0)
                {
                    SHA256Managed sha256 = new SHA256Managed();
                    var hash = sha256.ComputeHash(roMemory);
                    if (!hash.SequenceEqual(header.RoHash))
                    {
                        throw new Exception();
                    }
                }
                fs.Position = header.RoMemoryOffset;
                fs.Write(roMemory, 0, roMemory.Length);
            }

            if (header.DataFileSize > 0)
            {
                var dataMemory = dataBinary;
                if ((header.Flags & (uint)NsoHeaderFlags.DataCompress) != 0)
                {
                    dataMemory = Deompress(dataBinary, (int)header.DataSize);
                }
                if ((header.Flags & (uint)NsoHeaderFlags.DataHash) != 0)
                {
                    SHA256Managed sha256 = new SHA256Managed();
                    var hash = sha256.ComputeHash(dataMemory);
                    if (!hash.SequenceEqual(header.DataHash))
                    {
                        throw new Exception();
                    }
                }
                fs.Position = header.DataMemoryOffset;
                fs.Write(dataMemory, 0, dataMemory.Length);
            }
        }


        /// <summary>
        /// Text セグメントの情報を設定します
        /// </summary>
        /// <param name="info">セグメントの情報</param>
        public void SetTextSegment(ElfSegmentInfo info)
        {
            if (info == null)
            {
                return;
            }

            if (info.VirtualAddress != 0)
            {
                throw new ArgumentException(string.Format(Properties.Resources.Message_InvalidTextAddress, info.VirtualAddress));
            }
            if (info.MemorySize > 0x7fffffff)
            {
                throw new ArgumentException(string.Format(Properties.Resources.Message_InvalidSegmentSize, "Text", info.MemorySize));
            }

            header.TextMemoryOffset = (uint)info.VirtualAddress;
            header.TextSize = (uint)info.MemorySize;

            textBinary = info.GetContents();
            header.TextFileSize = (uint)textBinary.Length;
        }

        /// <summary>
        /// Ro セグメントの情報を設定します
        /// </summary>
        /// <param name="info">セグメントの情報</param>
        public void SetRoSegment(ElfSegmentInfo info)
        {
            if (info == null)
            {
                return;
            }

            if ((info.VirtualAddress & 0xFFF) > 0)
            {
                throw new ArgumentException(string.Format(Properties.Resources.Message_InvalidSegmentAlign, "RO", info.VirtualAddress));
            }
            if (info.MemorySize > 0x7fffffff)
            {
                throw new ArgumentException(string.Format(Properties.Resources.Message_InvalidSegmentSize, "RO", info.MemorySize));
            }

            header.RoMemoryOffset = (uint)info.VirtualAddress;
            header.RoSize = (uint)info.MemorySize;

            roBinary = info.GetContents();
            header.RoFileSize = (uint)roBinary.Length;
        }

        /// <summary>
        /// Data セグメントの情報を設定します
        /// </summary>
        /// <param name="info">セグメントの情報</param>
        public void SetDataSegment(ElfSegmentInfo info)
        {
            if (info == null)
            {
                return;
            }

            if ((info.VirtualAddress & 0xFFF) > 0)
            {
                throw new ArgumentException(string.Format(Properties.Resources.Message_InvalidSegmentAlign, "Data", info.VirtualAddress));
            }
            if (info.MemorySize > 0x7fffffff)
            {
                throw new ArgumentException(string.Format(Properties.Resources.Message_InvalidSegmentSize, "Data", info.MemorySize));
            }

            header.DataMemoryOffset = (uint)info.VirtualAddress;
            header.DataSize = (uint)info.MemorySize;

            dataBinary = info.GetContents();
            header.DataFileSize = (uint)dataBinary.Length;
        }

        /// <summary>
        /// Bss セグメントの情報を設定します
        /// </summary>
        /// <param name="info">セグメントの情報</param>
        public void SetBssSegment(ElfSegmentInfo info)
        {
            if (info == null)
            {
                return;
            }

            if ((info.VirtualAddress & 0xF) > 0)
            {
                throw new ArgumentException(string.Format(Properties.Resources.Message_InvalidBssSegmentAlign, info.VirtualAddress));
            }

            header.BssSize = (uint)info.MemorySize;
            bssMemoryOffset = (uint)info.VirtualAddress;

            if (header.BssSize > 0)
            {
                // Data サイズが 0 だったときは、配置オフセットに BSS の先頭アドレスを入れる
                if (header.DataSize == 0)
                {
                    header.RoMemoryOffset = bssMemoryOffset;
                }

                // Data の終端と BSS の先端との差分を BSS のサイズに足し合わせる
                if (bssMemoryOffset < header.DataMemoryOffset + header.DataSize)
                {
                    throw new ArgumentException(Properties.Resources.Message_InvalidBssSegments);
                }
                uint dataEndAddr = header.DataMemoryOffset + header.DataSize;
                header.BssSize += (dataEndAddr > 0) ? (uint)(bssMemoryOffset - dataEndAddr) : 0;
            }
        }

        public void CompressTextSegment()
        {
            try
            {
                var srcData = textBinary;
                var end = (header.DataMemoryOffset + header.DataSize + header.BssSize + 0xFFF) & ~0xFFF;
                var bin = Compress(srcData, (int)(end - header.TextMemoryOffset));

                SHA256Managed sha256 = new SHA256Managed();
                var hash = sha256.ComputeHash(srcData);
                hash.CopyTo(header.TextHash, 0);

                textBinary = bin;
                header.Flags |= (uint)NsoHeaderFlags.TextCompress;
                header.Flags |= (uint)NsoHeaderFlags.TextHash;
                header.TextFileSize = (uint)textBinary.Length;
            }
            catch (Exception)
            {
            }
        }

        public void CompressRoSegment()
        {
            try
            {
                var srcData = roBinary;
                var end = (header.DataMemoryOffset + header.DataSize + header.BssSize + 0xFFF) & ~0xFFF;
                var bin = Compress(srcData, (int)(end - header.RoMemoryOffset));

                SHA256Managed sha256 = new SHA256Managed();
                var hash = sha256.ComputeHash(srcData);
                hash.CopyTo(header.RoHash, 0);

                roBinary = bin;
                header.Flags |= (uint)NsoHeaderFlags.RoCompress;
                header.Flags |= (uint)NsoHeaderFlags.RoHash;
                header.RoFileSize = (uint)roBinary.Length;
            }
            catch (Exception)
            {
            }
        }

        public void CompressDataSegment()
        {
            try
            {
                var srcData = dataBinary;
                var end = (header.DataMemoryOffset + header.DataSize + header.BssSize + 0xFFF) & ~0xFFF;
                var bin = Compress(srcData, (int)(end - header.DataMemoryOffset));

                SHA256Managed sha256 = new SHA256Managed();
                var hash = sha256.ComputeHash(srcData);
                hash.CopyTo(header.DataHash, 0);

                dataBinary = bin;
                header.Flags |= (uint)NsoHeaderFlags.DataCompress;
                header.Flags |= (uint)NsoHeaderFlags.DataHash;
                header.DataFileSize = (uint)dataBinary.Length;
            }
            catch (Exception)
            {
            }
        }

        /// <summary>
        /// 設定された情報から位置情報を調整します。
        /// </summary>
        public void CalcPosition()
        {
            // ファイルオフセットの計算
            header.ModuleNameOffset = headerSize;
            header.TextFileOffset = header.ModuleNameOffset + header.ModuleNameSize;
            header.RoFileOffset = header.TextFileOffset + header.TextFileSize;
            header.DataFileOffset = header.RoFileOffset + header.RoFileSize;
        }

        /// <summary>
        /// NSO ファイルを出力します
        /// </summary>
        /// <param name="fs">出力ファイルストリーム</param>
        public void WriteData(FileStream fs)
        {
            BinaryWriter bw = new BinaryWriter(fs);

            var buffer = new byte[headerSize];
            var handle = GCHandle.Alloc(buffer, GCHandleType.Pinned);

            try
            {
                Marshal.StructureToPtr(header, handle.AddrOfPinnedObject(), false);
            }
            finally
            {
                handle.Free();
            }
            bw.Write(buffer);

            bw.Write(moduleName);
            if (header.TextSize > 0)
            {
                bw.Write(textBinary);
            }
            if (header.RoSize > 0)
            {
                bw.Write(roBinary);
            }
            if (header.DataSize > 0)
            {
                bw.Write(dataBinary);
            }
        }

        /// <summary>
        /// NSO ファイルのヘッダ情報を出力します
        /// </summary>
        public void PrintNsoHeader()
        {
            Console.Write("signature: ");
            for (int i = 0; i < header.Signature.GetLength(0); i++)
            {
                Console.Write("{0}", (char)header.Signature[i]);
            }
            Console.WriteLine();
            Console.WriteLine("flags: 0x{0:X}", header.Flags);
            Console.WriteLine("text file offset: 0x{0:X}", header.TextFileOffset);
            Console.WriteLine("text mem offset: 0x{0:X}", header.TextMemoryOffset);
            Console.WriteLine("text file size: 0x{0:X}", header.TextSize);
            Console.WriteLine("module offset: 0x{0:X}", header.ModuleNameOffset);
            Console.WriteLine("ro file offset: 0x{0:X}", header.RoFileOffset);
            Console.WriteLine("ro mem offset: 0x{0:X}", header.RoMemoryOffset);
            Console.WriteLine("ro file size: 0x{0:X}", header.RoSize);
            Console.WriteLine("module file size: 0x{0:X}", header.ModuleNameSize);
            Console.WriteLine("data file offset: 0x{0:X}", header.DataFileOffset);
            Console.WriteLine("data mem offset: 0x{0:X}", header.DataMemoryOffset);
            Console.WriteLine("data file size: 0x{0:X}", header.DataSize);
            Console.WriteLine("bss file size: 0x{0:X}", header.BssSize);
            Console.WriteLine("embeded offset: 0x{0:X}", header.EmbededOffset);
            Console.WriteLine("embeded size: 0x{0:X}", header.EmbededSize);
        }
    }
}
