#pragma once

#include "inexor/vulkan-renderer/wrapper/device.hpp"
#include "inexor/vulkan-renderer/wrapper/framebuffer.hpp"
#include "inexor/vulkan-renderer/wrapper/swapchain.hpp"

#include <spdlog/spdlog.h>

#include <functional>
#include <memory>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>

// TODO: Compute stages.
// TODO: Uniform buffers.

// Forward declarations
namespace inexor::vulkan_renderer::wrapper {
class CommandBuffer;
class Shader;
}; // namespace inexor::vulkan_renderer::wrapper

namespace inexor::vulkan_renderer {

// Forward declarations
class PhysicalResource;
class PhysicalStage;
class RenderGraph;

struct RenderGraphObject {
    RenderGraphObject() = default;
    RenderGraphObject(const RenderGraphObject &) = delete;
    RenderGraphObject(RenderGraphObject &&) = delete;
    virtual ~RenderGraphObject() = default;

    RenderGraphObject &operator=(const RenderGraphObject &) = delete;
    RenderGraphObject &operator=(RenderGraphObject &&) = delete;

    template <typename T>
    [[nodiscard]] T *as();

    template <typename T>
    [[nodiscard]] const T *as() const;

class RenderResource : public RenderGraphObject {
    friend RenderGraph;

    const std::string m_name;
    std::shared_ptr<PhysicalResource> m_physical;

    explicit RenderResource(std::string name) : m_name(std::move(name)) {}

    RenderResource(const RenderResource &) = delete;
    RenderResource(RenderResource &&) = delete;
    ~RenderResource() override = default;

    RenderResource &operator=(const RenderResource &) = delete;
    RenderResource &operator=(RenderResource &&) = delete;

    [[nodiscard]] const std::string &name() const {
        return m_name;

enum class BufferUsage {


class BufferResource : public RenderResource {
    friend RenderGraph;

    const BufferUsage m_usage;
    std::vector<VkVertexInputAttributeDescription> m_vertex_attributes;

    // Data to upload during render graph compilation.
    const void *m_data{nullptr};
    std::size_t m_data_size{0};
    bool m_data_upload_needed{false};
    std::size_t m_element_size{0};

    BufferResource(std::string &&name, BufferUsage usage) : RenderResource(name), m_usage(usage) {}

    void add_vertex_attribute(VkFormat format, std::uint32_t offset);

    void set_element_size(std::size_t element_size) {
        m_element_size = element_size;

    // TODO: Use std::span when we switch to C++ 20.
    template <typename T>
    void upload_data(const T *data, std::size_t count);

    template <typename T>
    void upload_data(const std::vector<T> &data);

enum class TextureUsage {
    // TODO: Refactor back buffer system more (remove need for BACK_BUFFER texture usage)



class TextureResource : public RenderResource {
    friend RenderGraph;

    const TextureUsage m_usage;
    VkFormat m_format{VK_FORMAT_UNDEFINED};

    TextureResource(std::string &&name, TextureUsage usage) : RenderResource(name), m_usage(usage) {}

    void set_format(VkFormat format) {
        m_format = format;

class RenderStage : public RenderGraphObject {
    friend RenderGraph;

    const std::string m_name;
    std::unique_ptr<PhysicalStage> m_physical;
    std::vector<const RenderResource *> m_writes;
    std::vector<const RenderResource *> m_reads;

    std::vector<VkDescriptorSetLayout> m_descriptor_layouts;
    std::vector<VkPushConstantRange> m_push_constant_ranges;
    std::function<void(const PhysicalStage &, const wrapper::CommandBuffer &)> m_on_record{[](auto &, auto &) {}};

    explicit RenderStage(std::string name) : m_name(std::move(name)) {}

    RenderStage(const RenderStage &) = delete;
    RenderStage(RenderStage &&) = delete;
    ~RenderStage() override = default;

    RenderStage &operator=(const RenderStage &) = delete;
    RenderStage &operator=(RenderStage &&) = delete;

    void writes_to(const RenderResource *resource);

    void reads_from(const RenderResource *resource);

    // TODO: Refactor descriptor management in the render graph
    void add_descriptor_layout(VkDescriptorSetLayout layout) {

    void add_push_constant_range(VkPushConstantRange range) {

    [[nodiscard]] const std::string &name() const {
        return m_name;

    void set_on_record(std::function<void(const PhysicalStage &, const wrapper::CommandBuffer &)> on_record) {
        m_on_record = std::move(on_record);

class GraphicsStage : public RenderStage {
    friend RenderGraph;

    bool m_clears_screen{false};
    bool m_depth_test{false};
    bool m_depth_write{false};
    VkPipelineColorBlendAttachmentState m_blend_attachment{};
    std::unordered_map<const BufferResource *, std::uint32_t> m_buffer_bindings;
    std::vector<VkPipelineShaderStageCreateInfo> m_shaders;

    explicit GraphicsStage(std::string &&name) : RenderStage(name) {}
    GraphicsStage(const GraphicsStage &) = delete;
    GraphicsStage(GraphicsStage &&) = delete;
    ~GraphicsStage() override = default;

    GraphicsStage &operator=(const GraphicsStage &) = delete;
    GraphicsStage &operator=(GraphicsStage &&) = delete;

    void set_clears_screen(bool clears_screen) {
        m_clears_screen = clears_screen;

    void set_depth_options(bool depth_test, bool depth_write) {
        m_depth_test = depth_test;
        m_depth_write = depth_write;

    void set_blend_attachment(VkPipelineColorBlendAttachmentState blend_attachment) {
        m_blend_attachment = blend_attachment;

    void bind_buffer(const BufferResource *buffer, std::uint32_t binding);

    void uses_shader(const wrapper::Shader &shader);

// TODO: Add wrapper::Allocation that can be made by doing `device->make<Allocation>(...)`.
class PhysicalResource : public RenderGraphObject {
    friend RenderGraph;

    const wrapper::Device &m_device;
    VmaAllocation m_allocation{VK_NULL_HANDLE};

    explicit PhysicalResource(const wrapper::Device &device) : m_device(device) {}

    PhysicalResource(const PhysicalResource &) = delete;
    PhysicalResource(PhysicalResource &&) = delete;
    ~PhysicalResource() override = default;

    PhysicalResource &operator=(const PhysicalResource &) = delete;
    PhysicalResource &operator=(PhysicalResource &&) = delete;

class PhysicalBuffer : public PhysicalResource {
    friend RenderGraph;

    VmaAllocationInfo m_alloc_info{};
    VkBuffer m_buffer{VK_NULL_HANDLE};

    explicit PhysicalBuffer(const wrapper::Device &device) : PhysicalResource(device) {}
    PhysicalBuffer(const PhysicalBuffer &) = delete;
    PhysicalBuffer(PhysicalBuffer &&) = delete;
    ~PhysicalBuffer() override;

    PhysicalBuffer &operator=(const PhysicalBuffer &) = delete;
    PhysicalBuffer &operator=(PhysicalBuffer &&) = delete;

class PhysicalImage : public PhysicalResource {
    friend RenderGraph;

    VkImage m_image{VK_NULL_HANDLE};
    VkImageView m_image_view{VK_NULL_HANDLE};

    explicit PhysicalImage(const wrapper::Device &device) : PhysicalResource(device) {}
    PhysicalImage(const PhysicalImage &) = delete;
    PhysicalImage(PhysicalImage &&) = delete;
    ~PhysicalImage() override;

    PhysicalImage &operator=(const PhysicalImage &) = delete;
    PhysicalImage &operator=(PhysicalImage &&) = delete;

class PhysicalBackBuffer : public PhysicalResource {
    friend RenderGraph;

    const wrapper::Swapchain &m_swapchain;

    PhysicalBackBuffer(const wrapper::Device &device, const wrapper::Swapchain &swapchain)
        : PhysicalResource(device), m_swapchain(swapchain) {}
    PhysicalBackBuffer(const PhysicalBackBuffer &) = delete;
    PhysicalBackBuffer(PhysicalBackBuffer &&) = delete;
    ~PhysicalBackBuffer() override = default;

    PhysicalBackBuffer &operator=(const PhysicalBackBuffer &) = delete;
    PhysicalBackBuffer &operator=(PhysicalBackBuffer &&) = delete;

class PhysicalStage : public RenderGraphObject {
    friend RenderGraph;

    VkPipeline m_pipeline{VK_NULL_HANDLE};
    VkPipelineLayout m_pipeline_layout{VK_NULL_HANDLE};

    const wrapper::Device &m_device;

    explicit PhysicalStage(const wrapper::Device &device) : m_device(device) {}
    PhysicalStage(const PhysicalStage &) = delete;
    PhysicalStage(PhysicalStage &&) = delete;
    ~PhysicalStage() override;

    PhysicalStage &operator=(const PhysicalStage &) = delete;
    PhysicalStage &operator=(PhysicalStage &&) = delete;

    // TODO: This can be removed once descriptors are properly implemented in the render graph.
    [[nodiscard]] VkPipelineLayout pipeline_layout() const {
        return m_pipeline_layout;

class PhysicalGraphicsStage : public PhysicalStage {
    friend RenderGraph;

    VkRenderPass m_render_pass{VK_NULL_HANDLE};
    std::vector<wrapper::Framebuffer> m_framebuffers;

    explicit PhysicalGraphicsStage(const wrapper::Device &device) : PhysicalStage(device) {}
    PhysicalGraphicsStage(const PhysicalGraphicsStage &) = delete;
    PhysicalGraphicsStage(PhysicalGraphicsStage &&) = delete;
    ~PhysicalGraphicsStage() override;

    PhysicalGraphicsStage &operator=(const PhysicalGraphicsStage &) = delete;
    PhysicalGraphicsStage &operator=(PhysicalGraphicsStage &&) = delete;

class RenderGraph {
    wrapper::Device &m_device;
    const wrapper::Swapchain &m_swapchain;
    std::shared_ptr<spdlog::logger> m_log{spdlog::default_logger()->clone("render-graph")};

    // Vectors of render resources and stages.
    std::vector<std::unique_ptr<BufferResource>> m_buffer_resources;
    std::vector<std::unique_ptr<TextureResource>> m_texture_resources;
    std::vector<std::unique_ptr<RenderStage>> m_stages;

    // Stage execution order.
    std::vector<RenderStage *> m_stage_stack;

    // Functions for building resource related vulkan objects.
    void build_buffer(const BufferResource &, PhysicalBuffer &) const;
    void build_image(const TextureResource &, PhysicalImage &, VmaAllocationCreateInfo *) const;
    void build_image_view(const TextureResource &, PhysicalImage &) const;

    // Functions for building stage related vulkan objects.
    void build_pipeline_layout(const RenderStage *, PhysicalStage &) const;
    void record_command_buffer(const RenderStage *, const wrapper::CommandBuffer &cmd_buf,
                               std::uint32_t image_index) const;

    // Functions for building graphics stage related vulkan objects.
    void build_render_pass(const GraphicsStage *, PhysicalGraphicsStage &) const;
    void build_graphics_pipeline(const GraphicsStage *, PhysicalGraphicsStage &) const;

    RenderGraph(wrapper::Device &device, const wrapper::Swapchain &swapchain)
        : m_device(device), m_swapchain(swapchain) {}

    template <typename T, typename... Args>
    T *add(Args &&...args) {
        auto ptr = std::make_unique<T>(std::forward<Args>(args)...);
        if constexpr (std::is_same_v<T, BufferResource>) {
            return static_cast<T *>(m_buffer_resources.emplace_back(std::move(ptr)).get());
        } else if constexpr (std::is_same_v<T, TextureResource>) {
            return static_cast<T *>(m_texture_resources.emplace_back(std::move(ptr)).get());
        } else if constexpr (std::is_base_of_v<RenderStage, T>) {
            return static_cast<T *>(m_stages.emplace_back(std::move(ptr)).get());
        } else {
            static_assert(!std::is_same_v<T, T>, "T must be a RenderResource or RenderStage");

    void compile(const RenderResource *target);

    void render(std::uint32_t image_index, const wrapper::CommandBuffer &cmd_buf);

template <typename T>
[[nodiscard]] T *RenderGraphObject::as() {
    return dynamic_cast<T *>(this);

template <typename T>
[[nodiscard]] const T *RenderGraphObject::as() const {
    return dynamic_cast<const T *>(this);

template <typename T>
void BufferResource::upload_data(const T *data, std::size_t count) {
    m_data = data;
    m_data_size = count * (m_element_size = sizeof(T));
    m_data_upload_needed = true;

template <typename T>
void BufferResource::upload_data(const std::vector<T> &data) {
    upload_data(, data.size());

} // namespace inexor::vulkan_renderer