Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Traced Transformers #74

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 67 additions & 12 deletions pero_ocr/ocr_engine/transformer_ocr_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import print_function

import logging
import torch
import numpy as np
from .line_ocr_engine import BaseEngineLineOCR
Expand All @@ -18,18 +18,34 @@ def __init__(self, json_def, device, batch_size=16, substitute_output_atomic: bo
self.sentence_boundary_ind = len(self.characters) - 2
self.ignore_ind = len(self.characters) - 1

self.net = transformer.build_net(net=self.net_name,
input_height=self.line_px_height,
input_channels=3,
nb_output_symbols=len(self.characters) - 2)
self.logger = logging.getLogger(__name__)

print(self.net)
self.exported = False
self.net = None
self.load_net()
self.logger.info(self.net)

self.net.load_state_dict(torch.load(self.checkpoint, map_location=device))
self.net.eval()
self.net = self.net.to(device)
self.max_decoded_seq_length = 210

def load_net(self):
if self.device.type == "cpu":
self.checkpoint += ".cpu"

if 'exported' in self.config and self.config['exported']:
net = torch.jit.load(self.checkpoint, map_location=self.device)
self.exported = True

else:
net = transformer.build_net(net=self.net_name,
input_height=self.line_px_height,
input_channels=3,
nb_output_symbols=len(self.characters) - 2)

net.load_state_dict(torch.load(self.checkpoint, map_location=self.device))
net.eval()

self.net = net.to(self.device)

def run_ocr(self, batch_data):
with torch.no_grad():
batch_data = np.transpose(batch_data, (0, 3, 1, 2))
Expand All @@ -40,13 +56,52 @@ def run_ocr(self, batch_data):
new_batch_data[:, :, :, s:s+batch_data.shape[3]] = batch_data
batch_data = new_batch_data

labels, logits = self.transcribe_batch(batch_data, is_cached=True)
if self.exported:
labels, logits = self.transcribe_batch_exported(batch_data)
else:
labels, logits = self.transcribe_batch(batch_data, is_cached=True)

logits = logits.cpu().numpy()
decoded = self.decode(labels)

return decoded, logits

def transcribe_batch_exported(self, inputs):
lines = torch.from_numpy(inputs).to(self.device).float() / 255.0

encoded_lines = self.net.encode(lines)
encoded_lines = self.net.adapt(encoded_lines)

partial_transcriptions = torch.tensor([self.sentence_boundary_ind] * lines.shape[0], dtype=torch.long,
device=self.device).unsqueeze(1)
alive_mask = torch.full((lines.shape[0],), 1, dtype=torch.long, device=self.device)

logits = []

for counter in range(self.max_decoded_seq_length):
step_logits = self.net.decode_step(encoded_lines, partial_transcriptions)
logits.append(step_logits)

sampled_characters = torch.argmax(step_logits, dim=-1)

surviving_lines = (sampled_characters != self.sentence_boundary_ind)
alive_mask *= surviving_lines

if sum(alive_mask) == 0:
break

if partial_transcriptions.shape[0] > lines.shape[-1] // 4:
self.logger.warning(f'The transcription is getting way too long ({len(partial_transcriptions)}) for '
f'the line ({lines.shape}), aborting it at shape {partial_transcriptions.shape}')
break

partial_transcriptions = torch.cat([partial_transcriptions, sampled_characters.unsqueeze(1)], dim=1)

outs = self.postprocess_decoded(partial_transcriptions[:, 1:], self.ignore_ind, self.sentence_boundary_ind)
logits = torch.stack(logits).permute(1, 0, 2)

return outs, logits

def transcribe_batch(self, inputs, is_cached=False):
lines = torch.from_numpy(inputs).to(self.device).float()
lines /= 255.0
Expand Down Expand Up @@ -77,8 +132,8 @@ def transcribe_batch(self, inputs, is_cached=False):
break

if len(partial_transcripts) > inputs.shape[-1] // 4: # four pixels per letter is already ridiculous
print(f'The transcription is getting way too long ({len(partial_transcripts)}) for the line '
f'({inputs.shape}), aborting it at shape {partial_transcripts.shape}')
self.logger.warning(f'The transcription is getting way too long ({len(partial_transcripts)}) for '
f'the line ({inputs.shape}), aborting it at shape {partial_transcripts.shape}')
break

partial_transcripts = torch.cat([partial_transcripts, samples.unsqueeze(0)], dim=0)
Expand Down