﻿#!/usr/bin/env python
#
# Copyright (C)Nintendo All rights reserved.
#
# These coded instructions, statements, and computer programs contain proprietary
# information of Nintendo and/or its licensed developers and are protected by
# national and international copyright laws. They may not be disclosed to third
# parties or copied or duplicated in any form, in whole or in part, without the
# prior written consent of Nintendo.
#
# The content herein is highly confidential and should be handled accordingly.

import sys
import os
import subprocess
import re
from argparse import ArgumentParser

# libclang python binding を検索
libclang_relative_path = os.path.join('Externals', 'Binaries', 'libclang-5.0.0', 'src', 'bindings', 'python')
nintendo_sdk_root = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..' , '..', '..', '..', '..', '..'))

# PATH に追加して import
libclang_binding_abspath = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..' ,'..', '..', '..', libclang_relative_path)
sys.path.append(libclang_binding_abspath)
sys.path.append(os.path.join(nintendo_sdk_root, libclang_relative_path))

import clang
from clang.cindex import *

# ライブラリパスを設定
Config.set_library_path(os.path.join(nintendo_sdk_root, 'Externals/Binaries/libclang-5.0.0/bin'))

# システムインクルードをリストアップ
system_include_path = []
# ToolChains から include パスを追加する
tool_chains_path = os.path.join(nintendo_sdk_root, 'ToolChains')
if 'SIGLO_TOOLCHAINS_ROOT' in os.environ:
    tool_chains_path = os.environ['SIGLO_TOOLCHAINS_ROOT']
for directory in os.listdir(tool_chains_path):
    if directory.startswith('clang-for-nx.') and os.path.isdir(os.path.join(tool_chains_path, directory)):
        version = directory.replace('clang-for-nx.', '')
        clang_include_dir = os.path.join(tool_chains_path, directory, 'nx', 'armv7l', 'include')
        if os.path.exists(clang_include_dir):
            system_include_path.append(os.path.join(clang_include_dir, 'lib', version, 'include'))
            system_include_path.append(os.path.join(clang_include_dir, 'c++', 'v1'))
            system_include_path.append(clang_include_dir)
        nx_clang = os.path.join(tool_chains_path, directory, 'bin', 'nx-clang++.exe')

# include パスをリストアップ
include_path = [
    os.path.join(nintendo_sdk_root, 'Common', 'Configs', 'Targets', 'NX-NXFP2-a32-clang', 'Include'),
    os.path.join(nintendo_sdk_root, 'Programs', 'Alice', 'Include'),
    os.path.join(nintendo_sdk_root, 'Programs', 'Chris', 'Include'),
    os.path.join(nintendo_sdk_root, 'Programs', 'Eris', 'Include'),
    os.path.join(nintendo_sdk_root, 'Programs', 'Iris', 'Include'),
    os.path.join(nintendo_sdk_root, 'Programs', 'Chris', 'Sources', 'Libraries', 'fs')
]

# include オプションを生成
include_options = [ '-isystem' + s for s in system_include_path ]
include_options.extend([ '-I' + s for s in include_path ])

# define オプションを生成
predefineds = [ '__NX__', '__NINTENDO__', 'NN_NINTENDO_SDK', 'NN_SDK_BUILD_LIBRARY', 'NN_SDK_BUILD_DEVELOP', '__horizon__', '__aarch64__' ]
define_options = [ '-D' + s for s in predefineds ]
define_options.append('-U_WIN32')

# clang のオプションを生成
clang_options = ['-std=gnu++11']
clang_options.extend(include_options)
clang_options.extend(define_options)

has_error = False

def eprint(*args, **kwargs):
    """
    標準エラーに print する
    """
    global has_error
    has_error = True
    print(*args, file=sys.stderr, **kwargs)

class FormatParser(object):
    """
    %d や %X などのフォーマット文字列をパースするクラス
    """

    re_format = re.compile('([^%]*)%(0?)([0-9]*)(l{1,2}|z|)([a-zA-Z])([^%]*)')

    class Formatter(object):
        def __init__(self, match):
            self.prefix = match.group(1)
            self.postfix = match.group(6)
            self.padding = match.group(2)
            self.width = match.group(3)
            self.type = match.group(4) + match.group(5)
            format_string = self.prefix + '%' + self.padding + self.width + self.type + self.postfix
            python_format_string = format_string.replace('ll', '').replace('zu', 'u').replace('p', 'x')
            self.python_format_string = python_format_string.replace('\\"', '"')

        def format(self, value):
            return self.python_format_string % value

    def __init__(self, format, fetched_id = -1):
        self.raw_string = format
        self.parse(format)
        self.fetched_id = fetched_id

    def __eq__(self, other):
        return self.raw_string == other.raw_string

    # __eq__ を定義したので、警告抑制のため __hash__ を未定義にする
    __hash__ = None

    def parse(self, format):
        self.format_info = []
        for match in FormatParser.re_format.finditer(format):
            self.format_info.append(FormatParser.Formatter(match))
            python_format_string = self.raw_string.replace('ll', '').replace('zu', 'u').replace('p', 'x')
            self.python_format_string = python_format_string.replace('\\"', '"')

    def format(self, value):
        """
        指定の値でフォーマット
        """
        return self.python_format_string % value

    def auto_format(self):
        """
        フォーマット書式から適当な値でフォーマットする
        """
        if not self.format_info:
            return self.raw_string
        format_string = ''
        for info in self.format_info:
            type = info.type.lower()
            if type == 's':
                format_string += info.format('test')
            elif type == 'd' or type == 'lld':
                if self.fetched_id >= 0:
                    format_string += info.format(sys.maxsize * self.fetched_id)
                else:
                    format_string += info.format(-1)
            elif type == 'u' or type == 'zu':
                format_string += info.format(sys.maxsize)
            elif type == 'x':
                format_string += info.format(sys.maxsize)
            elif type == 'llx':
                format_string += info.format((sys.maxsize << 32) + sys.maxsize)
            elif type == 'p':
                format_string += info.format(sys.maxsize)
            else:
                format_string += info.format(0)
        return format_string


class SourceFile(object):
    """
    libclang から AST 生成するクラス
    """

    def __init__(self, path, options):
        """
        :param str path: ソースコードのパス
        :param list options: clang compiler オプションのリスト
        """
        self.path = path
        index = Index.create()
        clang_args = options
        clang_args.extend(clang_options)
        self.node = index.parse(path, args=clang_args)
        self.dump_diag_error()

    def cursor(self, file, line):
        """
        ファイル名と行番号から cursor を取得
        :param str file: ファイル名
        :param int line: 行番号
        """
        location = self.node.get_location(file, line)
        return Cursor.from_location(self.node, location)

    def get_function(self, line, column):
        """
        行番号からその行が属する関数名を取得
        :param int line: 行番号
        """
        cursor = self.cursor(self.path, (line, column))
        cursor = cursor.semantic_parent

        # cursor から辿って lambda の中にいないか確認する
        def find_lambda_cursor(cursor, line, has_lambda):
            if cursor.kind == CursorKind.LAMBDA_EXPR:
                has_lambda = True
            if cursor.location.line == line:
                return cursor, has_lambda
            for child in cursor.get_children():
                find, child_has_lambda = find_lambda_cursor(child, line, has_lambda)
                if find:
                    return find, child_has_lambda
            return None, has_lambda

        # spelling と displayname の取得
        def find_function(cursor):
            # 関数名はそのまま親をたどる
            while cursor.kind != CursorKind.FUNCTION_DECL:
                # 無名クラスメンバー（lambda）以外のクラスメンバーの場合
                if (cursor.kind == CursorKind.CXX_METHOD or cursor.kind == CursorKind.DESTRUCTOR or cursor.kind == CursorKind.CONSTRUCTOR) and cursor.semantic_parent.spelling:
                    class_name = cursor.semantic_parent.spelling + "::"
                    return class_name + cursor.spelling, class_name + cursor.displayname
                cursor = cursor.semantic_parent
                if cursor is None:
                    return None, None
            namespace = ".".join(re.sub(".*{}@([a-z@]+)@nn.*".format(cursor.spelling), "\\1", cursor.mangled_name).split("@")[::-1])
            if namespace == "nfp":
                # 特定のケースにおいては、名前空間を付随させます
                return namespace + "::" + cursor.spelling, namespace + "::" + cursor.displayname
            return cursor.spelling, cursor.displayname

        location_cursor, has_lambda = find_lambda_cursor(cursor, line, False)
        spelling, displayname = find_function(cursor)
        if has_lambda:
            lambda_name = '::<lambda>::operator ()'
            return spelling + lambda_name, displayname + lambda_name
        return spelling, displayname

    def dump_diag_error(self):
        """
        clang の解析エラーを print
        """
        severity  = ['ignored', 'note', 'warning', 'error', 'fatal']
        for diag in self.node.diagnostics:
            if (diag.severity == 3) and (diag.category_number != 2):
                print('{0}:{1}: {2}: {3}'.format(diag.location.file, diag.location.line, severity[diag.severity], diag.spelling))

    def dump(self):
        """
        デバッグ用情報出力
        """
        self._dump_ast(self.node.cursor)

    def dump_cursor(self, cursor):
        """
        デバッグ用情報出力
        :param Cursor cursor: 指定の coursor
        """
        self._dump_ast(cursor)

    def _dump_ast(self, cursor, indent=''):
        """
        デバッグ用情報出力
        :param Cursor cursor: 指定の coursor
        :param str indent: インデント用
        """
        print("%s%s : %s(%s)" % (indent, cursor.kind.name, cursor.displayname, cursor.location.line))
        for child in cursor.get_children():
            self._dump_ast(child, indent + "  ")


class FsAccessLog(object):
    """
    アクセスログ情報クラス
    """

    re_expression = re.compile('\(\(.*\)\)')
    re_yaml_format = re.compile(' "(.*?[^\\\\])"(?= )')
    re_yaml_elem = re.compile('(.*?)\s*:\s*(\S*)\s*')

    def __init__(self, path, values, indent):
        """
        :param str path: ソースファイルのパス
        :param str values: アクセスログ検出器から取得した文字列
        """
        self.path = path
        self.raw = values
        self.indent = indent
        match = FsAccessLog.re_expression.search(values)
        if match:
            self.expression = match.group(0)
            values = values.replace(self.expression, 'EXP')
        # , 区切りで各種情報が取得できる
        # [0] マーカー
        # [1] アクセスログのタイプ(mount, enable, access)
        # [2] アクセスログのタイプの詳細(application, system)
        # [3] 関数名(明示的な指定 or __FUNCTION__)
        # [4] 行番号
        # [5] 式
        # [6] handle または name
        # [7..] 追加の yaml 要素
        items = values.split(',')
        self.type = items[1].strip()
        self.type_detail = items[2].strip()
        self.function = None
        self.function_name = items[3].strip().strip('"')
        self.line = (int)(items[4])
        va_args_items = items[7:]
        va_args = ''.join(va_args_items)
        self.yaml = {}
        yaml_format = ''
        end = 0
        checking_multiplex_key = ""
        inner_key_id = 0
        for match in self.re_yaml_format.finditer(va_args):
            yaml_format += match.group(1)
            end = match.end()
        yaml_values = va_args[end:].split()
        fetched_id = 0
        for match in self.re_yaml_elem.finditer(yaml_format):
            key = match.group(1).strip()
            value = match.group(2).strip()
            if value == "[{":
                checking_multiplex_key = key
                self.yaml[key] = [[]]
                inner_key_id = 0
            elif key[:3] == "}] ":
                checking_multiplex_key == ""
                key = key[3:]
            elif checking_multiplex_key != "":
                if key[:4] == "} { ":
                    self.yaml[checking_multiplex_key].append([])
                    inner_key_id += 1
                    key = key[4:]
                self.yaml[checking_multiplex_key][inner_key_id].append([key, FormatParser(value, fetched_id)])
                fetched_id += 1
            if checking_multiplex_key == "":
                if key in self.yaml:
                    eprint('{0}:{1}: {2}: duplicate key name ({3})'.format(self.path, self.line, self.type, key))
                self.yaml[key] = FormatParser(value)
        pass

    def dump(self):
        print('{0}:{1}: {2}: {3}'.format(self.path, self.line, self.type, self.function_name))

class FsAccessLogParser(object):
    """
    アクセスログ検出・検証クラス
    """

    def __init__(self):
        self.logs = {}
        self.type_logs = self.get_init_type_logs()
        self.re_detect = re.compile(r'@(FS_ACCESS.+)@')

    def get_init_type_logs(self):
        return {'mount': [], 'enable': [], 'access': []}

    def parse(self, path):
        """
        ファイルのパース
        :param str path: ソースファイルのパス
        :rtype list
        :return: 検出したアクセスログのリスト
        """
        return self._preprocess(path)

    def _preprocess(self, path):
        """
        ソースコードのプリプロセスを実行し、その出力からアクセスログを検出
        :param str path: ソースファイルのパス
        :rtype list
        :return: 検出したアクセスログのリスト
        """
        logs = []
        commands = [nx_clang, '-E', '-P', '-DNN_DETAIL_FS_ACCESS_LOG_DETECT_FOR_TEST', path]
        commands.extend(clang_options)
        proc = subprocess.Popen(commands, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        while True:
            readline = proc.stdout.readline()
            line = str(readline)
            if '@FS_ACCESS' in line:
                for match in self.re_detect.finditer(line):
                    indent = 0
                    for c in line:
                        if c != ' ':
                            break
                        indent += 1
                    logs.append(FsAccessLog(path, match.group(1), indent))

            if not readline and proc.poll() is not None:
                break
        return logs

    def has_accesslog(self, filepath):
        """
        ファイルにアクセスログが含まれているかどうか
        """
        if os.path.splitext(filepath)[1] not in ['.cpp', '.c', '.h', '.hpp' ]:
            return False
        with open(filepath, 'r', encoding='utf8') as file:
            line = file.readline()
            while line:
                if 'fs_AccessLog.h' in line:
                    return True
                # include 文が namespace よりあとに来ることはない想定
                if 'namespace' in line:
                    return False
                line = file.readline()
        return False

    def checkdir(self, path):
        """
        ディレクトリ以下のファイルからアクセスログを検出し、検証します
        :param str path: 対象ディレクトリパス
        """
        for name in os.listdir(path):
            if not name.startswith('.'):
                filepath = os.path.join(path, name)
                if os.path.isdir(filepath):
                    self.checkdir(filepath)
                else:
                    self.checkfile(filepath)

    def checkfile(self, filepath):
        """
        ファイルからアクセスログを検出し、検証します
        :param str filepath: ソースファイルのパス
        """
        if not self.has_accesslog(filepath):
            return
        logs = self.parse(filepath)
        self.logs[filepath] = logs
        if len(logs) == 0:
            return
        # 関数名を libclang から取得する
        need_ast_logs = filter(lambda log: log.function_name == '__FUNCTION__', logs)
        if need_ast_logs:
            source = SourceFile(filepath, [])
            for log in need_ast_logs:
                function_name, log.function = source.get_function(log.line, log.indent + 1)
                log.function_name = function_name

        # アクセスログの情報出力とアクセスログのタイプごとに分類
        type_logs_in_file = self.get_init_type_logs()
        for log in logs:
            log.dump()
            type_logs_in_file[log.type].append(log)
            self.type_logs[log.type].append(log)

        # MOUNT と ENABLE の検査
        self._check_mount_and_enable(type_logs_in_file)
        # class_name の検査
        self._check_class_name(type_logs_in_file)
        # Open 関数の検査
        self._check_open_access_has_path(type_logs_in_file)

    def _check_mount_and_enable(self, type_logs_in_file):
        """
        MOUNT と ENABLE の詳細タイプが一致するかどうか
        また、どちらかの対応忘れを検索
        """
        for log in type_logs_in_file['mount']:
            # 関数シグネチャおよび詳細タイプ（application or system）が一致する enabler を検索する
            if any([ log.function == x.function and log.type_detail == x.type_detail for x in type_logs_in_file['enable'] ]):
                pass
            else:
                if log.type_detail == 'system':
                    enable_macro = 'NN_DETAIL_FS_ACCESS_LOG_SYSTEM_FSACCESSOR_ENABLE'
                else:
                    enable_macro = 'NN_DETAIL_FS_ACCESS_LOG_FSACCESSOR_ENABLE'
                eprint('{0}:{1}: {2} is not call {3}'.format(log.path, log.line, log.function, enable_macro))
            # 関数シグネチャが一致し、line が異なる mount を検索する
            if any([ log.function == x.function and log.line != x.line for x in type_logs_in_file['mount'] ]):
                eprint('{0}:{1}: {2} is duplicated NN_DETAIL_FS_ACCESS_LOG_(SYSTEM_)MOUNT'.format(log.path, log.line, log.function))

        for log in type_logs_in_file['enable']:
            # 関数シグネチャおよび詳細タイプ（application or system）が一致する mount を検索する
            if any([ log.function == x.function and log.type_detail == x.type_detail for x in type_logs_in_file['mount'] ]):
                pass
            else:
                if log.type_detail == 'system':
                    mount_macro = 'NN_DETAIL_FS_ACCESS_LOG_SYSTEM_MOUNT'
                else:
                    mount_macro = 'NN_DETAIL_FS_ACCESS_LOG_MOUNT'
                eprint('{0}:{1}: {2} is not call {3}'.format(log.path, log.line, log.function, mount_macro))
            # 関数シグネチャが一致し、line が異なる enabler を検索する
            if any([ log.function == x.function and log.line != x.line for x in type_logs_in_file['enable'] ]):
                eprint('{0}:{1}: {2} is duplicated NN_DETAIL_FS_ACCESS_LOG_(SYSTEM_)FSACCESSOR_ENABLE'.format(log.path, log.line, log.function))

    def _check_class_name(self, type_logs_in_file):
        for logs in type_logs_in_file.values():
            for log in logs:
                if 'class_name' in log.yaml:
                    class_name = log.yaml['class_name'].raw_string
                    if not log.function_name.startswith(class_name):
                        eprint('{0}:{1}: class_name yaml element ({2}) of {3} does not match'.format(log.path, log.line, class_name, log.function))

    def _check_open_access_has_path(self, type_logs_in_file):
        for log in type_logs_in_file['access']:
            if log.function in ['OpenFile', 'OpenDirectory', 'DeliveryCacheDirectory::Open', 'DeliveryCacheFile::Open']:
                if 'path' in log.yaml:
                    eprint('{0}:{1}: {2} has not path yaml element'.format(log.path, log.line, log.function))

    def get_function_name_set(self, type_name, type_detail):
        """
        指定のタイプのログの関数名セットを取得
        :param str type_name: アクセスログのタイプ
        :param str type_detail: 詳細タイプ
        :rtype set
        :return: ソート済み関数セット
        """
        functions = set()
        if type_name in self.type_logs:
            for log in self.type_logs[type_name]:
                if log.type_detail == type_detail:
                    functions.add(log.function_name)
        return sorted(functions)

    def output_mount_function_list(self, output_dir):
        """
        マウント関数リストをファイルに出力
        :param str output_dir: 出力ファイルディレクトリパス
        """
        with open(os.path.join(output_dir, 'mount.txt'), 'w') as file:
            mount_functions = self.get_function_name_set('mount', 'application')
            for function in mount_functions:
                file.write(function + '\n')

    def output_system_mount_function_list(self, output_dir):
        """
        システムマウント関数リストをファイルに出力
        :param str output_dir: 出力ファイルディレクトリパス
        """
        with open(os.path.join(output_dir, 'system_mount.txt'), 'w') as file:
            mount_functions = self.get_function_name_set('mount', 'system')
            for function in mount_functions:
                file.write(function + '\n')

    def output_access_function_list(self, output_dir):
        """
        マウント・システムマウント関数以外のアクセスログ対応関数リストをファイルに出力
        :param str output_dir: 出力ファイルディレクトリパス
        """
        with open(os.path.join(output_dir, 'access.txt'), 'w') as file:
            access_functions = self.get_function_name_set('access', 'application')
            for function in access_functions:
                file.write(function + '\n')

    def output_system_access_function_list(self, output_dir):
        """
        マウント・システムマウント関数以外のシステム向けアクセスログ対応関数リストをファイルに出力
        :param str output_dir: 出力ファイルディレクトリパス
        """
        with open(os.path.join(output_dir, 'system_access.txt'), 'w') as file:
            access_functions = self.get_function_name_set('access', 'system')
            for function in access_functions:
                file.write(function + '\n')

    def output_all_yaml_accesslog(self, output_dir):
        """
        検出した yaml 要素をすべて含むアクセスログを生成
        :param str output_dir: 出力ファイルディレクトリパス
        """
        # key: [ "format"... ] な dict を生成（同じ key でもフォーマットが違う場合がある）
        yaml = {}
        for logs in self.logs.values():
            for log in logs:
                for key,format in log.yaml.items():
                    if key in yaml:
                        if type(yaml[key][0]) is list:
                            pass
                        elif format not in yaml[key]:
                            yaml[key].append(format)
                    else:
                        yaml[key] = [format]
        accesslogs = [ 'FS_ACCESS: { start: 0, end: 1, result: 0x00000000, handle: 0x0000000000000000, function: "Hoge"' ]
        for key,values in yaml.items():
            while len(accesslogs) < len(values):
                accesslogs.append(accesslogs[0])
            for i in range(len(accesslogs)):
                if type(values[0]) is list:
                    value = "["
                    for j, inner_yaml in enumerate(values[0]):
                        if j != 0:
                            value += ", "
                        value += "{ "
                        for k, inner_kv in enumerate(inner_yaml):
                            if k != 0:
                                value += ", "
                            value += inner_kv[0] + ': ' + inner_kv[1].auto_format()
                        value += " }"
                    value += "]"
                elif i < len(values):
                    value = values[i].auto_format()
                else:
                    value = values[0].auto_format()
                accesslogs[i] += ', ' + key + ': ' + value

        for i in range(len(accesslogs)):
            accesslogs[i] += '}'

        with open(os.path.join(output_dir, 'all_yaml_accesslog.txt'), 'w') as file:
            for accesslog in accesslogs:
                file.write(accesslog + '\n')


# コマンドラインオプション
def parse_command_line():
    parser = ArgumentParser()
    parser.add_argument(
        '-o',
        '--output',
        metavar='DIRECTORY',
        dest='output_dir',
        help='output file directory.'
    )
    parser.add_argument(
        'file',
        metavar='PATH',
        nargs='+',
        help='source code file or dircetory'
    )
    options = parser.parse_args()
    return options, parser


def main():
    options, argp_arser = parse_command_line()
    parser = FsAccessLogParser()
    for filepath in options.file:
        if os.path.isdir(filepath):
            parser.checkdir(filepath)
        else:
            parser.checkfile(filepath)
    if options.output_dir:
        if not os.path.exists(options.output_dir):
            os.mkdir(options.output_dir)
        parser.output_mount_function_list(options.output_dir)
        parser.output_system_mount_function_list(options.output_dir)
        parser.output_access_function_list(options.output_dir)
        parser.output_system_access_function_list(options.output_dir)
        parser.output_all_yaml_accesslog(options.output_dir)
    if has_error:
        sys.exit(1)

if __name__ == '__main__':
    main()
