Skip to content

Commit

Permalink
Use an enum and static_pointer_cast instead of dynamic_pointer_cast
Browse files Browse the repository at this point in the history
This should be faster and simpler to read.

Signed-off-by: Robert Quill <[email protected]>
  • Loading branch information
robquill committed Sep 2, 2024
1 parent c28743d commit b792ed5
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 14 deletions.
3 changes: 1 addition & 2 deletions src/Algorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ Algorithm::createParameters()
KP_LOG_DEBUG("Kompute Algorithm createParameters started");

for (const std::shared_ptr<Memory>& mem : this->mMemObjects) {
if (std::shared_ptr<Image> image =
std::dynamic_pointer_cast<Image>(mem)) {
if (mem->type() == Memory::Type::eImage) {
numImages++;
}
else {
Expand Down
11 changes: 4 additions & 7 deletions src/Memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,6 @@ void
Memory::recordCopyFrom(const vk::CommandBuffer& commandBuffer,
std::shared_ptr<Memory> copyFromMemory)
{
std::shared_ptr<Tensor> tensor = std::dynamic_pointer_cast<Tensor>(copyFromMemory);
std::shared_ptr<Image> image = std::dynamic_pointer_cast<Image>(copyFromMemory);

if (copyFromMemory->dataType() != this->dataType()) {
throw std::runtime_error(fmt::format(
"Attempting to copy memory of different types from {} to {}",
Expand All @@ -268,11 +265,11 @@ Memory::recordCopyFrom(const vk::CommandBuffer& commandBuffer,
this->size()));
}

if (tensor) {
this->recordCopyFrom(commandBuffer, tensor);
if (copyFromMemory->type() == Memory::Type::eTensor) {
this->recordCopyFrom(commandBuffer, std::static_pointer_cast<Tensor>(copyFromMemory));
}
else if (image) {
this->recordCopyFrom(commandBuffer, image);
else if (copyFromMemory->type() == Memory::Type::eImage) {
this->recordCopyFrom(commandBuffer, std::static_pointer_cast<Image>(copyFromMemory));
}
else
{
Expand Down
5 changes: 3 additions & 2 deletions src/OpAlgoDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ OpAlgoDispatch::record(const vk::CommandBuffer& commandBuffer)
this->mAlgorithm->getMemObjects()) {

// For images the image layout needs to be set to eGeneral before using it for imageLoad/imageStore in a shader.
if (std::shared_ptr<Image> image =
std::dynamic_pointer_cast<Image>(mem)) {
if (mem->type() == Memory::Type::eImage) {
std::shared_ptr<Image> image = std::static_pointer_cast<Image>(mem);

image->recordPrimaryImageBarrier(
commandBuffer,
vk::AccessFlagBits::eTransferWrite,
Expand Down
2 changes: 2 additions & 0 deletions src/include/kompute/Image.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ class Image : public Memory
*/
uint32_t getNumChannels();

Type type() override { return Type::eImage; }

protected:
// -------------- ALWAYS OWNED RESOURCES
uint32_t mNumChannels;
Expand Down
18 changes: 15 additions & 3 deletions src/include/kompute/Memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class Memory
eUnsignedChar = 9
};

enum class Type
{
eTensor = 0,
eImage = 1
};

static std::string toString(MemoryTypes dt);
static std::string toString(DataTypes dt);

Expand Down Expand Up @@ -158,9 +164,8 @@ class Memory
* @param commandBuffer Vulkan Command Buffer to record the commands into
* @param copyFromMemory Memory to copy the data from
*/
void
recordCopyFrom(const vk::CommandBuffer& commandBuffer,
std::shared_ptr<Memory> copyFromMemory);
void recordCopyFrom(const vk::CommandBuffer& commandBuffer,
std::shared_ptr<Memory> copyFromMemory);

/**
* Adds this object to a Vulkan descriptor set at \p binding.
Expand Down Expand Up @@ -285,6 +290,13 @@ recordCopyFrom(const vk::CommandBuffer& commandBuffer,
*/
uint32_t getY() { return this->mY; };

/**
* Return the object type of this Memory object.
*
* @return The object type of the Memory object.
*/
virtual Type type() = 0;

protected:
// -------------- ALWAYS OWNED RESOURCES
MemoryTypes mMemoryType;
Expand Down
2 changes: 2 additions & 0 deletions src/include/kompute/Tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class Tensor : public Memory

std::shared_ptr<vk::Buffer> getPrimaryBuffer();

Type type() override { return Type::eTensor; }

protected:
// -------------- ALWAYS OWNED RESOURCES
vk::DescriptorBufferInfo mDescriptorBufferInfo;
Expand Down

0 comments on commit b792ed5

Please sign in to comment.