diff --git a/luxonis_train/models/luxonis_model.py b/luxonis_train/models/luxonis_model.py index 7cd5e02d..24aa6199 100644 --- a/luxonis_train/models/luxonis_model.py +++ b/luxonis_train/models/luxonis_model.py @@ -398,7 +398,17 @@ def export_onnx(self, save_path: str, **kwargs) -> list[str]: for i in range(len(out)) ] ) - output_names = [ + + if self.cfg.exporter.output_names is not None: + len_names = len(self.cfg.exporter.output_names) + if len_names != len(output_order): + logger.warning( + f"Number of provided output names ({len_names}) does not match " + f"number of outputs ({len(output_order)}). Using default names." + ) + self.cfg.exporter.output_names = None + + output_names = self.cfg.exporter.output_names or [ f"{node_name}/{output_name}/{i}" for node_name, output_name, i in output_order ] diff --git a/luxonis_train/utils/config.py b/luxonis_train/utils/config.py index 9a1552a1..8f004c4a 100644 --- a/luxonis_train/utils/config.py +++ b/luxonis_train/utils/config.py @@ -252,6 +252,7 @@ class ExportConfig(BaseModel): reverse_input_channels: bool = True scale_values: list[float] | None = None mean_values: list[float] | None = None + output_names: list[str] | None = None onnx: OnnxExportConfig = OnnxExportConfig() blobconverter: BlobconverterExportConfig = BlobconverterExportConfig() upload_url: str | None = None