Skip to content

Commit

Permalink
Merge pull request #19 from rhysdg/feat-gdino-extended-ops
Browse files Browse the repository at this point in the history
Feat  -  gdino extended ops, extended tensorrt execution provider settings
  • Loading branch information
rhysdg authored Aug 1, 2024
2 parents 44ea384 + d43a061 commit e2a0ff6
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 14 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,6 @@ decrypt.py

**/*.pth
**/*.onnx
**/*.engine
**/*profile

6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,10 @@ Last of all the aim here is to keep up with the latest optimised foundation mode
logging.basicConfig(level=logging.INFO)

output_dir = 'output'

ogd = OnnxGDINO(type='gdino_fp32')

#modest speedup with TensorRT 10.0.1.6-1 and fp16, amplitude hw currently
#torch with amp autocast and matmul enhancements at 'high' is still faster currently
ogd = OnnxGDINO(type='gdino_fp32', trt=True)

payload = ogd.preprocess_query("spaceman. spacecraft. water. clouds. space helmet. glove")
img, img_transformed = load_image('images/wave_planet.webp')
Expand Down
67 changes: 55 additions & 12 deletions gdino/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import time
import scipy
import errno
import os
import gdown
import logging
import warnings
from pathlib import Path
from typing import (List,
Tuple,
Expand All @@ -20,14 +22,17 @@

import torch
import onnxruntime as ort
from onnxruntime_extensions import get_library_path as _lib_path
from utils.gdino_utils import (generate_masks_with_special_tokens_and_transfer_map,
create_positive_map_from_span
)
from gdino.gdino_tokenizer import BertTokenizer


logging.basicConfig(level=logging.INFO)

T = TypeVar("T")
logging.basicConfig(level=logging.INFO)
ort.set_default_logger_severity(3)


class OnnxGDINO:
"""
Expand All @@ -40,8 +45,10 @@ def __init__(
batch_size: Optional[int] = None,
type='gdino_fp32',
device='cuda',
trt=False
):
trt=False,
warmup=False,
n_iters=10
):
"""
"""
Expand All @@ -55,13 +62,23 @@ def __init__(
self.providers.insert(0, 'CUDAExecutionProvider')

if trt:
self.providers.insert(0, 'TensorrtExecutionProvider')
self.providers.insert(0, ('TensorrtExecutionProvider', {'trt_engine_cache_enable': True,
'trt_max_workspace_size': 4294967296,
'trt_engine_cache_path': f'{os.path.dirname(os.path.abspath(__file__))}/data',
'trt_engine_hw_compatible': True,
'trt_sparsity_enable': True,
'trt_build_heuristics_enable': True,
'trt_builder_optimization_level': 0,
'trt_fp16_enable': True
}
)
)


if self.providers:
logging.info(
"Available providers for ONNXRuntime: %s", ", ".join(self.providers)
)
"Available providers for ONNXRuntime: ")



self.embedding_size = 512
Expand All @@ -77,8 +94,9 @@ def __init__(

self.tokenizer = BertTokenizer(vocab_file=vocab_dir)
self.model = self._load_model(model_dir)


if warmup:
self.warmup(n_iters=n_iters)

self._batch_size = batch_size

Expand All @@ -90,15 +108,20 @@ def sigmoid(x):
def _load_model(self, path: str):

sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_options.log_severity_level = 3
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_options.register_custom_ops_library(_lib_path())



try:
if os.path.exists(path):
# `providers` need to be set explicitly since ORT 1.9
return ort.InferenceSession(
path, providers=self.providers
path,
sess_options,
providers=self.providers,

)
else:
raise FileNotFoundError(
Expand All @@ -120,7 +143,27 @@ def _load_model(self, path: str):
providers=self.providers,

)



def warmup(self, n_iters=10):

payload = self.preprocess_query('time. to. warmup')

dummy = np.random.randn(1, 3, 800, 1200).astype(np.float32)

for i in range(n_iters):

_ , _ = self.model.run(None, {'img': dummy,
'input_ids': np.array(payload['input_ids']),
'attention_mask': np.array(payload['attention_mask']).astype(bool),
'position_ids': payload['position_ids'].detach().numpy(),
'token_type_ids': np.array(payload['token_type_ids']),
'text_token_mask': payload['text_token_mask'].detach().numpy()
}
)



def preprocess_query(self,
query,
max_text_len=256
Expand Down Expand Up @@ -207,7 +250,7 @@ def inference(self,
assert text_threshold is not None or token_spans is not None, "text_threshold and token_spans should not be None at the same time!"

image = np.expand_dims(image, 0)

logits, boxes = self.model.run(None, {'img': image,
'input_ids': np.array(payload['input_ids']),
'attention_mask': np.array(payload['attention_mask']).astype(bool),
Expand Down
Binary file modified output/pred.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
torch==2.3.1
torchvision==0.18.1
onnxruntime-gpu==1.18.0
onnxruntime-extensions
sentencepiece==0.2.0
pillow==10.3.0
gdown==5.2.0
Expand All @@ -10,3 +11,4 @@ ftfy==6.2.0
regex==2024.5.15
scipy==1.13.1
gradio==3.26.0

0 comments on commit e2a0ff6

Please sign in to comment.