Skip to content

Commit

Permalink
Batch inference added
Browse files Browse the repository at this point in the history
  • Loading branch information
Anakha Vasanthakumaribabu committed Sep 19, 2022
1 parent 30fcf1a commit a3b0ec8
Show file tree
Hide file tree
Showing 5 changed files with 539 additions and 0 deletions.
193 changes: 193 additions & 0 deletions adSimServer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#!/usr/bin/env python

import time, threading, queue, argparse
import numpy as np
import pvaccess as pva
import os, os.path

__version__ = pva.__version__

class AdSimServer:

DELAY_CORRECTION = 0.0001
PVA_TYPE_KEY_MAP = {
np.dtype('uint8') : 'ubyteValue',
np.dtype('int8') : 'byteValue',
np.dtype('uint16') : 'ushortValue',
np.dtype('int16') : 'shortValue',
np.dtype('uint32') : 'uintValue',
np.dtype('int32') : 'intValue',
np.dtype('uint64') : 'ulongValue',
np.dtype('int64') : 'longValue',
np.dtype('float32') : 'floatValue',
np.dtype('float64') : 'doubleValue'
}

def __init__(self, input_directory, input_file, frame_rate, nf, nx, ny, runtime, channel_name, start_delay, report_frequency):
self.arraySize = None
self.delta_t = 0
if frame_rate > 0:
self.delta_t = 1.0/frame_rate
self.runtime = runtime
self.report_frequency = report_frequency

input_files = []
if input_directory is not None:
input_files = [os.path.join(input_directory, f) for f in os.listdir(input_directory) if os.path.isfile(os.path.join(input_directory, f))]
if input_file is not None:
input_files.append(input_file)
self.frames = None
for f in input_files:
try:
new_frames = np.load(f)
if self.frames is None:
self.frames = new_frames
else:
self.frames = np.append(self.frames, new_frames, axis=0)
print('Loaded input file %s' % (f))
except Exception as ex:
print('Cannot load input file %s, skipping it: %s' % (f, ex))
if self.frames is None:
print('Generating random frames')
self.frames = np.random.randint(0, 256, size=(nf, nx, ny), dtype=np.int16)
self.n_input_frames, self.rows, self.cols = self.frames.shape
self.pva_type_key = self.PVA_TYPE_KEY_MAP.get(self.frames.dtype)
print('Number of input frames: %s (size: %sx%s, type: %s)' % (self.n_input_frames, self.rows, self.cols, self.frames.dtype))

self.channel_name = channel_name
self.frame_rate = frame_rate
self.server = pva.PvaServer()
self.server.addRecord(self.channel_name, pva.NtNdArray())
self.frame_map = {}
self.current_frame_id = 0
self.n_published_frames = 0
self.start_time = 0
self.last_published_time = 0
self.start_delay = start_delay
self.is_done = False

def get_timestamp(self):
s = time.time()
ns = int((s-int(s))*1000000000)
s = int(s)
return pva.PvTimeStamp(s,ns)

def frame_producer(self, extraFieldsPvObject=None):
for frame_id in range(0, self.n_input_frames):
if self.is_done:
return

if extraFieldsPvObject is None:
nda = pva.NtNdArray()
else:
nda = pva.NtNdArray(extraFieldsPvObject.getStructureDict())

nda['uniqueId'] = frame_id
nda['codec'] = pva.PvCodec('pvapyc', pva.PvInt(5))
dims = [pva.PvDimension(self.rows, 0, self.rows, 1, False), \
pva.PvDimension(self.cols, 0, self.cols, 1, False)]
nda['dimension'] = dims
nda['compressedSize'] = self.rows*self.cols
nda['uncompressedSize'] = self.rows*self.cols
ts = self.get_timestamp()
nda['timeStamp'] = ts
nda['dataTimeStamp'] = ts
nda['descriptor'] = 'PvaPy Simulated Image'
nda['value'] = {self.pva_type_key : self.frames[frame_id].flatten()}
attrs = [pva.NtAttribute('ColorMode', pva.PvInt(0))]
nda['attribute'] = attrs
if extraFieldsPvObject is not None:
nda.set(extraFieldsPvObject)
self.frame_map[frame_id] = nda

def prepare_frame(self):
# Get cached frame
cached_frame_id = self.current_frame_id % self.n_input_frames
frame = self.frame_map[cached_frame_id]

# Correct image id and timeestamps
self.current_frame_id += 1
frame['uniqueId'] = self.current_frame_id
ts = self.get_timestamp()
frame['timeStamp'] = ts
frame['dataTimeStamp'] = ts
return frame

def frame_publisher(self):
while True:
if self.is_done:
return

frame = self.prepare_frame()
self.server.update(self.channel_name, frame)
self.last_published_time = time.time()
self.n_published_frames += 1

runtime = 0
if self.n_published_frames > 1:
runtime = self.last_published_time - self.start_time
delta_t = runtime/(self.n_published_frames - 1)
frame_rate = 1.0/delta_t
if self.report_frequency > 0 and (self.n_published_frames % self.report_frequency) == 0:
print("Published frame id %6d @ %.3f (frame rate: %.4f fps)" % (self.current_frame_id, self.last_published_time, frame_rate))
else:
self.start_time = self.last_published_time
if self.report_frequency > 0 and (self.n_published_frames % self.report_frequency) == 0:
print("Published frame id %6d @ %.3f" % (self.current_frame_id, self.last_published_time))

if runtime > self.runtime:
print("Server will exit after reaching runtime of %s seconds" % (self.runtime))
return

if self.delta_t > 0:
next_publish_time = self.start_time + self.n_published_frames*self.delta_t
delay = next_publish_time - time.time() - self.DELAY_CORRECTION
if delay > 0:
threading.Timer(delay, self.frame_publisher).start()
return

def start(self):
threading.Thread(target=self.frame_producer, daemon=True).start()
self.server.start()
threading.Timer(self.start_delay, self.frame_publisher).start()

def stop(self):
self.is_done = True
self.server.stop()
runtime = self.last_published_time - self.start_time
delta_t = runtime/(self.n_published_frames - 1)
frame_rate = 1.0/delta_t
print('\nServer runtime: %.4f seconds' % (runtime))
print('Published frames: %6d @ %.4f fps' % (self.n_published_frames, frame_rate))

def main():
parser = argparse.ArgumentParser(description='PvaPy Area Detector Simulator')
parser.add_argument('--input-directory', '-id', type=str, dest='input_directory', default=None, help='Directory containing input files to be streamed; if input directory or input file are not provided, random images will be generated')
parser.add_argument('--input-file', '-if', type=str, dest='input_file', default=None, help='Input file to be streamed; if input directory or input file are not provided, random images will be generated')
parser.add_argument('--frame-rate', '-fps', type=float, dest='frame_rate', default=20, help='Frames per second (default: 20 fps)')
parser.add_argument('--n-x-pixels', '-nx', type=int, dest='n_x_pixels', default=2048, help='Number of pixels in x dimension (default: 256 pixels; does not apply if input_file file is given)')
parser.add_argument('--n-y-pixels', '-ny', type=int, dest='n_y_pixels', default=256, help='Number of pixels in x dimension (default: 256 pixels; does not apply if hdf5 file is given)')
parser.add_argument('--n-frames', '-nf', type=int, dest='n_frames', default=1000, help='Number of different frames to generate and cache; those images will be published over and over again as long as the server is running')
parser.add_argument('--runtime', '-rt', type=float, dest='runtime', default=300, help='Server runtime in seconds (default: 300 seconds)')
parser.add_argument('--channel-name', '-cn', type=str, dest='channel_name', default='pvapy:image', help='Server PVA channel name (default: pvapy:image)')
parser.add_argument('--start-delay', '-sd', type=float, dest='start_delay', default=10.0, help='Server start delay in seconds (default: 10 seconds)')
parser.add_argument('--report-frequency', '-rf', type=int, dest='report_frequency', default=1, help='Reporting frequency for publishing frames; if set to <=0 no frames will be reported as published (default: 1)')
parser.add_argument('-v', '--version', action='version', version='%(prog)s {version}'.format(version=__version__))

args, unparsed = parser.parse_known_args()
if len(unparsed) > 0:
print('Unrecognized argument(s): %s' % ' '.join(unparsed))
exit(1)

server = AdSimServer(input_directory=args.input_directory, input_file=args.input_file, frame_rate=args.frame_rate, nf=args.n_frames, nx=args.n_x_pixels, ny=args.n_y_pixels, runtime=args.runtime, channel_name=args.channel_name, start_delay=args.start_delay, report_frequency=args.report_frequency)

server.start()
try:
runtime = args.runtime + 2*args.start_delay
time.sleep(runtime)
except KeyboardInterrupt as ex:
pass
server.stop()

if __name__ == '__main__':
main()
64 changes: 64 additions & 0 deletions helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pycuda.driver as cuda
import tensorrt as trt
import logging, torch


def engine_build_from_onnx(onnx_mdl):
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
builder = trt.Builder(TRT_LOGGER)
config = builder.create_builder_config()
# config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.TF32)
config.max_workspace_size = 1 * (1 << 30) # the maximum size that any layer in the network can use

network = builder.create_network(EXPLICIT_BATCH)
parser = trt.OnnxParser(network, TRT_LOGGER)
# Load the Onnx model and parse it in order to populate the TensorRT network.
success = parser.parse_from_file(onnx_mdl)
for idx in range(parser.num_errors):
print(parser.get_error(idx))

if not success:
return None

return builder.build_engine(network, config)

def mem_allocation(engine):
# Determine dimensions and create page-locked memory buffers (i.e. won't be swapped to disk) to hold host inputs/outputs.

in_sz = trt.volume(engine.get_binding_shape(0)) * engine.max_batch_size
h_input = cuda.pagelocked_empty(in_sz, dtype='float32')

out_sz = trt.volume(engine.get_binding_shape(1)) * engine.max_batch_size
h_output = cuda.pagelocked_empty(out_sz, dtype='float32')

# Allocate device memory for inputs and outputs.
d_input = cuda.mem_alloc(h_input.nbytes)
d_output = cuda.mem_alloc(h_output.nbytes)

# Create a stream in which to copy inputs/outputs and run inference.
stream = cuda.Stream()

return h_input, h_output, d_input, d_output, stream

def inference(context, h_input, h_output, d_input, d_output, stream):
# Transfer input data to the GPU.
cuda.memcpy_htod_async(d_input, h_input, stream)

# Run inference.
context.execute_async_v2(bindings=[int(d_input), int(d_output)], stream_handle=stream.handle)

# Transfer predictions back from the GPU.
cuda.memcpy_dtoh_async(h_output, d_output, stream)

# Synchronize the stream
stream.synchronize()

# Return the host
return h_output

## can change this later def pth2onnx(pth):

#def pth2onnx(pth, bsz, in_size):

76 changes: 76 additions & 0 deletions inferPtychoNN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from PIL import Image
import numpy as np
import tensorrt as trt
#import pycuda.autoinit
#import pycuda.driver as cuda
import threading
import time
import math
import os
import logging
#import GPUtil
#import common_v1

from multiprocessing import Process, Queue
from skimage.transform import resize
from helper import inference
from pvaClient import *

class inferPtychoNNtrt:
def __init__(self, client, mbsz, onnx_mdl, tq_diff , frm_id_q):
self.tq_diff = tq_diff
self.mbsz = mbsz
self.onnx_mdl = onnx_mdl
self.client= client
self.frm_id_q = frm_id_q
self.processed_count = 0
self.msg1 = ''
self.msg2 = ''
self.frame_loss = 0
self.t0=0
from helper import engine_build_from_onnx, mem_allocation, inference
import pycuda.autoinit # must be in the same thread as the actual cuda execution

self.trt_engine = engine_build_from_onnx(self.onnx_mdl)

self.trt_hin, self.trt_hout, self.trt_din, self.trt_dout, \
self.trt_stream = mem_allocation(self.trt_engine)
self.trt_context = self.trt_engine.create_execution_context()
logging.info("TensorRT Inference engine initialization completed!")

def start(self, ):
threading.Thread(target=self.batch_infer, daemon=True).start()


def batch_infer(self, ):



## change here, tensorrt engine need not intilaized everytime
#while True:
#print('entered here')
in_mb = self.tq_diff.get()
frm_id_list = self.frm_id_q.get()
batch_tick = time.time()
np.copyto(self.trt_hin, in_mb.astype(np.float32).ravel())
comp_tick = time.time()
pred = np.array(inference(self.trt_context, self.trt_hin, self.trt_hout, \
self.trt_din, self.trt_dout, self.trt_stream))
t_comp = 1000 * (time.time() - comp_tick)
t_batch = 1000 * (time.time() - batch_tick)

logging.info(" Time %.3f ms " % (t_batch))

#np.save('../batch_out.npy',pred)
#ctx.pop()

pred = pred.reshape(8, 16384)

for j in range(0, len(frm_id_list)):
self.processed_count=self.processed_count+1
if(not(self.processed_count%1000)):
self.msg1 = "Inference @ {0:.0f}Hz | {1} frames remaining".format(1000/(time.time()-self.t0), (-self.processed_count+self.client.recv_frames))
self.t0 = time.time()
print(self.client.msg2+ " | "+ self.msg1+" \r", end="")
self.client.server.update(self.client.channel_name, self.client.frame_producer(int(frm_id_list[j]), pred[j]))
#logging.info("Sent frame id ".format(frm_id_list[j]))
Loading

0 comments on commit a3b0ec8

Please sign in to comment.