﻿// 文字コード:UTF-8
/// @file
#include "lib/Base64.hpp"

//------------------------------------------------------------------------------
namespace {

// Base64パディング文字
const char fPaddingChar = '=';

// Base64エンコード文字列を4文字ずつ取り出して復号する
class fBase64InputStream
{
public:
    fBase64InputStream(const char* aSrc, int aSrcSize);
    bool isValid() const { return mIsValid; }
    bool isEnd() const { return mIsEnd; }
    void step();
    int numOctet() const { return mNumOctet; }
    uint8_t octet(int aIndex) const;

private:
    static const int OctetCapacity = 3;
    static int GetValue(char aChr);
    const char* mPtr;
    const char* mEndPtr;
    bool mIsValid;
    bool mIsEnd;
    uint8_t mOctet[OctetCapacity];
    int mNumOctet;
};

//------------------------------------------------------------------------------
fBase64InputStream::fBase64InputStream(const char* aSrc, int aSrcSize)
: mPtr(aSrc)
, mEndPtr(aSrc + aSrcSize - 1)
, mIsValid(true)
, mIsEnd(false)
, mOctet()
, mNumOctet(0)
{
    for (int i = 0; i < OctetCapacity; ++i) {
        mOctet[i] = 0;
    }
    SYS_ASSERT(0 <= aSrcSize);
    if (aSrcSize == 0) {
        mIsEnd = true;
    } else {
        step();
    }
}

//------------------------------------------------------------------------------
void fBase64InputStream::step()
{
    SYS_ASSERT(isValid());
    SYS_ASSERT(!isEnd());
    const int QuantumCapacity = 4;
    char quantum[QuantumCapacity];

    mNumOctet = 0;
    // 4文字取り出す
    {
        int index = 0;
        while (true) {
            if (mEndPtr < mPtr) {
                if (index != 0) {
                    // Base64エンコード文字列の文字数が4の倍数でなければ不正
                    mIsValid = false;
                } else {
                    mIsEnd = true;
                }
                return;
            }
            const char ch = *mPtr++;
            if (::std::isspace(ch)) {
                // whitespace はスキップ
                continue;
            }
            if (ch == fPaddingChar) {
                if (index < 2) {
                    mIsValid = false;
                    return;
                }
                mIsEnd = true;
                quantum[index++] = ch;
            } else if (GetValue(ch) < 0) {
                // 不正文字を発見
                mIsValid = false;
                return;
            } else {
                quantum[index++] = ch;
            }
            if (index == QuantumCapacity) {
                if (mEndPtr < mPtr) {
                    mIsEnd = true;
                }
                break;
            }
        }
    }
    // オクテット数の取得と復号化
    uint32_t octets = 0;
    if (quantum[2] == fPaddingChar) {
        mNumOctet = 1;
        octets =
            (GetValue(quantum[0]) << 18) |
            (GetValue(quantum[1]) << 12) |
            0;
    } else if (quantum[3] == fPaddingChar) {
        mNumOctet = 2;
        octets =
            (GetValue(quantum[0]) << 18) |
            (GetValue(quantum[1]) << 12) |
            (GetValue(quantum[2]) <<  6) |
            0;
    } else {
        mNumOctet = 3;
        octets =
            (GetValue(quantum[0]) << 18) |
            (GetValue(quantum[1]) << 12) |
            (GetValue(quantum[2]) <<  6) |
            (GetValue(quantum[3]) <<  0) |
            0;
    }
    mOctet[0] = (octets >> 16) & 0xff;
    mOctet[1] = (octets >>  8) & 0xff;
    mOctet[2] = (octets >>  0) & 0xff;
}

//------------------------------------------------------------------------------
uint8_t fBase64InputStream::octet(int aIndex) const
{
    SYS_ASSERT(isValid());
    SYS_ASSERT(0 <= aIndex && aIndex < numOctet());
    return mOctet[aIndex];
}

//------------------------------------------------------------------------------
int fBase64InputStream::GetValue(char aChr)
{
    if ('A' <= aChr && aChr <= 'Z') {
        return aChr - 'A';
    } else if ('a' <= aChr && aChr <= 'z') {
        return aChr - 'a' + 26;
    } else if ('0' <= aChr && aChr <= '9') {
        return aChr - '0' + 52;
    } else if ('+' == aChr) {
        return 62;
    } else if ('/' == aChr) {
        return 63;
    } else {
        return -1;
    }
}

} // namespace

//------------------------------------------------------------------------------
namespace lib {

//------------------------------------------------------------------------------
int Base64::Decode(const char* aSrc, int aSrcSize, uint8_t* aDst, int aDstSize)
{
    SYS_ASSERT_POINTER(aSrc);
    SYS_ASSERT(0 <= aSrcSize);
    SYS_ASSERT_POINTER(aDst);
    SYS_ASSERT(0 <= aDstSize);
    fBase64InputStream s(aSrc, aSrcSize);
    int decodedSize = 0;
    uint8_t* p = aDst;
    uint8_t* const end = aDst + aDstSize - 1;
    while (true) {
        if (!s.isValid()) {
            return -1;
        }
        for (int i = 0; i < s.numOctet(); ++i) {
            SYS_ASSERT(p <= end);
            *p++ = s.octet(i);
            ++decodedSize;
        }
        if (s.isEnd()) {
            break;
        }
        s.step();
    }
    return decodedSize;
}

//------------------------------------------------------------------------------
int Base64::GetDecodedSize(const char* aSrc, int aSrcSize)
{
    SYS_ASSERT_POINTER(aSrc);
    SYS_ASSERT(0 <= aSrcSize);
    fBase64InputStream s(aSrc, aSrcSize);
    int decodedSize = 0;
    while (true) {
        if (!s.isValid()) {
            return -1;
        }
        decodedSize += s.numOctet();
        if (s.isEnd()) {
            break;
        }
        s.step();
    }
    return decodedSize;
}

} // namespace
// EOF
