diff --git a/finn-rtllib/fifo/hdl/Q_srl.v b/finn-rtllib/fifo/hdl/Q_srl.v index d1ce33c41f..0b01973163 100644 --- a/finn-rtllib/fifo/hdl/Q_srl.v +++ b/finn-rtllib/fifo/hdl/Q_srl.v @@ -184,58 +184,58 @@ module Q_srl (clock, reset, i_d, i_v, i_r, o_d, o_v, o_r, count, maxcount); end // always @ (posedge clock or negedge reset) always @* begin // - combi always - srlo_ <= 'bx; - shift_en_o_ <= 1'bx; - shift_en_ <= 1'bx; - addr_ <= 'bx; - state_ <= 2'bx; + srlo_ = 'bx; + shift_en_o_ = 1'bx; + shift_en_ = 1'bx; + addr_ = 'bx; + state_ = 2'bx; case (state) state_empty: begin // - (empty, will not produce) if (i_v) begin // - empty & i_v => consume - srlo_ <= i_d; - shift_en_o_ <= 1; - shift_en_ <= 1'bx; - addr_ <= 0; - state_ <= state_one; + srlo_ = i_d; + shift_en_o_ = 1; + shift_en_ = 1'bx; + addr_ = 0; + state_ = state_one; end else begin // - empty & !i_v => idle - srlo_ <= 'bx; - shift_en_o_ <= 0; - shift_en_ <= 1'bx; - addr_ <= 0; - state_ <= state_empty; + srlo_ = 'bx; + shift_en_o_ = 0; + shift_en_ = 1'bx; + addr_ = 0; + state_ = state_empty; end end state_one: begin // - (contains one) if (i_v && o_b) begin // - one & i_v & o_b => consume - srlo_ <= 'bx; - shift_en_o_ <= 0; - shift_en_ <= 1; - addr_ <= 0; - state_ <= state_more; + srlo_ = 'bx; + shift_en_o_ = 0; + shift_en_ = 1; + addr_ = 0; + state_ = state_more; end else if (i_v && !o_b) begin // - one & i_v & !o_b => cons+prod - srlo_ <= i_d; - shift_en_o_ <= 1; - shift_en_ <= 1; - addr_ <= 0; - state_ <= state_one; + srlo_ = i_d; + shift_en_o_ = 1; + shift_en_ = 1; + addr_ = 0; + state_ = state_one; end else if (!i_v && o_b) begin // - one & !i_v & o_b => idle - srlo_ <= 'bx; - shift_en_o_ <= 0; - shift_en_ <= 1'bx; - addr_ <= 0; - state_ <= state_one; + srlo_ = 'bx; + shift_en_o_ = 0; + shift_en_ = 1'bx; + addr_ = 0; + state_ = state_one; end else if (!i_v && !o_b) begin // - one & !i_v & !o_b => produce - srlo_ <= 'bx; - shift_en_o_ <= 0; - shift_en_ <= 1'bx; - addr_ <= 0; - state_ <= state_empty; + srlo_ = 'bx; + shift_en_o_ = 0; + shift_en_ = 1'bx; + addr_ = 0; + state_ = state_empty; end end // case: state_one @@ -244,60 +244,60 @@ module Q_srl (clock, reset, i_d, i_v, i_r, o_d, o_v, o_r, count, maxcount); // - (full, will not consume) // - (full here if depth==2) if (o_b) begin // - full & o_b => idle - srlo_ <= 'bx; - shift_en_o_ <= 0; - shift_en_ <= 0; - addr_ <= addr; - state_ <= state_more; + srlo_ = 'bx; + shift_en_o_ = 0; + shift_en_ = 0; + addr_ = addr; + state_ = state_more; end else begin // - full & !o_b => produce - srlo_ <= srl[addr]; - shift_en_o_ <= 1; - shift_en_ <= 0; -// addr_ <= addr-1; -// state_ <= state_more; - addr_ <= addr_zero_ ? 0 : addr-1; - state_ <= addr_zero_ ? state_one : state_more; + srlo_ = srl[addr]; + shift_en_o_ = 1; + shift_en_ = 0; +// addr_ = addr-1; +// state_ = state_more; + addr_ = addr_zero_ ? 0 : addr-1; + state_ = addr_zero_ ? state_one : state_more; end end else begin // - (mid: neither empty nor full) if (i_v && o_b) begin // - mid & i_v & o_b => consume - srlo_ <= 'bx; - shift_en_o_ <= 0; - shift_en_ <= 1; - addr_ <= addr+1; - state_ <= state_more; + srlo_ = 'bx; + shift_en_o_ = 0; + shift_en_ = 1; + addr_ = addr+1; + state_ = state_more; end else if (i_v && !o_b) begin // - mid & i_v & !o_b => cons+prod - srlo_ <= srl[addr]; - shift_en_o_ <= 1; - shift_en_ <= 1; - addr_ <= addr; - state_ <= state_more; + srlo_ = srl[addr]; + shift_en_o_ = 1; + shift_en_ = 1; + addr_ = addr; + state_ = state_more; end else if (!i_v && o_b) begin // - mid & !i_v & o_b => idle - srlo_ <= 'bx; - shift_en_o_ <= 0; - shift_en_ <= 0; - addr_ <= addr; - state_ <= state_more; + srlo_ = 'bx; + shift_en_o_ = 0; + shift_en_ = 0; + addr_ = addr; + state_ = state_more; end else if (!i_v && !o_b) begin // - mid & !i_v & !o_b => produce - srlo_ <= srl[addr]; - shift_en_o_ <= 1; - shift_en_ <= 0; - addr_ <= addr_zero_ ? 0 : addr-1; - state_ <= addr_zero_ ? state_one : state_more; + srlo_ = srl[addr]; + shift_en_o_ = 1; + shift_en_ = 0; + addr_ = addr_zero_ ? 0 : addr-1; + state_ = addr_zero_ ? state_one : state_more; end end // else: !if(addr_full) end // case: state_more default: begin - srlo_ <= 'bx; - shift_en_o_ <= 1'bx; - shift_en_ <= 1'bx; - addr_ <= 'bx; - state_ <= 2'bx; + srlo_ = 'bx; + shift_en_o_ = 1'bx; + shift_en_ = 1'bx; + addr_ = 'bx; + state_ = 2'bx; end // case: default endcase // case(state) diff --git a/src/finn/custom_op/fpgadataflow/hlsbackend.py b/src/finn/custom_op/fpgadataflow/hlsbackend.py index 9749a507b2..a0c61ec5b3 100644 --- a/src/finn/custom_op/fpgadataflow/hlsbackend.py +++ b/src/finn/custom_op/fpgadataflow/hlsbackend.py @@ -59,6 +59,8 @@ def get_nodeattr_types(self): "code_gen_dir_cppsim": ("s", False, ""), "executable_path": ("s", False, ""), "res_hls": ("s", False, ""), + # temporary node attribute to keep track of interface style of hls ops + "cpp_interface": ("s", False, "packed", {"packed", "hls_vector"}), } def get_all_verilog_paths(self): @@ -232,7 +234,13 @@ def code_generation_cppsim(self, model): self.dataoutstrm() self.save_as_npy() - template = templates.docompute_template + if self.get_nodeattr("cpp_interface") == "hls_vector": + self.timeout_value() + self.timeout_condition() + self.timeout_read_stream() + template = templates.docompute_template_timeout + else: + template = templates.docompute_template for key in self.code_gen_dict: # transform list into long string separated by '\n' @@ -398,24 +406,40 @@ def read_npy_data(self): if dtype == DataType["BIPOLAR"]: # use binary for bipolar storage dtype = DataType["BINARY"] - elem_bits = dtype.bitwidth() - packed_bits = self.get_instream_width() - packed_hls_type = "ap_uint<%d>" % packed_bits elem_hls_type = dtype.get_hls_datatype_str() npy_type = "float" npy_in = "%s/input_0.npy" % code_gen_dir self.code_gen_dict["$READNPYDATA$"] = [] - self.code_gen_dict["$READNPYDATA$"].append( - 'npy2apintstream<%s, %s, %d, %s>("%s", in0_%s);' - % ( - packed_hls_type, - elem_hls_type, - elem_bits, - npy_type, - npy_in, - self.hls_sname(), + + cpp_interface = self.get_nodeattr("cpp_interface") + + if cpp_interface == "packed": + elem_bits = dtype.bitwidth() + packed_bits = self.get_instream_width() + packed_hls_type = "ap_uint<%d>" % packed_bits + self.code_gen_dict["$READNPYDATA$"].append( + 'npy2apintstream<%s, %s, %d, %s>("%s", in0_%s);' + % ( + packed_hls_type, + elem_hls_type, + elem_bits, + npy_type, + npy_in, + self.hls_sname(), + ) + ) + else: + folded_shape = self.get_folded_input_shape() + self.code_gen_dict["$READNPYDATA$"].append( + 'npy2vectorstream<%s, %s, %d>("%s", in0_%s, false);' + % ( + elem_hls_type, + npy_type, + folded_shape[-1], + npy_in, + self.hls_sname(), + ) ) - ) def strm_decl(self): """Function to generate the commands for the stream declaration in c++, @@ -449,27 +473,43 @@ def dataoutstrm(self): if dtype == DataType["BIPOLAR"]: # use binary for bipolar storage dtype = DataType["BINARY"] - elem_bits = dtype.bitwidth() - packed_bits = self.get_outstream_width() - packed_hls_type = "ap_uint<%d>" % packed_bits elem_hls_type = dtype.get_hls_datatype_str() npy_type = "float" npy_out = "%s/output.npy" % code_gen_dir oshape = self.get_folded_output_shape() oshape_cpp_str = str(oshape).replace("(", "{").replace(")", "}") - self.code_gen_dict["$DATAOUTSTREAM$"] = [ - 'apintstream2npy<%s, %s, %d, %s>(out_%s, %s, "%s");' - % ( - packed_hls_type, - elem_hls_type, - elem_bits, - npy_type, - self.hls_sname(), - oshape_cpp_str, - npy_out, - ) - ] + cpp_interface = self.get_nodeattr("cpp_interface") + + if cpp_interface == "packed": + elem_bits = dtype.bitwidth() + packed_bits = self.get_outstream_width() + packed_hls_type = "ap_uint<%d>" % packed_bits + + self.code_gen_dict["$DATAOUTSTREAM$"] = [ + 'apintstream2npy<%s, %s, %d, %s>(out_%s, %s, "%s");' + % ( + packed_hls_type, + elem_hls_type, + elem_bits, + npy_type, + self.hls_sname(), + oshape_cpp_str, + npy_out, + ) + ] + else: + folded_shape = self.get_folded_output_shape() + self.code_gen_dict["$DATAOUTSTREAM$"] = [ + 'vectorstream2npy<%s, %s, %d>(strm, %s, "%s");' + % ( + elem_hls_type, + npy_type, + folded_shape[-1], + oshape_cpp_str, + npy_out, + ) + ] def save_as_npy(self): """Function to generate the commands for saving data in .npy file in c++""" @@ -501,3 +541,17 @@ def get_ap_int_max_w(self): ret = max([instream, outstream]) assert ret <= 8191, "AP_INT_MAX_W=%d is larger than allowed maximum of 8191" % ret return ret + + def timeout_value(self): + """Set timeout value for HLS functions defined for one clock cycle""" + self.code_gen_dict["$TIMEOUT_VALUE$"] = ["1000"] + + def timeout_condition(self): + """Set timeout condition for HLS functions defined for one clock cycle""" + self.code_gen_dict["$TIMEOUT_CONDITION$"] = ["out_{}.empty()".format(self.hls_sname())] + + def timeout_read_stream(self): + """Set reading output stream procedure for HLS functions defined for one clock cycle""" + self.code_gen_dict["$TIMEOUT_READ_STREAM$"] = [ + "strm << out_{}.read();".format(self.hls_sname()) + ] diff --git a/src/finn/custom_op/fpgadataflow/templates.py b/src/finn/custom_op/fpgadataflow/templates.py index 88188a1472..8a23e05339 100644 --- a/src/finn/custom_op/fpgadataflow/templates.py +++ b/src/finn/custom_op/fpgadataflow/templates.py @@ -33,6 +33,7 @@ #define HLS_NO_XIL_FPO_LIB #include "cnpy.h" #include "npy2apintstream.hpp" +#include "npy2vectorstream.hpp" #include #include "bnn-library.h" @@ -59,6 +60,51 @@ """ +# template for single node execution with timeout (for single clock hls operations) +docompute_template_timeout = """ +#define AP_INT_MAX_W $AP_INT_MAX_W$ +#include "cnpy.h" +#include "npy2apintstream.hpp" +#include "npy2vectorstream.hpp" +#include +#include "bnn-library.h" + +// includes for network parameters +$GLOBALS$ + +// defines for network parameters +$DEFINES$ + +int main(){ +$PRAGMAS$ + +$STREAMDECLARATIONS$ + +$READNPYDATA$ + +unsigned timeout = 0; +while(timeout < $TIMEOUT_VALUE$){ + +$DOCOMPUTE$ + +if($TIMEOUT_CONDITION$){ +timeout++; +} + +else{ +$TIMEOUT_READ_STREAM$ +timeout = 0; +} +} + +$DATAOUTSTREAM$ + +$SAVEASCNPY$ + +} + +""" + # templates for single node ip generation # cpp file diff --git a/tests/brevitas/test_brevitas_fc.py b/tests/brevitas/test_brevitas_fc.py index 842d099f57..a7a73a5ed4 100644 --- a/tests/brevitas/test_brevitas_fc.py +++ b/tests/brevitas/test_brevitas_fc.py @@ -45,8 +45,6 @@ from finn.util.basic import make_build_dir from finn.util.test import get_test_model_trained -export_onnx_path = make_build_dir("test_brevitas_fc_") - @pytest.mark.brevitas_export # act bits @@ -61,6 +59,7 @@ def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits): if wbits > abits: pytest.skip("No wbits > abits cases at the moment") nname = "%s_%dW%dA" % (size, wbits, abits) + export_onnx_path = make_build_dir("test_brevitas_fc_") finn_onnx = export_onnx_path + "/%s.onnx" % nname fc = get_test_model_trained(size, wbits, abits) ishape = (1, 1, 28, 28) diff --git a/tests/transformation/streamline/test_streamline_cnv.py b/tests/transformation/streamline/test_streamline_cnv.py index 8a91a49278..9e206c843a 100644 --- a/tests/transformation/streamline/test_streamline_cnv.py +++ b/tests/transformation/streamline/test_streamline_cnv.py @@ -50,8 +50,6 @@ from finn.util.basic import make_build_dir from finn.util.test import get_test_model_trained -export_onnx_path = make_build_dir("test_streamline_cnv_") - @pytest.mark.streamline # act bits @@ -64,6 +62,7 @@ def test_streamline_cnv(size, wbits, abits): if wbits > abits: pytest.skip("No wbits > abits cases at the moment") nname = "%s_%dW%dA" % (size, wbits, abits) + export_onnx_path = make_build_dir("test_streamline_cnv_") finn_onnx = export_onnx_path + "/%s.onnx" % nname fc = get_test_model_trained(size, wbits, abits) export_qonnx(fc, torch.randn(1, 3, 32, 32), finn_onnx) diff --git a/tests/transformation/streamline/test_streamline_fc.py b/tests/transformation/streamline/test_streamline_fc.py index edc4a96fe2..9ce2f2ab65 100644 --- a/tests/transformation/streamline/test_streamline_fc.py +++ b/tests/transformation/streamline/test_streamline_fc.py @@ -52,8 +52,6 @@ from finn.util.basic import make_build_dir from finn.util.test import get_test_model_trained -export_onnx_path = make_build_dir("test_streamline_fc_") - @pytest.mark.streamline # act bits @@ -68,6 +66,7 @@ def test_streamline_fc(size, wbits, abits): if wbits > abits: pytest.skip("No wbits > abits cases at the moment") nname = "%s_%dW%dA" % (size, wbits, abits) + export_onnx_path = make_build_dir("test_streamline_fc_") finn_onnx = export_onnx_path + "/%s.onnx" % nname fc = get_test_model_trained(size, wbits, abits) export_qonnx(fc, torch.randn(1, 1, 28, 28), finn_onnx)