Skip to content

Commit

Permalink
fix formatting (triton-inference-server#1859)
Browse files Browse the repository at this point in the history
  • Loading branch information
CoderHam authored Aug 2, 2020
1 parent e06a6b7 commit 79320a5
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 19 deletions.
5 changes: 3 additions & 2 deletions qa/L0_custom_ops/cuda_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@
sys.exit(1)

for i in range(elements):
print(str(i) + ": input " + str(input_data[i]) + ", output " +
str(output_data[i]))
print(
str(i) + ": input " + str(input_data[i]) + ", output " +
str(output_data[i]))
if output_data[i] != (input_data[i] + 1):
print("error: incorrect value")
sys.exit(1)
5 changes: 3 additions & 2 deletions qa/L0_custom_ops/mod_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,9 @@
sys.exit(1)

for i in range(elements):
print(str(i) + ": " + str(input_data[0][i]) + " % " +
str(input_data[1][i]) + " = " + str(output_data[i]))
print(
str(i) + ": " + str(input_data[0][i]) + " % " +
str(input_data[1][i]) + " = " + str(output_data[i]))
if ((input_data[0][i] % input_data[1][i]) != output_data[i]):
print("error: incorrect value")
sys.exit(1)
5 changes: 3 additions & 2 deletions qa/L0_custom_ops/onnx_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@

for i in range(3):
for j in range(5):
print(str(input_data[0][i][j]) + " + " + str(input_data[1][i][j]) +
" = " + str(output_data[i][j]))
print(
str(input_data[0][i][j]) + " + " + str(input_data[1][i][j]) +
" = " + str(output_data[i][j]))
if ((input_data[0][i][j] + input_data[1][i][j]) !=
output_data[i][j]):
print("error: incorrect value")
Expand Down
39 changes: 29 additions & 10 deletions qa/L0_custom_ops/vision_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,36 @@

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-v', '--verbose', action="store_true", required=False, default=False,
parser.add_argument('-v',
'--verbose',
action="store_true",
required=False,
default=False,
help='Enable verbose output')
parser.add_argument('-u', '--url', type=str, required=False, default='localhost:8000',
parser.add_argument('-u',
'--url',
type=str,
required=False,
default='localhost:8000',
help='Inference server URL. Default is localhost:8000.')
parser.add_argument('-i', '--protocol', type=str, required=False, default='http',
help='Protocol ("http"/"grpc") used to ' +
'communicate with inference service. Default is "http".')
parser.add_argument('-m', '--model', type=str, required=True,
parser.add_argument(
'-i',
'--protocol',
type=str,
required=False,
default='http',
help='Protocol ("http"/"grpc") used to ' +
'communicate with inference service. Default is "http".')
parser.add_argument('-m',
'--model',
type=str,
required=True,
help='Name of model.')

FLAGS = parser.parse_args()
if (FLAGS.protocol != "http") and (FLAGS.protocol != "grpc"):
print("unexpected protocol \"{}\", expects \"http\" or \"grpc\"".format(FLAGS.protocol))
print("unexpected protocol \"{}\", expects \"http\" or \"grpc\"".format(
FLAGS.protocol))
exit(1)

client_util = httpclient if FLAGS.protocol == "http" else grpcclient
Expand All @@ -66,8 +83,9 @@
input_data = np.random.rand(1, 16, 10, 10).astype(np.float32)

inputs = []
inputs.append(client_util.InferInput(
"INPUT__0", input_data.shape, np_to_triton_dtype(input_data.dtype)))
inputs.append(
client_util.InferInput("INPUT__0", input_data.shape,
np_to_triton_dtype(input_data.dtype)))
inputs[0].set_data_from_numpy(input_data)

results = client.infer(model_name, inputs)
Expand All @@ -79,5 +97,6 @@
sys.exit(1)

if (output_data.shape != (1, 33, 12, 14)):
print("error: incorrect shape "+ str(output_data.shape) +"for 'OUTPUT__0'")
print("error: incorrect shape " + str(output_data.shape) +
"for 'OUTPUT__0'")
sys.exit(1)
5 changes: 3 additions & 2 deletions qa/L0_custom_ops/zero_out_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@
sys.exit(1)

for i in range(elements):
print(str(i) + ": input " + str(input_data[i]) + ", output " +
str(output_data[i]))
print(
str(i) + ": input " + str(input_data[i]) + ", output " +
str(output_data[i]))
if (i == 0) and (input_data[i] != output_data[i]):
print("error: incorrect value")
sys.exit(1)
Expand Down
6 changes: 5 additions & 1 deletion qa/common/gen_qa_custom_ops_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,17 @@ def create_moduloop_modelconfig(models_dir, model_version):
with open(config_dir + "/config.pbtxt", "w") as cfile:
cfile.write(config)


# Use Torchvision ops
def create_visionop_modelfile(models_dir, model_version):
model_name = "libtorch_visionop"

class CustomVisionNet(nn.Module):

def __init__(self):
super(CustomVisionNet, self).__init__()
self.conv2 = ops.misc.ConvTranspose2d(16, 33, (3, 5))

def forward(self, input0):
return self.conv2(input0)

Expand Down Expand Up @@ -403,6 +406,7 @@ def create_visionop_modelconfig(models_dir, model_version):
with open(config_dir + "/config.pbtxt", "w") as cfile:
cfile.write(config)


def create_zero_out_models(models_dir):
model_version = 1

Expand Down

0 comments on commit 79320a5

Please sign in to comment.