-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#17215: Add write/read APIs for TTNN tensors allocated on mesh buffer #17513
Conversation
@@ -378,6 +378,9 @@ struct MultiDeviceStorage { | |||
|
|||
using Storage = std::variant<OwnedStorage, DeviceStorage, BorrowedStorage, MultiDeviceHostStorage, MultiDeviceStorage>; | |||
|
|||
template <typename T> | |||
concept OwnedOrBorrowedStorage = std::is_same_v<T, OwnedStorage> || std::is_same_v<T, BorrowedStorage>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider HostStorage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is MultiDeviceHostStorage
unfortunately. I think we need to do a better job at creating a hierarchy here - e.g. HostTensor
as a collection of buffers (which can be owned or borrowed) + DeviceTensor
that will always be backed by MeshBuffer
(eventually). "Owned" vs" Borrowed" to me sounds like lower level concept, implementation detail of the buffer, not the whole tensor storage variant.
CommandQueue& cq, std::shared_ptr<Buffer> device_buffer, void* host_buffer_data, bool blocking) { | ||
EnqueueReadBuffer(cq, device_buffer, host_buffer_data, blocking); | ||
} | ||
|
||
template <typename T> | ||
inline void read_data_from_device_buffer(std::shared_ptr<Buffer> device_buffer, std::vector<T>& host_buffer) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you
CommandQueue& cq, std::shared_ptr<Buffer> device_buffer, void* host_buffer_data, bool blocking) { | ||
EnqueueReadBuffer(cq, device_buffer, host_buffer_data, blocking); | ||
} | ||
|
||
template <typename T> | ||
inline void read_data_from_device_buffer(std::shared_ptr<Buffer> device_buffer, std::vector<T>& host_buffer) { | ||
void read_data_from_device_buffer(std::shared_ptr<Buffer> device_buffer, std::vector<T>& host_buffer) { | ||
::tt::tt_metal::detail::ReadFromBuffer(device_buffer, host_buffer); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whats the difference between this vs EnqueueReadBuffer
above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a slow dispatch path?
if so, can you please mark the method deprecated with the comment that its a slow dispatch path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slow dispatch is not deprecated. It's actively used for all sorts of bringup, debug and experiments. Marking it deprecated would imply that this needs to be cleaned up from the codebase at some point, which is not the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TT-NN must not care about it I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, TTNN needs to rely on the single API, internally we might fallback to slow dispatch if needed.
} | ||
}, | ||
}, | ||
[](const auto& s) -> owned_buffer::Buffer<T> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good change
CommandQueue& cq, std::shared_ptr<Buffer> device_buffer, void* host_buffer_data, bool blocking) { | ||
EnqueueReadBuffer(cq, device_buffer, host_buffer_data, blocking); | ||
} | ||
|
||
template <typename T> | ||
inline void read_data_from_device_buffer(std::shared_ptr<Buffer> device_buffer, std::vector<T>& host_buffer) { | ||
void read_data_from_device_buffer(std::shared_ptr<Buffer> device_buffer, std::vector<T>& host_buffer) { | ||
::tt::tt_metal::detail::ReadFromBuffer(device_buffer, host_buffer); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slow dispatch is not deprecated. It's actively used for all sorts of bringup, debug and experiments. Marking it deprecated would imply that this needs to be cleaned up from the codebase at some point, which is not the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will you be adding python APis for this next? Would be great to see how TTNN code remain essentially unchanged when we switch backends.
Most likely will add a switch to use these mesh-based implementations in the existing |
Yes, a first step would be to add a switch to those top level APIs, and then see what falls out when we integrate into the functions exposed by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good but need to review your changes and the assumptions you made about 1:1 mapping of shards to devices.
fd642fe
to
ce98cd9
Compare
ce98cd9
to
65105aa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for fixing!
Latest changes look great, thanks Oleg! |
Ticket
#17215
Problem description
Tensors allocated on mesh buffer (aka "mesh tensors") need write and read APIs exposed to TTNN.
What's changed
to_device_mesh_tensor
andto_host_mesh_tensor
that will be the main API used in TTNN to read/write the mesh buffer tensors.Checklist