Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non-deterministic hangs caused by MeshDevice trace replay #17696

Merged
merged 3 commits into from
Feb 11, 2025

Conversation

cfjchu
Copy link
Collaborator

@cfjchu cfjchu commented Feb 7, 2025

Ticket

Link to Github Issue

Problem description

There are non-deterministic hangs with llama model tests using trace functionality. @tt-aho pointed out there are some changes in the blocking behavior that was incorrectly introduced in: d2ba114

What's changed

This change modifies the IDevice::replay_trace method to add more granular control over blocking behavior. Previously, we used a single boolean blocking to denote stalls that happen on device and stalls that happen on worker-thread. Since we've moved push_work API underneath the device APIs, we need this fine-grained control for orchestrating trace replay from MeshDevice.

Checklist

@cfjchu cfjchu changed the title Jchu/fix meshdevice trace replay Fix non-deterministic hangs caused by MeshDevice trace replay Feb 7, 2025
@cfjchu cfjchu marked this pull request as ready for review February 7, 2025 04:36
tt_metal/impl/device/device.cpp Show resolved Hide resolved
@@ -141,7 +141,8 @@ class IDevice {
// Metal trace device capture mode
virtual void begin_trace(const uint8_t cq_id, const uint32_t tid) = 0;
virtual void end_trace(const uint8_t cq_id, const uint32_t tid) = 0;
virtual void replay_trace(const uint8_t cq_id, const uint32_t tid, const bool blocking) = 0;
virtual void replay_trace(
const uint8_t cq_id, const uint32_t tid, const bool block_on_device, const bool block_on_worker_thread) = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside of this PR... const T in function declarations don't do anything, it is equivalent to just T (doesn't apply to function definitions, where const T makes sure the parameter doesn't get mutated in the function body, also doesn't apply to const T& const T*)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I just used the existing precedent there but I agree.

Comment on lines +562 to +592
// If blocking, wait until worker threads have completed
if (block_on_worker_thread) {
for (auto& device : scoped_devices_->get_devices()) {
device->synchronize();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not going to work if you call device->synchronize() but each call device->replay_trace() is non blocking?

Is it possible to use the event APIs for this perhaps? Maybe in the long term? I see the TODO but this problem will essentially remain, it's just now you won't hop over push_work and instead perform a blocking EnqueueTrace. How would EnqueueMeshTrace work from this perspective?

Copy link
Collaborator Author

@cfjchu cfjchu Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not going to work if you call device->synchronize() but each call device->replay_trace() is non blocking?

At a high-level the behavior we want is to have main thread block on all devices. With worker thread per device, this is basically broken up into each worker thread blocking on device and main thread blocking on every worker thread. Your proposal would not guarantee the same behavior as what was there before. Right now, I'm just reverting back to the old behavior to resolve the llama hangs.

Is it possible to use the event APIs for this perhaps? Maybe in the long term? I see the TODO but this problem will essentially remain, it's just now you won't hop over push_work and instead perform a blocking EnqueueTrace. How would EnqueueMeshTrace work from this perspective?

Long term we won't have blocking on worker threads with TT-Mesh. @tt-asaigal will be implementing this as part of the Trace functionality.

@@ -1326,7 +1326,7 @@ void EndTraceCapture(IDevice* device, const uint8_t cq_id, const uint32_t tid) {
void ReplayTrace(IDevice* device, const uint8_t cq_id, const uint32_t tid, const bool blocking) {
LIGHT_METAL_TRACE_FUNCTION_ENTRY();
LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureReplayTrace, device, cq_id, tid, blocking);
device->replay_trace(cq_id, tid, blocking);
device->replay_trace(cq_id, tid, blocking /* block_on_device */, blocking /* block_on_worker_thread */);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both places where replay is called use the same boolean for both though? Does this fix the hang?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take a look at the MeshDevice::replay implementation. On each worker replay, block_on_worker_thread is set to false.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passes locally, waiting on our backlogged CI.

@mtairum
Copy link
Contributor

mtairum commented Feb 7, 2025

We've tested this PR with a new update to the Llama3 models locally and the demo tests are passing.
Relevant PR: #17421

Pushed a branch that's basically #17421 rebased to this PR #17696 and kicked off the T3K demo pipelines: https://github.com/tenstorrent/tt-metal/actions/runs/13198728419

If these past this PR his good for merge from the models team side.
FIY @yieldthought

@cfjchu cfjchu force-pushed the jchu/fix-meshdevice-trace-replay branch from 68f1409 to eb5e550 Compare February 7, 2025 19:15
@yieldthought
Copy link
Contributor

I cherry-picked these onto a new branch and still see a local hang - not convinced this is a complete fix.

@yieldthought
Copy link
Contributor

We noticed this broke the CI tests over a week ago. For all of that time main has been broken for our team and main has been completely unprotected against other commits breaking our code further.

The breaking change should have been reverted from main immediately and this fixed on branch. Can we do this today? We need a working main for our workshop tomorrow.

@cfjchu cfjchu force-pushed the jchu/fix-meshdevice-trace-replay branch from eb5e550 to 61072d1 Compare February 11, 2025 04:29
@cfjchu
Copy link
Collaborator Author

cfjchu commented Feb 11, 2025

I had a chat with @mtairum and we agree there are other commits that are responsible for the failures in regression. This commit does solve the issues originally reported.

Looking back at the original regression introduced by my changes:

The original change caused a regression only in the t3k_llama3_tests test suite. The changes in this PR revert to the original behavior and fixes the original regression in t3k_llama3_tests:

There seems to be three categories of failures in the CI today:

  1. failures that already existed before my commit: (A)
  2. failure caused by my commit: (B)-(A)
  3. new failures caused not related to my commit

@cfjchu
Copy link
Collaborator Author

cfjchu commented Feb 11, 2025

@cfjchu cfjchu merged commit 05b16aa into main Feb 11, 2025
222 of 223 checks passed
@cfjchu cfjchu deleted the jchu/fix-meshdevice-trace-replay branch February 11, 2025 06:10
@mtairum mtairum mentioned this pull request Feb 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants