forked from msethi006/client-ocr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_utils.py
168 lines (143 loc) · 6.7 KB
/
custom_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch
from utils.dataloaders import LoadImages
from utils.general import non_max_suppression,Profile,scale_boxes,xyxy2xywh
from pathlib import Path
from utils.plots import Annotator
import cv2
from PIL import Image
from torchvision.transforms import ToTensor,Resize
def get_yolo_results(image_path,model):
source = image_path
imgsz=(640, 640)
stride = model.stride
pt = model.pt
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
conf_thres = 0.25
iou_thres=0.45
classes = None
agnostic_nms = False
max_det = 1000
webcam = False
save_dir = Path('D:/Upwork/KS/MOHIT/')
save_crop = False
line_thickness = 3
names = model.names
total_detections = []
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
for path, im, im0s, vid_cap, s in dataset:
with dt[0]:
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
# Inference
with dt[1]:
visualize = False
augment = False
pred = model(im, augment=augment, visualize=visualize)
# NMS
with dt[2]:
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
for i, det in enumerate(pred): # per image
seen += 1
if webcam: # batch_size >= 1
p, im0, frame = path[i], im0s[i].copy(), dataset.count
s += f'{i}: '
else:
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
p = Path(p) # to Path
save_path = str(save_dir / p.name) # im.jpg
#txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
txt_path = str(save_dir / p.stem)
s += '%gx%g ' % im.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
# Print results
for c in det[:, 5].unique():
n = (det[:, 5] == c).sum() # detections per class
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
# Write results
save_conf = True
for *xyxy, conf, cls in reversed(det):
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
# with open(f'{txt_path}.txt', 'a') as f:
# f.write(('%g ' * len(line)).rstrip() % line + '\n')
total_detections.append(line)
converted_detections = [tuple(value.item() if isinstance(value, torch.Tensor) else value for value in item) for item in total_detections]
return converted_detections
def get_bboxes(image_path, detections):
image = cv2.imread(image_path)
bounding_boxes = []
for cls, x, y, w, h, confidence in detections:
cls = int(cls)
left = int((x - w/2) * image.shape[1]) - 10
top = int((y - h/2) * image.shape[0]) - 10
width = int(w * image.shape[1]) + 20
height = int(h * image.shape[0]) + 20
bounding_boxes.append({
'class': cls,
'left': left,
'top': top,
'width': width,
'height': height
})
return bounding_boxes
def calculate_bbox_area(bbox):
width = bbox['width']
height = bbox['height']
return width * height
def remove_lowest_area_bbox(detections):
if len(detections) <= 2:
return detections
areas = [calculate_bbox_area(bbox) for bbox in detections]
min_area_index = areas.index(min(areas))
# Return detections without the detection with the lowest area
return detections[:min_area_index] + detections[min_area_index + 1:]
def perform_inference(model, image, model_args):
transform = ToTensor()
resize = resize = Resize((32, 128)) # Resize the image to match the model's input size
image = Image.fromarray(image)
image = image.convert("RGB")
image = resize(image) # Resize the image
image_tensor = transform(image).unsqueeze(0).to(model_args['device'])
with torch.no_grad():
output = model(image_tensor)
# Process the output as needed
return output
def ocr_text(model,image, model_args):
inference_result = perform_inference(model, image, model_args)
# Greedy decoding
pred = inference_result.softmax(-1)
label, confidence = model.tokenizer.decode(pred)
return (label[0],["{:.2%}".format(value) for value in confidence[0].tolist()[:-1]])
def draw_detections_with_text(image, detections, texts):
for i, bbox in enumerate(detections):
left = bbox['left']
top = bbox['top']
width = bbox['width']
height = bbox['height']
text = texts[i]
# Draw the bbox on the image
color = (0, 255, 0) # Green color (BGR format)
thickness = 2 # Thickness of the bbox lines
cv2.rectangle(image, (left, top), (left + width, top + height), color, thickness)
cv2.putText(image, text, (left, top - 5), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0), 10)
return image
def process_detections(detections,image_path,model_ocr, model_args):
results ={}
img = cv2.imread(image_path)
for detection in detections:
if detection['class'] == 0:
cropped_img = img[detection['top']:detection['top']+detection['height'],detection['left']:detection['left']+detection['width']]
results['DotPeenText'] = ocr_text(model_ocr,cropped_img, model_args)
#predicted_Dot_Peening_Text_list.append(ocr.ocr(sharpen_image(cropped_img))[0][0][1][0].replace(' ',''))
else:
cropped_img = img[detection['top']:detection['top']+detection['height'],detection['left']:detection['left']+detection['width']]
results['MachinePeenText'] = ocr_text(model_ocr,cropped_img, model_args)
return results