﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <nn/nn_Common.h>
#include <type_traits>
#include "../../kern_Assert.h"

namespace nn {
    namespace kern {
        namespace ARMv7A
        {
            namespace detail
            {
                template <typename T>
                struct LoadStoreRegEx
                {
                    static T LoadRegEx (volatile T* ptr)
                    {
                        T val;
                        static_assert(sizeof(T) == sizeof(int32_t), "");
                        asm volatile("ldrex %0, [%1]" : "=&r"(val) : "r"(ptr): "memory");
                        asm volatile("dmb ish":::"memory");
                        return val;
                    }
                    static int StoreRegEx(T val, volatile T* ptr)
                    {
                        int result;
                        static_assert(sizeof(T) == sizeof(int32_t), "");
                        asm volatile("dmb ish":::"memory");
                        asm volatile("strex %0, %1, [%2]" :"=&r"(result) :"r"(val), "r"(ptr): "memory");
                        return result;
                    }
                };
                template <>
                struct LoadStoreRegEx<int8_t>
                {
                    static int8_t LoadRegEx (volatile int8_t* ptr)
                    {
                        int8_t val;
                        asm volatile("ldrexb %0, [%1]" : "=&r"(val) : "r"(ptr): "memory");
                        asm volatile("dmb ish":::"memory");
                        return val;
                    }
                    static int StoreRegEx(int8_t val, volatile int8_t* ptr)
                    {
                        int result;
                        asm volatile("dmb ish":::"memory");
                        asm volatile("strexb %0, %1, [%2]" :"=&r"(result) :"r"(val), "r"(ptr): "memory");
                        return result;
                    }
                };
                template <>
                struct LoadStoreRegEx<int16_t>
                {
                    static int16_t LoadRegEx (volatile int16_t* ptr)
                    {
                        int16_t val;
                        asm volatile("ldrexh %0, [%1]" : "=&r"(val) : "r"(ptr): "memory");
                        asm volatile("dmb ish":::"memory");
                        return val;
                    }
                    static int StoreRegEx(int16_t val, volatile int16_t* ptr)
                    {
                        int result;
                        asm volatile("dmb ish":::"memory");
                        asm volatile("strexh %0, %1, [%2]" :"=&r"(result) :"r"(val), "r"(ptr): "memory");
                        return result;
                    }
                };
                template <>
                struct LoadStoreRegEx<int64_t>
                {
                    static int64_t LoadRegEx (volatile int64_t* ptr)
                    {
                        int64_t val;
                        asm volatile("ldrexd %0, %H0, [%1]" : "=&r"(val) : "r"(ptr): "memory");
                        asm volatile("dmb ish":::"memory");
                        return val;
                    }
                    static int StoreRegEx(int64_t val, volatile int64_t* ptr)
                    {
                        int result;
                        asm volatile("dmb ish":::"memory");
                        asm volatile("strexd %0, %1, %H1, [%2]" :"=&r"(result) :"r"(val), "r"(ptr): "memory");
                        return result;
                    }
                };
            }

            class Interlocked
            {
            private:

                template <typename T, typename = void> struct AtomicStorageSelecter;

                template <typename T> struct AtomicStorageSelecter<T, typename std::enable_if<sizeof(T) == sizeof(int64_t)>::type>
                {
                    typedef int64_t Type;
                };

                template <typename T> struct AtomicStorageSelecter<T, typename std::enable_if<sizeof(T) == sizeof(int32_t)>::type>
                {
                    typedef int32_t Type;
                };

                template <typename T> struct AtomicStorageSelecter<T, typename std::enable_if<sizeof(T) == sizeof(int16_t)>::type>
                {
                    typedef int16_t Type;
                };

                template <typename T> struct AtomicStorageSelecter<T, typename std::enable_if<sizeof(T) == sizeof(int8_t)>::type>
                {
                    typedef int8_t Type;
                };

            public:
                static int32_t CompareAndSwap(int32_t* pTarget, int32_t comp, int32_t swap);
                static int32_t Swap(int32_t* pTarget, int32_t value);
                static int32_t Increment(int32_t* pTarget);
                static int32_t Decrement(int32_t* pTarget);
                static int32_t Add(int32_t* pTarget, int32_t value);
                static int32_t Substract(int32_t* pTarget, int32_t value);
                static int32_t BitwiseOr(int32_t* pTarget, int32_t value);
                static int32_t BitwiseAnd(int32_t* pTarget, int32_t value);
                static int32_t BitwiseXor(int32_t* pTarget, int32_t value);
                static int32_t BitwiseNot(int32_t* pTarget);
                static int64_t Read(int64_t* pTarget) { return *pTarget; }
                static uint64_t CompareAndSwap(uint64_t* pTarget, const uint64_t& comp, uint64_t swap);

                template <typename T>
                static T* CompareAndSwap(T** pTarget, T* comp, T* swap)
                {
                    return reinterpret_cast<T*>(
                        CompareAndSwap( reinterpret_cast<int32_t*>(pTarget),
                            reinterpret_cast<int32_t>(comp),
                            reinterpret_cast<int32_t>(swap) ));
                }

                template <typename T>
                static T* Swap(T** pTarget, T* value)
                {
                    return reinterpret_cast<T*>(
                        Swap( reinterpret_cast<int32_t*>(pTarget),
                            reinterpret_cast<int32_t>(value) ));
                }

                template <typename T, typename UpdateFunc>
                static bool AtomicUpdate(volatile T* p, UpdateFunc* pUpdate, typename std::enable_if<sizeof(T) <= sizeof(int64_t)>::type* = 0)
                {
                    typedef typename AtomicStorageSelecter<T>::Type Storage;
                    typedef detail::LoadStoreRegEx<Storage> LoadStore;

                    // T がPOD的に扱えることを保障するためにunionを使う
                    union U
                    {
                        T v;
                        Storage n;
                    };

                    U x;

                    for(;;)
                    {
                        x.n = LoadStore::LoadRegEx(reinterpret_cast<volatile Storage*&>(p));

                        if (!(*pUpdate)(&x.v))
                        {
                            asm volatile("clrex":::"memory");
                            return false;
                        }

                        if ( LoadStore::StoreRegEx(x.n, reinterpret_cast<volatile Storage*&>(p)) == 0 )
                        {
                            return true;
                        }
                    }
                }

                template <typename T>
                static int CompareAndSwapWeak(volatile T* p, T compValue, T setValue, typename std::enable_if<sizeof(T) <= sizeof(int64_t)>::type* = 0)
                {
                    typedef typename AtomicStorageSelecter<T>::Type Storage;
                    typedef detail::LoadStoreRegEx<Storage> LoadStore;

                    // T がPOD的に扱えることを保障するためにunionを使う
                    union U
                    {
                        T v;
                        Storage n;
                    };

                    U x;

                    x.n = LoadStore::LoadRegEx(reinterpret_cast<volatile Storage*&>(p));

                    int ret = 1;

                    if( x.v != compValue )
                    {
                        LoadStore::StoreRegEx(x.n, reinterpret_cast<volatile Storage*&>(p));
                    }
                    else
                    {
                        x.v = setValue;
                        ret = LoadStore::StoreRegEx(x.n, reinterpret_cast<volatile Storage*&>(p));
                    }

                    return ret;
                }
            };
        }
    }
}

