Skip to content

Commit

Permalink
fix the reverted PR for Optimize the web demo for yolov4 (#15478) (#1…
Browse files Browse the repository at this point in the history
…5838)

### Problem description
Have a real-time web demo for yolov4. 
There was a merged PR for this that got reverted due to some failure.
redoing the PR and running more tests for it now.

### What's changed
Enable trace + 2cq
Optimize the post processing 



### Checklist
- [x] Post commit CI passes
- [ ] Blackhole Post commit (if applicable)
- [x ] Model regression CI testing passes (if applicable)
- [x] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [x] New/Existing tests provide coverage for changes

---------

Co-authored-by: Mohamed Bahnas <[email protected]>
Co-authored-by: Mohamed Bahnas <[email protected]>
  • Loading branch information
3 people authored Feb 21, 2025
1 parent 1eef336 commit 9ada8ab
Show file tree
Hide file tree
Showing 24 changed files with 928 additions and 695 deletions.
4 changes: 2 additions & 2 deletions models/demos/wormhole/yolov4/test_yolov4_performant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_run_yolov4_inference(device, use_program_cache, batch_size, act_dtype,


@run_for_wormhole_b0()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "trace_region_size": 1843200}], indirect=True)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "trace_region_size": 6422528}], indirect=True)
@pytest.mark.parametrize(
"batch_size, act_dtype, weight_dtype",
((1, ttnn.bfloat16, ttnn.bfloat16),),
Expand All @@ -50,7 +50,7 @@ def test_run_yolov4_trace_inference(

@run_for_wormhole_b0()
@pytest.mark.parametrize(
"device_params", [{"l1_small_size": 24576, "trace_region_size": 3686400, "num_command_queues": 2}], indirect=True
"device_params", [{"l1_small_size": 24576, "trace_region_size": 6397952, "num_command_queues": 2}], indirect=True
)
@pytest.mark.parametrize(
"batch_size, act_dtype, weight_dtype",
Expand Down
44 changes: 2 additions & 42 deletions models/demos/wormhole/yolov4/test_yolov4_performant_webdemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,52 +8,12 @@
import torch

from models.utility_functions import run_for_wormhole_b0
from models.demos.yolov4.tests.yolov4_perfomant_webdemo import (
run_yolov4_inference,
run_yolov4_trace_inference,
run_yolov4_trace_2cqs_inference,
Yolov4Trace2CQ,
)


@run_for_wormhole_b0()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"batch_size, act_dtype, weight_dtype",
((1, ttnn.bfloat16, ttnn.bfloat16),),
)
def test_run_yolov4_inference(device, use_program_cache, batch_size, act_dtype, weight_dtype, model_location_generator):
run_yolov4_inference(device, batch_size, act_dtype, weight_dtype, model_location_generator)


@run_for_wormhole_b0()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "trace_region_size": 1617920}], indirect=True)
@pytest.mark.parametrize(
"batch_size, act_dtype, weight_dtype",
((1, ttnn.bfloat16, ttnn.bfloat16),),
)
@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_run_yolov4_trace_inference(
device,
use_program_cache,
batch_size,
act_dtype,
weight_dtype,
enable_async_mode,
model_location_generator,
):
run_yolov4_trace_inference(
device,
batch_size,
act_dtype,
weight_dtype,
model_location_generator,
)
from models.demos.yolov4.tests.yolov4_perfomant_webdemo import Yolov4Trace2CQ


@run_for_wormhole_b0()
@pytest.mark.parametrize(
"device_params", [{"l1_small_size": 24576, "trace_region_size": 1617920, "num_command_queues": 2}], indirect=True
"device_params", [{"l1_small_size": 24576, "trace_region_size": 3211264, "num_command_queues": 2}], indirect=True
)
@pytest.mark.parametrize(
"batch_size, act_dtype, weight_dtype",
Expand Down
27 changes: 17 additions & 10 deletions models/demos/yolov4/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,31 @@

## How to run yolov4

- Use the following command to run the yolov4 performant impelementation (95 FPS):
### Model code running with Trace+2CQ
- Use the following command to run the yolov4 performant implementation (71 FPS):
```bash
pytest models/demos/wormhole/yolov4/test_yolov4_performant_webdemo.py::test_run_yolov4_trace_2cqs_inference[True-1-act_dtype0-weight_dtype0-device_params0]
```
pytest models/demos/wormhole/yolov4/test_yolov4_performant.py::test_run_yolov4_trace_2cqs_inference[True-1-act_dtype0-weight_dtype0-device_params0]
```

- You may try the interactive web demo following the instructions here: models/demos/yolov4/web_demo/README.md (25-30 FPS). NOTE: The post-processing is currently running on host. It will be moved to device soon which should significantly improve the end to end FPS.


- Use the following command to run a single-image demo for visualization. NOTE: the following demos are intented for visualization. It is not the performant implementation yet. And, the post processing is currently done on host which we will be moving to device soon.
### Single Image Demo

- Use the following command to run the yolov4 with a giraffe image:
```
```bash
pytest models/demos/yolov4/demo/demo.py
```
- The output file `ttnn_yolov4_320_prediction_demo.jpg` will be generated.

- Use the following command to run the yolov4 with different input image:
```
```bash
pytest --disable-warnings --input-path=<PATH_TO_INPUT_IMAGE> models/demos/yolov4/demo/demo.py
```

Once you run the command, The output file named `ttnn_prediction_demo.jpg` will be generated.

### mAP Accuracy Test
- To be added soon

### Web Demo
- You may try the interactive web demo (35 FPS end-2-end) following the instructions:
```
models/demos/yolov4/web_demo/README.md
```
231 changes: 110 additions & 121 deletions models/demos/yolov4/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ def yolo_forward_dynamic(
by_bh /= output.size(2)

# Shape: [batch, num_anchors * H * W, 1]
bx = bx_bw[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
by = by_bh[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
bw = bx_bw[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
bh = by_bh[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
bx = bx_bw[:, :num_anchors].reshape(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
by = by_bh[:, :num_anchors].reshape(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
bw = bx_bw[:, num_anchors:].reshape(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
bh = by_bh[:, num_anchors:].reshape(output.size(0), num_anchors * output.size(2) * output.size(3), 1)

bx1 = bx - bw * 0.5
by1 = by - bh * 0.5
Expand Down Expand Up @@ -324,12 +324,6 @@ def nms_cpu(boxes, confs, nms_thresh=0.5, min_mode=False):


def post_processing(img, conf_thresh, nms_thresh, output):
# anchors = [12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401]
# num_anchors = 9
# anchor_masks = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
# strides = [8, 16, 32]
# anchor_step = len(anchors) // num_anchors

# [batch, num, 1, 4]
box_array = output[0]
# [batch, num, num_classes]
Expand Down Expand Up @@ -464,34 +458,7 @@ def do_detect(model, img, conf_thresh, nms_thresh, n_classes, device=None, class
output_tensor3 = output_tensor3.reshape(1, 10, 10, 255)
output_tensor3 = torch.permute(output_tensor3, (0, 3, 1, 2))

yolo1 = YoloLayer(
anchor_mask=[0, 1, 2],
num_classes=n_classes,
anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
num_anchors=9,
stride=8,
)

yolo2 = YoloLayer(
anchor_mask=[3, 4, 5],
num_classes=n_classes,
anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
num_anchors=9,
stride=16,
)

yolo3 = YoloLayer(
anchor_mask=[6, 7, 8],
num_classes=n_classes,
anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
num_anchors=9,
stride=32,
)

y1 = yolo1(output_tensor1)
y2 = yolo2(output_tensor2)
y3 = yolo3(output_tensor3)

y1, y2, y3 = gen_yolov4_boxes_confs([output_tensor1, output_tensor2, output_tensor3])
output = get_region_boxes([y1, y2, y3])

t2 = time.time()
Expand All @@ -511,37 +478,8 @@ def do_detect(model, img, conf_thresh, nms_thresh, n_classes, device=None, class
else:
t1 = time.time()
output = model(img)

yolo1 = YoloLayer(
anchor_mask=[0, 1, 2],
num_classes=n_classes,
anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
num_anchors=9,
stride=8,
)

yolo2 = YoloLayer(
anchor_mask=[3, 4, 5],
num_classes=n_classes,
anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
num_anchors=9,
stride=16,
)

yolo3 = YoloLayer(
anchor_mask=[6, 7, 8],
num_classes=n_classes,
anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
num_anchors=9,
stride=32,
)

y1 = yolo1(output[0])
y2 = yolo2(output[1])
y3 = yolo3(output[2])

y1, y2, y3 = gen_yolov4_boxes_confs(output)
output = get_region_boxes([y1, y2, y3])

t2 = time.time()

print("-----------------------------------")
Expand All @@ -556,66 +494,117 @@ def do_detect(model, img, conf_thresh, nms_thresh, n_classes, device=None, class
plot_boxes_cv2(img, boxes[0], "torch_prediction_demo.jpg", class_names)


def gen_yolov4_boxes_confs(output):
n_classes = 80
anchors_array = [12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401]
num_anchors = 9
anchor_masks = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
strides = [8, 16, 32]

yolo1 = YoloLayer(
anchor_mask=anchor_masks[0],
num_classes=n_classes,
anchors=anchors_array,
num_anchors=num_anchors,
stride=strides[0],
)

yolo2 = YoloLayer(
anchor_mask=anchor_masks[1],
num_classes=n_classes,
anchors=anchors_array,
num_anchors=num_anchors,
stride=strides[1],
)

yolo3 = YoloLayer(
anchor_mask=anchor_masks[2],
num_classes=n_classes,
anchors=anchors_array,
num_anchors=num_anchors,
stride=strides[2],
)

y1 = yolo1(output[0])
y2 = yolo2(output[1])
y3 = yolo3(output[2])

return y1, y2, y3


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
"use_pretrained_weight",
[True, False],
ids=[
"pretrained_weight_true",
"pretrained_weight_false",
],
)
def test_yolov4_model(device, model_location_generator, reset_seeds, input_path, use_pretrained_weight):
def test_yolov4(device, reset_seeds, model_location_generator):
torch.manual_seed(0)
model_path = model_location_generator("models", model_subdir="Yolo")
if use_pretrained_weight:
if model_path == "models":
if not os.path.exists("tests/ttnn/integration_tests/yolov4/yolov4.pth"): # check if yolov4.th is availble
os.system(
"tests/ttnn/integration_tests/yolov4/yolov4_weights_download.sh"
) # execute the yolov4_weights_download.sh file

weights_pth = "tests/ttnn/integration_tests/yolov4/yolov4.pth"
else:
weights_pth = str(model_path / "yolov4.pth")

ttnn_model = TtYOLOv4(device, weights_pth)
torch_model = Yolov4()
new_state_dict = {}
ds_state_dict = {k: v for k, v in ttnn_model.torch_model.items()}

keys = [name for name, parameter in torch_model.state_dict().items()]
values = [parameter for name, parameter in ds_state_dict.items()]

for i in range(len(keys)):
new_state_dict[keys[i]] = values[i]
if model_path == "models":
if not os.path.exists("tests/ttnn/integration_tests/yolov4/yolov4.pth"): # check if yolov4.th is availble
os.system(
"tests/ttnn/integration_tests/yolov4/yolov4_weights_download.sh"
) # execute the yolov4_weights_download.sh file

torch_model.load_state_dict(new_state_dict)
torch_model.eval()
weights_pth = "tests/ttnn/integration_tests/yolov4/yolov4.pth"
else:
torch_model = Yolov4.from_random_weights()
ttnn_weights = update_weight_parameters(OrderedDict(torch_model.state_dict()))
ttnn_model = TtYOLOv4(device, ttnn_weights)
weights_pth = str(model_path / "yolov4.pth")

n_classes = 80
namesfile = "models/demos/yolov4/demo/coco.names"
if input_path == "":
imgfile = "models/demos/yolov4/demo/giraffe_320.jpg"
else:
imgfile = input_path
ttnn_model = TtYOLOv4(weights_pth, device)

imgfile = "models/demos/yolov4/demo/giraffe_320.jpg"
width = 320
height = 320

img = cv2.imread(imgfile)

# Inference input size is 416*416 does not mean training size is the same
# Training size could be 608*608 or even other sizes
# Optional inference sizes:
# Hight in {320, 416, 512, 608, ... 320 + 96 * n}
# Width in {320, 416, 512, 608, ... 320 + 96 * m}
sized = cv2.resize(img, (width, height))
sized = cv2.cvtColor(sized, cv2.COLOR_BGR2RGB)

for i in range(2): # This 'for' loop is for speed check
# Because the first iteration is usually longer
do_detect(ttnn_model, sized, 0.3, 0.4, n_classes, device, class_name=namesfile, imgfile=imgfile)
img = cv2.resize(img, (width, height))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if type(img) == np.ndarray and len(img.shape) == 3: # cv2 image
img = torch.from_numpy(img.transpose(2, 0, 1)).float().div(255.0).unsqueeze(0)
elif type(img) == np.ndarray and len(img.shape) == 4:
img = torch.from_numpy(img.transpose(0, 3, 1, 2)).float().div(255.0)
else:
exit()
torch_input = torch.autograd.Variable(img)

input_tensor = torch.permute(torch_input, (0, 2, 3, 1))
ttnn_input = ttnn.from_torch(input_tensor, ttnn.bfloat16)

torch_model = Yolov4()
new_state_dict = dict(zip(torch_model.state_dict().keys(), ttnn_model.torch_model.values()))
torch_model.load_state_dict(new_state_dict)
torch_model.eval()

torch_output_tensor = torch_model(torch_input)

ref1, ref2, ref3 = gen_yolov4_boxes_confs(torch_output_tensor)
ref_boxes, ref_confs = get_region_boxes([ref1, ref2, ref3])

ttnn_output_tensor = ttnn_model(ttnn_input)
result_boxes_padded = ttnn.to_torch(ttnn_output_tensor[0])
result_confs = ttnn.to_torch(ttnn_output_tensor[1])

result_boxes_padded = result_boxes_padded.permute(0, 2, 1, 3)
result_boxes_list = []
# Unpadding
# That ttnn tensor is the concat output of 3 padded tensors
# As a perf workaround I'm doing the unpadding on the torch output here.
# TODO: cleaner ttnn code when ttnn.untilize() is fully optimized
box_1_start_i = 0
box_1_end_i = 6100
box_2_start_i = 6128
box_2_end_i = 6228
box_3_start_i = 6256
box_3_end_i = 6356
result_boxes_list.append(result_boxes_padded[:, box_1_start_i:box_1_end_i])
result_boxes_list.append(result_boxes_padded[:, box_2_start_i:box_2_end_i])
result_boxes_list.append(result_boxes_padded[:, box_3_start_i:box_3_end_i])
result_boxes = torch.cat(result_boxes_list, dim=1)

## Giraffe image detection
conf_thresh = 0.3
nms_thresh = 0.4
output = [result_boxes.to(torch.float16), result_confs.to(torch.float16)]

boxes = post_processing(img, conf_thresh, nms_thresh, output)
namesfile = "models/demos/yolov4/demo/coco.names"
class_names = load_class_names(namesfile)
img = cv2.imread(imgfile)
plot_boxes_cv2(img, boxes[0], "ttnn_yolov4_320_prediction_demo.jpg", class_names)
Loading

0 comments on commit 9ada8ab

Please sign in to comment.