﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <cstdio>

#include "DecklinkInterface.h"

#define DECKLINK_ASSERT_SUCCEEDED(x)                                                        \
{                                                                                           \
    int hrFailure = 0;                                                                      \
    if (!SUCCEEDED(hrFailure = (x)))                                                        \
    {                                                                                       \
        Print("Assertion on line %i failed (%i): %s\n", __LINE__, hrFailure, #x);           \
        m_errorHandler();                                                                   \
    }                                                                                       \
}

DecklinkInterface::DecklinkInterface(int index, int (*printFunc)(const char* format, va_list args), void (*errorHandler)(void))
    : m_printf(printFunc),
      m_errorHandler(errorHandler),
      m_previewHandler(nullptr),
      m_videoHandler(nullptr),
      m_inputCallback(this),
      m_previewCallback(this),
      m_oldFrameTime(0),
      m_audioChannels(2),
      m_audioSampleType(bmdAudioSampleType16bitInteger),
      m_audioSampleDepth(16),
      m_AudioBitrate(48000),
      m_AudioSampleRate(bmdAudioSampleRate48kHz),
      m_captureEnabled(false)
{
    CoInitialize(nullptr);

    // Find the right decklink
    IDeckLinkIterator* deckLinkIterator;
    IDeckLink* deckLink;
    DECKLINK_ASSERT_SUCCEEDED(CoCreateInstance(CLSID_CDeckLinkIterator, nullptr, CLSCTX_ALL, IID_IDeckLinkIterator, reinterpret_cast<void**>(&deckLinkIterator)));
    DECKLINK_ASSERT_SUCCEEDED(deckLinkIterator->Next(&deckLink));
    for (int i = 0; i < index; ++i)
    {
        DECKLINK_ASSERT_SUCCEEDED(deckLinkIterator->Next(&deckLink));
    }
    DECKLINK_ASSERT_SUCCEEDED(deckLink->QueryInterface(IID_IDeckLinkInput, reinterpret_cast<void**>(&m_pDeckLinkInput)));

    deckLinkIterator->Release();
    deckLink->Release();
}

void DecklinkInterface::StartCapture()
{
    VerifyInactive(__FUNCTION__);

    m_captureEnabled = true;

    // Set the callbacks
    DECKLINK_ASSERT_SUCCEEDED(m_pDeckLinkInput->SetScreenPreviewCallback(&m_previewCallback));
    DECKLINK_ASSERT_SUCCEEDED(m_pDeckLinkInput->SetCallback(&m_inputCallback));

    // Enable whatever the first display mode is
    IDeckLinkDisplayModeIterator* displayModeIterator;
    IDeckLinkDisplayMode* displayMode;
    DECKLINK_ASSERT_SUCCEEDED(m_pDeckLinkInput->GetDisplayModeIterator(&displayModeIterator));
    DECKLINK_ASSERT_SUCCEEDED(displayModeIterator->Next(&displayMode));

    Print("Starting resolution is %ix%i\n", displayMode->GetWidth(), displayMode->GetHeight());
    BMDTimeValue frameDuration;
    BMDTimeScale timeScale;
    displayMode->GetFrameRate(&frameDuration, &timeScale);
    Print("Setting frame rate to %llu/%llu\n", frameDuration, timeScale);

    // Set the video input mode
    DECKLINK_ASSERT_SUCCEEDED(m_pDeckLinkInput->EnableVideoInput(displayMode->GetDisplayMode(), bmdFormat8BitYUV, bmdVideoInputEnableFormatDetection));
    DECKLINK_ASSERT_SUCCEEDED(m_pDeckLinkInput->EnableAudioInput(m_AudioSampleRate, m_audioSampleType, m_audioChannels));

    // Start the capture
    DECKLINK_ASSERT_SUCCEEDED(m_pDeckLinkInput->StartStreams());

    displayModeIterator->Release();
    displayMode->Release();
}

void DecklinkInterface::Reset(BMDDisplayMode displayMode, BMDPixelFormat pixelFormat)
{
    // Restart the stream with the new format
    DECKLINK_ASSERT_SUCCEEDED(m_pDeckLinkInput->PauseStreams());
    DECKLINK_ASSERT_SUCCEEDED(m_pDeckLinkInput->EnableVideoInput(displayMode, pixelFormat, bmdVideoInputEnableFormatDetection));
    DECKLINK_ASSERT_SUCCEEDED(m_pDeckLinkInput->FlushStreams());
    DECKLINK_ASSERT_SUCCEEDED(m_pDeckLinkInput->StartStreams());
}

void DecklinkInterface::StopCapture()
{
    // Clear the callbacks
    DECKLINK_ASSERT_SUCCEEDED(m_pDeckLinkInput->SetScreenPreviewCallback(nullptr));
    DECKLINK_ASSERT_SUCCEEDED(m_pDeckLinkInput->SetCallback(nullptr));

    DECKLINK_ASSERT_SUCCEEDED(m_pDeckLinkInput->StopStreams());

    m_captureEnabled = false;
}

DecklinkInterface::~DecklinkInterface()
{
    StopCapture();

    m_pDeckLinkInput->Release();
}

void DecklinkInterface::Print(const char* format, ...)
{
    va_list args;
    va_start(args, format);
    m_printf(format, args);
    va_end(args);
}

void DecklinkInterface::SetPreviewHandler(PreviewHandlerFunc func)
{
    m_previewHandler = func;
}

void DecklinkInterface::SetVideoHandler(VideoHandlerFunc func)
{
    m_videoHandler = func;
}

void DecklinkInterface::SetAudioChannels(int channels)
{
    VerifyInactive(__FUNCTION__);

    m_audioChannels = channels;
}

void DecklinkInterface::SetAudioSampleDepth(int depth)
{
    VerifyInactive(__FUNCTION__);

    m_audioSampleDepth = depth;

    switch (m_audioSampleDepth)
    {
        case 16:
            m_audioSampleType = bmdAudioSampleType16bitInteger;
            break;
        case 32:
            m_audioSampleType = bmdAudioSampleType32bitInteger;
            break;
        default:
            printf("DecklinkInterface Error: Invalid audio sample depth (expected 16 or 32)\n");
            exit(-1);
    }
}

void DecklinkInterface::SetAudioBitrate(int rate)
{
    VerifyInactive(__FUNCTION__);

    // This function is just here for completeness, but only 48000 is allowed
    if (rate != 48000)
    {
        Print("DecklinkInterface Error: Invalid audio sample rate (expected 48000)\n");
        m_errorHandler();
    }
}

int DecklinkInterface::GetAudioSampleSize()
{
    return (m_audioSampleDepth / 8) * m_audioChannels;
}

int DecklinkInterface::GetAudioChannels()
{
    return m_audioChannels;
}

int DecklinkInterface::GetAudioSampleDepth()
{
    return m_audioSampleDepth;
}

int DecklinkInterface::GetAudioBitrate()
{
    return m_AudioBitrate;
}

void DecklinkInterface::VerifyInactive(const char* location)
{
    if (m_captureEnabled)
    {
        Print("DecklinkInterface Error: %s was called in the middle of a capture!\n", location);
        m_errorHandler();
    }
}

void DecklinkInterface::GotPreview(IDeckLinkVideoFrame* videoFrame)
{
    if (!m_previewHandler)
    {
        return;
    }

    unsigned char* videoBytes;
    videoFrame->GetBytes(reinterpret_cast<void**>(&videoBytes));

    m_previewHandler(videoBytes, videoFrame->GetWidth(), videoFrame->GetHeight(), videoFrame->GetPixelFormat());
}

void DecklinkInterface::GotVideo(IDeckLinkVideoInputFrame* videoFrame, IDeckLinkAudioInputPacket* audioFrame)
{
    bool droppedFrames = false;
    BMDTimeScale frameTime = 0;

    unsigned char* videoBytes = nullptr;
    int width = 0;
    int height = 0;
    BMDPixelFormat pixelFormat = bmdFormat8BitYUV;

    if (videoFrame)
    {
        BMDTimeScale frameDuration;
        videoFrame->GetStreamTime(&frameTime, &frameDuration, 60);  // The HDMI is assumed to always be 60 FPS

        // We should get every single frame
        if (frameDuration != 1 || m_oldFrameTime + frameDuration != frameTime)
        {
            droppedFrames = true;
        }

        videoFrame->GetBytes(reinterpret_cast<void**>(&videoBytes));

        width = videoFrame->GetWidth();
        height = videoFrame->GetHeight();
        pixelFormat = videoFrame->GetPixelFormat();
    }

    unsigned char* audioBytes = nullptr;
    int audioFrameCount = 0;

    if (audioFrame)
    {
        audioFrame->GetPacketTime(&frameTime, 60);  // The HDMI is assumed to always be 60 FPS

        // We should get every single frame
        if (m_oldFrameTime + 1 != frameTime)
        {
            droppedFrames = true;
        }


        audioFrameCount = audioFrame->GetSampleFrameCount();
        audioFrame->GetBytes(reinterpret_cast<void**>(&audioBytes));
    }

    m_oldFrameTime = frameTime;

    if (!m_videoHandler)
    {
        return;
    }

    m_videoHandler(videoBytes, width, height, pixelFormat, audioBytes, audioFrameCount, droppedFrames);
}
