#pragma once
#include <cstdint>
#include <vector>
#include <d3d11.h>
#include <wrl/client.h>
class PushBufferD3D11 {
public:
PushBufferD3D11(ID3D11Device *device, size_t size, D3D11_BIND_FLAG bindFlags) : size_(size) {
D3D11_BUFFER_DESC desc{};
desc.BindFlags = bindFlags;
desc.ByteWidth = (UINT)size;
desc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE;
desc.Usage = D3D11_USAGE_DYNAMIC;
device->CreateBuffer(&desc, nullptr, &buffer_);
}
PushBufferD3D11(PushBufferD3D11 &) = delete;
~PushBufferD3D11() {
}
ID3D11Buffer *Buf() const {
return buffer_.Get();
}
void Reset() {
pos_ = 0;
nextMapDiscard_ = true;
}
uint8_t *BeginPush(ID3D11DeviceContext *context, UINT *offset, size_t size, int align = 16) {
D3D11_MAPPED_SUBRESOURCE map;
pos_ = (pos_ + align - 1) & ~(align - 1);
if (pos_ + size > size_) {
pos_ = 0;
nextMapDiscard_ = true;
}
context->Map(buffer_.Get(), 0, nextMapDiscard_ ? D3D11_MAP_WRITE_DISCARD : D3D11_MAP_WRITE_NO_OVERWRITE, 0, &map);
nextMapDiscard_ = false;
*offset = (UINT)pos_;
uint8_t *retval = (uint8_t *)map.pData + pos_;
pos_ += size;
return retval;
}
void EndPush(ID3D11DeviceContext *context) {
context->Unmap(buffer_.Get(), 0);
}
private:
Microsoft::WRL::ComPtr<ID3D11Buffer> buffer_;
size_t pos_ = 0;
size_t size_;
bool nextMapDiscard_ = false;
};
std::vector<uint8_t> CompileShaderToBytecodeD3D11(const char *code, size_t codeSize, const char *target, UINT flags);
HRESULT CreateVertexShaderD3D11(ID3D11Device *device, const char *code, size_t codeSize, std::vector<uint8_t> *byteCodeOut, D3D_FEATURE_LEVEL featureLevel, UINT flags, ID3D11VertexShader **);
HRESULT CreatePixelShaderD3D11(ID3D11Device *device, const char *code, size_t codeSize, D3D_FEATURE_LEVEL featureLevel, UINT flags, ID3D11PixelShader **);
HRESULT CreateComputeShaderD3D11(ID3D11Device *device, const char *code, size_t codeSize, D3D_FEATURE_LEVEL featureLevel, UINT flags, ID3D11ComputeShader **);
HRESULT CreateGeometryShaderD3D11(ID3D11Device *device, const char *code, size_t codeSize, D3D_FEATURE_LEVEL featureLevel, UINT flags, ID3D11GeometryShader **);
#define ASSERT_SUCCESS(x) \
if (!SUCCEEDED((x))) \
Crash();