Skip to content

Commit

Permalink
Fix ruff & Add ignore style
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle committed Sep 10, 2024
1 parent ec2de83 commit d7debc3
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 47 deletions.
2 changes: 1 addition & 1 deletion netspresso_inference_package/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .inference.inference_service import InferenceService
from .inference.inference_service import InferenceService
3 changes: 2 additions & 1 deletion netspresso_inference_package/enums.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum


class Enums(Enum):
FRAMEWORK = "framework"
INPUTS = "inputs"
Expand Down Expand Up @@ -87,4 +88,4 @@ def _missing_(cls, value):
return cls[value]
except KeyError:
msg = f"{cls.__name__} expected {', '.join(list(cls.__members__.keys()))} but got `{value}`"
raise KeyError(msg)
raise KeyError(msg)
2 changes: 1 addition & 1 deletion netspresso_inference_package/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ def __init__(self, message="Can not load this tflite file."):

class WrongDatsetFile(BaseException):
def __init__(self, message="Dataset file is a not archive file neither '.npy' file.",):
super().__init__(message)
super().__init__(message)
8 changes: 3 additions & 5 deletions netspresso_inference_package/inference/abs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import abc
from typing import Any
import os


class Basemodel(metaclass=abc.ABCMeta):
class Basemodel(metaclass=abc.ABCMeta):
@abc.abstractmethod
def model_input_output_attributes(self):
raise NotImplementedError()

@abc.abstractmethod
def inference(self, **kwargs):
raise NotImplementedError()
raise NotImplementedError()
18 changes: 7 additions & 11 deletions netspresso_inference_package/inference/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import os
from pathlib import Path
from typing import Dict, Any, Union
from typing import Any, Dict, Union

import numpy as np
from loguru import logger

from ..enums import Enums
from ..utils import extract_archive
from ..exceptions import NumpyLoadError, WrongDatsetFile
from ..utils import extract_archive


class NumpyDataLoader:
def __init__(self, dataset_file_path:str, inputs:Dict[Union[str, int], Any]):
self.npy = self.load_datasets(dataset_file_path, inputs)

def load_datasets(
self,
dataset_file_path:str,
inputs:Dict[Union[str, int], Any]
)->Dict[Union[str, int], np.ndarray]:

def load_datasets(self, dataset_file_path:str, inputs:Dict[Union[str, int], Any])->Dict[Union[str, int], np.ndarray]:
outputs = {}
# for single numpy file case
if Path(dataset_file_path).is_file() and Path(dataset_file_path).suffix == ".npy":
for key in inputs.keys():
for key in inputs:
try:
outputs[key]=np.load(dataset_file_path)
except ValueError:
Expand All @@ -33,7 +29,7 @@ def load_datasets(
# Unzip dataset_file_path
extract_archive(dataset_file_path, os.path.dirname(dataset_file_path))
parent_directory = os.path.dirname(dataset_file_path)
for key in inputs.keys():
for key in inputs:
file_name = f"{key}.npy"
file_path = os.path.join(parent_directory, file_name)
if os.path.exists(file_path):
Expand All @@ -45,4 +41,4 @@ def load_datasets(
logger.warning(f"Warning: File {file_path} does not exist.")
return outputs
else:
raise WrongDatsetFile()
raise WrongDatsetFile()
20 changes: 10 additions & 10 deletions netspresso_inference_package/inference/inference_service.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import os
from pathlib import Path
from typing import Dict, Any
from typing import Any, Dict

import numpy as np
from loguru import logger

from .tflite_inference import TFLITE
from .onnx_inference import ONNX
from .data_loader import NumpyDataLoader
from ..utils import make_temp_dir, compress_files, delete_parent_directory
from ..exceptions import NotSupportedFramework
from ..utils import compress_files, make_temp_dir
from .data_loader import NumpyDataLoader
from .onnx_inference import ONNX
from .tflite_inference import TFLITE


class InferenceService:
Expand All @@ -33,7 +33,7 @@ def set_model_obj(self, model_file_path:str, num_threads:int):
# model_obj = TFLITE(model_file_path, kwargs["num_threads"])
# elif model_info[Enums.FRAMEWORK] == Enums.ONNX:
# model_obj = ONNX(model_file_path)

if len(model_obj.inputs) > 1:
logger.info(f'{self.model_file_path} has {len(model_obj.inputs)} nodes for input layer')
if len(model_obj.outputs) > 1:
Expand All @@ -52,14 +52,14 @@ def postprocess(self, inference_results:Dict[Any, np.ndarray]):
# save npy file for each layers result
files_path = []
result_file_path = os.path.join(self.result_save_path, "archive.zip")
for k, v in inference_results.items():
for k, _v in inference_results.items():
npy_file_path = os.path.join(self.result_save_path, f"{k}.npy")
np.save(npy_file_path, inference_results[k])
files_path.append(npy_file_path)
# zip npy files
compress_files(files_path, result_file_path)
self.result_file_path = result_file_path

def run(self, dataset_file_path):
self.dataset_file_path = dataset_file_path
inference_results = self.inference(dataset_file_path)
Expand All @@ -71,5 +71,5 @@ def run(self, dataset_file_path):

inf_service = InferenceService(
model_file_path="/app/tests/people_detection.onnx"
)
inf_service.run(dataset_file_path="/app/tests/dataset_for_onnx.npy")
)
inf_service.run(dataset_file_path="/app/tests/dataset_for_onnx.npy")
20 changes: 10 additions & 10 deletions netspresso_inference_package/inference/onnx_inference.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from typing import Dict, Union, Any
from typing import Any, Dict, Union

import numpy as np
import onnxruntime
import onnx
import onnxruntime

from .abs import Basemodel
from ..exceptions import NotLoadableONNX
from .abs import Basemodel


def create_outputs(outputs:Dict[Union[str, int], Any], results:Dict[Union[str, int], Any]):
if len(outputs) != len(results):
raise ValueError("The length of 'outputs' and 'results' must be the same.")

for (key, value), result in zip(outputs.items(), results):
for (key, _value), result in zip(outputs.items(), results):
outputs[key] = result

return outputs


Expand All @@ -41,20 +41,20 @@ def model_input_output_attributes(self, model_file_path:str):
for oup in model.graph.output:
shape = str(oup.type.tensor_type.shape.dim)
outputs[oup.name] = [int(s) for s in shape.split() if s.isdigit()]

return inputs, outputs

def inference(self, preprocess_result: Dict[int, np.ndarray]) -> Dict[int, np.ndarray]:
inputs = {}
for k, v in preprocess_result.items():
for k, _v in preprocess_result.items():
ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(preprocess_result[k])
inputs[k] = ortvalue
# TODO: make function which return npy generator when len(preprocess_result[k]) > 1

outputs_list = []
for k in self.outputs.keys():
for k in self.outputs:
outputs_list.append(k)

results = self.model_obj.run(outputs_list, inputs)
output_dict = create_outputs(self.outputs, results)
return output_dict
return output_dict
8 changes: 4 additions & 4 deletions netspresso_inference_package/inference/tflite_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import numpy as np

from .abs import Basemodel
from ..enums import EnumInputNodeShapeFormat, EnumNodeRawDataType
from ..exceptions import NotLoadableTFLITE
from .abs import Basemodel

try:
import tflite_runtime.interpreter as tflite
Expand Down Expand Up @@ -151,10 +151,10 @@ def model_input_output_attributes(self):
output_data_attribute.quantization = output_detail.get("quantization")
outputs[output_data_attribute.key] = output_data_attribute

return inputs, outputs
return inputs, outputs

def inference(self, preprocess_result: Dict[int, np.ndarray], **kwargs) -> Dict[int, np.ndarray]:
for k, v in self.inputs.items():
for _k, v in self.inputs.items():
if v.dtype in [np.uint8, np.int8, "int8", "unit8"]:
pass

Expand All @@ -166,4 +166,4 @@ def inference(self, preprocess_result: Dict[int, np.ndarray], **kwargs) -> Dict[
output_dict = {}
for output_location in iter(self.outputs):
output_dict[output_location] = self.interpreter_obj.get_tensor(output_location)
return output_dict
return output_dict
8 changes: 4 additions & 4 deletions netspresso_inference_package/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import shutil
import tarfile
import tempfile
from pathlib import Path
import zipfile
import tarfile
import shutil
from pathlib import Path

import py7zr
import rarfile
Expand Down Expand Up @@ -50,4 +50,4 @@ def delete_parent_directory(file_path:str):
except FileNotFoundError:
logger.info(f"Directory {file_path} does not exist.")
except PermissionError:
logger.info(f"Permission denied to delete {file_path}.")
logger.info(f"Permission denied to delete {file_path}.")
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ ignore = [
"C901",
"B008",
"SIM115",
"E722",
"B904",
]

[tool.ruff.per-file-ignores]
Expand Down

0 comments on commit d7debc3

Please sign in to comment.