Skip to content

Commit

Permalink
Merge pull request #924 from Xilinx/update/export
Browse files Browse the repository at this point in the history
[Tests] Update export to qonnx export
  • Loading branch information
auphelia authored Nov 24, 2023
2 parents 02ce695 + 156be02 commit 8f9f10c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
16 changes: 12 additions & 4 deletions tests/fpgadataflow/test_fpgadataflow_lookup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2021, Xilinx
# Copyright (C) 2021-2022, Xilinx, Inc.
# Copyright (C) 2023, Advanced Micro Devices, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -30,14 +31,15 @@

import numpy as np
import torch
from brevitas.export import FINNManager
from brevitas.export import export_qonnx
from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.general import GiveUniqueNodeNames
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.util.basic import gen_finn_dt_tensor
from qonnx.util.cleanup import cleanup as qonnx_cleanup
from torch import nn

from finn.core.onnx_exec import execute_onnx
Expand All @@ -49,6 +51,9 @@
from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN

export_onnx_path = "test_lookup.onnx"


def make_lookup_model(embeddings, ishape, idt, edt):
Expand All @@ -65,8 +70,11 @@ def forward(self, x):

torch_model = LookupModel(num_embeddings, embedding_dim)
input_t = torch.zeros(ishape, dtype=torch.int64)
ret = FINNManager.export(torch_model, input_t=input_t, opset_version=11)
model = ModelWrapper(ret)
export_qonnx(torch_model, input_t, export_onnx_path, opset_version=11)
qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(ConvertQONNXtoFINN())
model = model.transform(InferShapes())
iname = model.graph.input[0].name
ename = model.graph.node[0].input[0]
model.set_tensor_datatype(iname, idt)
Expand Down
14 changes: 9 additions & 5 deletions tests/fpgadataflow/test_fpgadataflow_upsampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2020, Xilinx
# Copyright (C) 2020-2022, Xilinx, Inc.
# Copyright (C) 2023, Advanced Micro Devices, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -32,7 +33,7 @@
import os
import shutil
import torch
from brevitas.export import FINNManager
from brevitas.export import export_qonnx
from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.base import Transformation
Expand All @@ -41,6 +42,7 @@
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast
from qonnx.util.cleanup import cleanup as qonnx_cleanup
from torch import nn

import finn.core.onnx_exec as oxe
Expand All @@ -52,6 +54,7 @@
from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from finn.util.basic import make_build_dir

tmpdir = os.environ["FINN_BUILD_DIR"]
Expand Down Expand Up @@ -154,10 +157,11 @@ def test_fpgadataflow_upsampler(dt, IFMDim, scale, NumChannels, exec_mode, is_1d
# Get golden PyTorch and ONNX inputs
golden_torch_float = torch_model(test_in)
export_path = f"{tmpdir}/Upsample_exported.onnx"
FINNManager.export(
torch_model, input_shape=input_shape, export_path=export_path, opset_version=11
)
export_qonnx(torch_model, torch.randn(input_shape), export_path, opset_version=11)
qonnx_cleanup(export_path, out_file=export_path)
model = ModelWrapper(export_path)
model = model.transform(ConvertQONNXtoFINN())
model = model.transform(InferShapes())
input_dict = {model.graph.input[0].name: test_in.numpy().astype(np.int32)}
input_dict = {model.graph.input[0].name: test_in.numpy()}
golden_output_dict = oxe.execute_onnx(model, input_dict, True)
Expand Down

0 comments on commit 8f9f10c

Please sign in to comment.