﻿using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using CredentialRegister.Exceptions;

namespace CredentialRegister
{
    public static class Credential
    {
        public struct AuthenticationInfo
        {
            public string Username;
            public string Password;
        }

        public static bool RequireAuthenticationInfo(out AuthenticationInfo authenticationInfo, string caption, string message)
        {
            authenticationInfo.Username = string.Empty;
            authenticationInfo.Password = string.Empty;

            var username = string.Empty;
            var password = string.Empty;

            var outputBuffer = IntPtr.Zero;
            var outputBufferSize = 0;

            var inputBuffer = PackUsernameBuffer(username);

            try
            {
                var ui = new NativeMethods.CREDUI_INFO()
                {
                    pszCaptionText = caption,
                    pszMessageText = message
                };
                ui.cbSize = Marshal.SizeOf(ui);

                var authPackage = 0;
                var save = false;

                var returnCode = NativeMethods.CredUIPromptForWindowsCredentials(
                    uiInfo: ref ui,
                    authError: 0,
                    authPackage: ref authPackage,
                    InAuthBuffer: inputBuffer.Item2,
                    InAuthBufferSize: inputBuffer.Item1,
                    refOutAuthBuffer: out outputBuffer,
                    refOutAuthBufferSize: out outputBufferSize,
                    fSave: ref save,
                    flags: NativeMethods.PromptForWindowsCredentialsFlags.CREDUIWIN_GENERIC);

                if (returnCode != NativeMethods.CredUIReturnCodes.NO_ERROR)
                {
                    if (returnCode == NativeMethods.CredUIReturnCodes.ERROR_CANCELLED)
                    {
                        return false;
                    }
                    else
                    {
                        throw new CredentialException("Failed to require an account.");
                    }
                }
            }
            finally
            {
                if (inputBuffer != null && inputBuffer.Item2 != IntPtr.Zero)
                {
                    Marshal.FreeHGlobal(inputBuffer.Item2);
                }
            }

            try
            {
                UnPackAuthBuffer(outputBuffer, outputBufferSize, out username, out password);
            }
            finally
            {
                if (outputBuffer != IntPtr.Zero)
                {
                    Marshal.FreeHGlobal(outputBuffer);
                }
            }

            authenticationInfo.Username = username;
            authenticationInfo.Password = password;

            return true;
        }

        public static bool FindAuthenticationInfo(string target, out AuthenticationInfo authenticationInfo)
        {
            authenticationInfo.Username = string.Empty;
            authenticationInfo.Password = string.Empty;

            var username = string.Empty;
            var password = string.Empty;

            var credentialPtr = IntPtr.Zero;

            try
            {
                if (!NativeMethods.CredRead(target, NativeMethods.CRED_TYPE.GENERIC, 0, out credentialPtr))
                {
                    var error = Marshal.GetLastWin32Error();

                    if (error == (int)NativeMethods.CredUIReturnCodes.ERROR_NOT_FOUND)
                    {
                        return false;
                    }

                    throw new CredentialException("Failed to read an account.");
                }

                var credential = (NativeMethods.CREDENTIAL)Marshal.PtrToStructure(credentialPtr, typeof(NativeMethods.CREDENTIAL));

                username = credential.userName;
                password = Marshal.PtrToStringUni(credential.credentialBlob, credential.credentialBlobSize / 2);

                authenticationInfo.Username = username;
                authenticationInfo.Password = password;

                return true;
            }
            finally
            {
                if (credentialPtr != IntPtr.Zero)
                {
                    NativeMethods.CredFree(credentialPtr);
                }
            }
        }

        public static void SaveAuthenticationInfo(string target, AuthenticationInfo authenticationInfo)
        {
            var credential = new NativeMethods.CREDENTIAL()
            {
                type = NativeMethods.CRED_TYPE.GENERIC,
                targetName = target,
                credentialBlob = Marshal.StringToCoTaskMemUni(authenticationInfo.Password),
                credentialBlobSize = Encoding.Unicode.GetByteCount(authenticationInfo.Password),
                persist = NativeMethods.CRED_PERSIST.ENTERPRISE,
                attributeCount = 0,
                userName = authenticationInfo.Username
            };

            if (!NativeMethods.CredWrite(ref credential, 0))
            {
                throw new CredentialException("Failed to save an account.");
            }
        }

        public static void DeleteAuthenticationInfo(string target)
        {
            if (!NativeMethods.CredDelete(target, NativeMethods.CRED_TYPE.GENERIC, 0))
            {
                var error = Marshal.GetLastWin32Error();

                if (error != (int)NativeMethods.CredUIReturnCodes.ERROR_NOT_FOUND)
                {
                    throw new CredentialException("Failed to delete an account.");
                }
            }
        }

        private static void UnPackAuthBuffer(IntPtr buffer, int size, out string username, out string password)
        {
            username = string.Empty;
            password = string.Empty;

            var usernameBuffer = new StringBuilder(255);
            var passwordBuffer = new StringBuilder(255);
            var domainBuffer = new StringBuilder(255);

            var userNameSize = 255;
            var passwordSize = 255;
            var domainSize = 255;

            if (!NativeMethods.CredUnPackAuthenticationBuffer(
                dwFlags: 0,
                pAuthBuffer: buffer,
                cbAuthBuffer: size,
                pszUserName: usernameBuffer,
                pcchMaxUserName: ref userNameSize,
                pszDomainName: domainBuffer,
                pcchMaxDomainame: ref domainSize,
                pszPassword: passwordBuffer,
                pcchMaxPassword: ref passwordSize))
            {
                throw new CredentialException(string.Format("Failed to unpack an account. : {0}", GetLastErrorMessage()));
            }

            username = usernameBuffer.ToString();
            password = passwordBuffer.ToString();
        }

        private static Tuple<int, IntPtr> PackUsernameBuffer(string username)
        {
            if (string.IsNullOrWhiteSpace(username))
            {
                return Tuple.Create(0, IntPtr.Zero);
            }

            var size = 0;

            // 必要なサイズを求めます。
            NativeMethods.CredPackAuthenticationBuffer(
                dwFlags: NativeMethods.CRED_PACK.GENERIC_CREDENTIALS,
                pszUserName: username,
                pszPassword: string.Empty,
                pPackedCredentials: IntPtr.Zero,
                pcbPackedCredentials: ref size);

            if (Marshal.GetLastWin32Error() != (int)NativeMethods.CredUIReturnCodes.ERROR_INSUFFICIENT_BUFFER)
            {
                throw new CredentialException(string.Format("Failed to get size of buffer. : {0}", GetLastErrorMessage()));
            }

            var buffer = Marshal.AllocHGlobal(size);

            if (!NativeMethods.CredPackAuthenticationBuffer(
                dwFlags: NativeMethods.CRED_PACK.GENERIC_CREDENTIALS,
                pszUserName: username,
                pszPassword: string.Empty,
                pPackedCredentials: buffer,
                pcbPackedCredentials: ref size))
            {
                throw new CredentialException(string.Format("Failed to pack data. : {0}", GetLastErrorMessage()));
            }

            return Tuple.Create(size, buffer);
        }

        private static string GetLastErrorMessage()
        {
            return GetErrorMessage(Marshal.GetLastWin32Error());
        }

        private static string GetErrorMessage(int lastError)
        {
            return new Win32Exception(lastError).Message;
        }
    }
}
