-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix non-deterministic hangs caused by MeshDevice trace replay #17696
Conversation
@@ -141,7 +141,8 @@ class IDevice { | |||
// Metal trace device capture mode | |||
virtual void begin_trace(const uint8_t cq_id, const uint32_t tid) = 0; | |||
virtual void end_trace(const uint8_t cq_id, const uint32_t tid) = 0; | |||
virtual void replay_trace(const uint8_t cq_id, const uint32_t tid, const bool blocking) = 0; | |||
virtual void replay_trace( | |||
const uint8_t cq_id, const uint32_t tid, const bool block_on_device, const bool block_on_worker_thread) = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outside of this PR... const T
in function declarations don't do anything, it is equivalent to just T
(doesn't apply to function definitions, where const T
makes sure the parameter doesn't get mutated in the function body, also doesn't apply to const T&
const T*
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I just used the existing precedent there but I agree.
// If blocking, wait until worker threads have completed | ||
if (block_on_worker_thread) { | ||
for (auto& device : scoped_devices_->get_devices()) { | ||
device->synchronize(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it not going to work if you call device->synchronize()
but each call device->replay_trace()
is non blocking?
Is it possible to use the event APIs for this perhaps? Maybe in the long term? I see the TODO but this problem will essentially remain, it's just now you won't hop over push_work
and instead perform a blocking EnqueueTrace
. How would EnqueueMeshTrace
work from this perspective?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it not going to work if you call device->synchronize() but each call device->replay_trace() is non blocking?
At a high-level the behavior we want is to have main thread block on all devices. With worker thread per device, this is basically broken up into each worker thread blocking on device and main thread blocking on every worker thread. Your proposal would not guarantee the same behavior as what was there before. Right now, I'm just reverting back to the old behavior to resolve the llama hangs.
Is it possible to use the event APIs for this perhaps? Maybe in the long term? I see the TODO but this problem will essentially remain, it's just now you won't hop over push_work and instead perform a blocking EnqueueTrace. How would EnqueueMeshTrace work from this perspective?
Long term we won't have blocking on worker threads with TT-Mesh. @tt-asaigal will be implementing this as part of the Trace functionality.
@@ -1326,7 +1326,7 @@ void EndTraceCapture(IDevice* device, const uint8_t cq_id, const uint32_t tid) { | |||
void ReplayTrace(IDevice* device, const uint8_t cq_id, const uint32_t tid, const bool blocking) { | |||
LIGHT_METAL_TRACE_FUNCTION_ENTRY(); | |||
LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureReplayTrace, device, cq_id, tid, blocking); | |||
device->replay_trace(cq_id, tid, blocking); | |||
device->replay_trace(cq_id, tid, blocking /* block_on_device */, blocking /* block_on_worker_thread */); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both places where replay
is called use the same boolean for both though? Does this fix the hang?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Take a look at the MeshDevice::replay
implementation. On each worker replay, block_on_worker_thread
is set to false.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passes locally, waiting on our backlogged CI.
We've tested this PR with a new update to the Llama3 models locally and the demo tests are passing. Pushed a branch that's basically #17421 rebased to this PR #17696 and kicked off the T3K demo pipelines: https://github.com/tenstorrent/tt-metal/actions/runs/13198728419 If these past this PR his good for merge from the models team side. |
68f1409
to
eb5e550
Compare
I cherry-picked these onto a new branch and still see a local hang - not convinced this is a complete fix. |
We noticed this broke the CI tests over a week ago. For all of that time main has been broken for our team and main has been completely unprotected against other commits breaking our code further. The breaking change should have been reverted from main immediately and this fixed on branch. Can we do this today? We need a working main for our workshop tomorrow. |
eb5e550
to
61072d1
Compare
I had a chat with @mtairum and we agree there are other commits that are responsible for the failures in regression. This commit does solve the issues originally reported. Looking back at the original regression introduced by my changes:
The original change caused a regression only in the t3k_llama3_tests test suite. The changes in this PR revert to the original behavior and fixes the original regression in There seems to be three categories of failures in the CI today:
|
Ticket
Link to Github Issue
Problem description
There are non-deterministic hangs with llama model tests using trace functionality. @tt-aho pointed out there are some changes in the blocking behavior that was incorrectly introduced in: d2ba114
What's changed
This change modifies the
IDevice::replay_trace
method to add more granular control over blocking behavior. Previously, we used a single booleanblocking
to denote stalls that happen on device and stalls that happen on worker-thread. Since we've movedpush_work
API underneath the device APIs, we need this fine-grained control for orchestrating trace replay fromMeshDevice
.Checklist