Skip to content

Commit

Permalink
Merge pull request #80 from luxonis/cli-yolov10
Browse files Browse the repository at this point in the history
Add CLI YoloV10 support
  • Loading branch information
sokovninn authored Jul 1, 2024
2 parents 6393717 + e2777c6 commit e47735c
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ docker compose run tools-cli shared_with_container/models/yolov6nr4.pt
# Building the package
pip install .
# Running the package
tools --model shared_with_container/models/yolov6nr4.pt --imgsz "416"
tools shared_with_container/models/yolov6nr4.pt --imgsz "416"
```

### Arguments
Expand Down
7 changes: 6 additions & 1 deletion tools/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
YOLOV6R4_CONVERSION,
YOLOV7_CONVERSION,
YOLOV8_CONVERSION,
YOLOV10_CONVERSION,
Config,
detect_version,
upload_file_to_remote,
Expand All @@ -35,7 +36,8 @@
YOLOV6R3_CONVERSION,
YOLOV6R4_CONVERSION,
YOLOV7_CONVERSION,
YOLOV8_CONVERSION
YOLOV8_CONVERSION,
YOLOV10_CONVERSION,
]


Expand Down Expand Up @@ -110,6 +112,9 @@ def convert(
elif version == YOLOV8_CONVERSION:
from tools.yolo.yolov8_exporter import YoloV8Exporter
exporter = YoloV8Exporter(str(model_path), config.imgsz, config.use_rvc2)
elif version == YOLOV10_CONVERSION:
from tools.yolo.yolov10_exporter import YoloV10Exporter
exporter = YoloV10Exporter(str(model_path), config.imgsz, config.use_rvc2)
else:
logger.error("Unrecognized model version.")
raise typer.Exit(code=1)
Expand Down
2 changes: 2 additions & 0 deletions tools/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ClassifyV8,
DetectV5,
DetectV7,
DetectV10,
)
from .exporter import Exporter
from .stage2 import Multiplier
Expand All @@ -33,4 +34,5 @@
"Multiplier",
"DetectV5",
"DetectV7",
"DetectV10",
]
7 changes: 7 additions & 0 deletions tools/modules/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,10 @@ def forward(self, x):
x = torch.cat(x, 1)
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
return x

class DetectV10(DetectV8):
"""YOLOv10 Detect head for detection models."""
def __init__(self, old_detect, use_rvc2):
super().__init__(old_detect, use_rvc2)
self.cv2 = old_detect.one2one_cv2
self.cv3 = old_detect.one2one_cv3
2 changes: 2 additions & 0 deletions tools/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
YOLOV6R4_CONVERSION,
YOLOV7_CONVERSION,
YOLOV8_CONVERSION,
YOLOV10_CONVERSION,
detect_version,
)
from .in_channels import get_first_conv2d_in_channels
Expand All @@ -31,6 +32,7 @@
"YOLOV6R4_CONVERSION",
"YOLOV7_CONVERSION",
"YOLOV8_CONVERSION",
"YOLOV10_CONVERSION",
"GOLD_YOLO_CONVERSION",
"UNRECOGNIZED",
"resolve_path",
Expand Down
3 changes: 3 additions & 0 deletions tools/utils/version_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
YOLOV6R4_CONVERSION = "yolov6r4"
YOLOV7_CONVERSION = "yolov7"
YOLOV8_CONVERSION = "yolov8"
YOLOV10_CONVERSION = "yolov10"
GOLD_YOLO_CONVERSION = "goldyolo"
UNRECOGNIZED = "none"

Expand Down Expand Up @@ -61,6 +62,8 @@ def detect_version(path: str, debug: bool = False) -> str:
return YOLOV6R3_CONVERSION
elif "yolov7" in content:
return YOLOV7_CONVERSION
elif "yolov10" in content:
return YOLOV10_CONVERSION
elif (
"SPPF" in content
or "yolov5" in content
Expand Down
2 changes: 1 addition & 1 deletion tools/yolo/ultralytics
Submodule ultralytics updated 433 files
82 changes: 82 additions & 0 deletions tools/yolo/yolov10_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import sys

sys.path.append("./tools/yolo/ultralytics")

from ultralytics.nn.modules import Detect
from ultralytics.nn.tasks import attempt_load_one_weight
import torch
from typing import Tuple, List, Optional

from tools.modules import Exporter, DetectV10
from tools.utils import get_first_conv2d_in_channels

class YoloV10Exporter(Exporter):
def __init__(
self,
model_path: str,
imgsz: Tuple[int, int],
use_rvc2: bool,
):
super().__init__(
model_path,
imgsz,
use_rvc2,
subtype="yolov10",
output_names=["output1_yolov10", "output2_yolov10", "output3_yolov10"],
)
self.load_model()

def load_model(self):
# load the model
model, _ = attempt_load_one_weight(
self.model_path, device="cpu", inplace=True, fuse=True
)

if isinstance(model.model[-1], (Detect)):
model.model[-1] = DetectV10(model.model[-1], self.use_rvc2)



self.names = model.module.names if hasattr(model, 'module') else model.names # get class names
# check num classes and labels
assert model.yaml["nc"] == len(self.names), f'Model class count {model.yaml["nc"]} != len(names) {len(self.names)}'

try:
self.number_of_channels = get_first_conv2d_in_channels(model)
# print(f"Number of channels: {self.number_of_channels}")
except Exception as e:
print(f"Error while getting number of channels: {e}")

# check if image size is suitable
gs = max(int(model.stride.max()), 32) # model stride
if isinstance(self.imgsz, int):
self.imgsz = [self.imgsz, self.imgsz]
for sz in self.imgsz:
if sz % gs != 0:
raise ValueError(f"Image size is not a multiple of maximum stride {gs}")

# ensure correct length
if len(self.imgsz) != 2:
raise ValueError("Image size must be of length 1 or 2.")

model.eval()
self.model = model


def export_nn_archive(self, class_names: Optional[List[str]] = None):
"""
Export the model to NN archive format.
Args:
class_list (Optional[List[str]], optional): List of class names. Defaults to None.
"""
names = list(self.model.names.values())

if class_names is not None:
assert len(class_names) == len(names), f"Number of the given class names {len(class_names)} does not match number of classes {len(names)} provided in the model!"
names = class_names

self.f_nn_archive = (self.output_folder / f"{self.model_name}.tar.xz").resolve()

self.make_nn_archive(names, self.model.model[-1].nc)

0 comments on commit e47735c

Please sign in to comment.