Skip to content

Commit

Permalink
[RTL SWG] Support SIMD < C in window-parallel mode
Browse files Browse the repository at this point in the history
  • Loading branch information
fpjentzsch committed Nov 20, 2023
1 parent 46e0661 commit 4c80cf8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 34 deletions.
9 changes: 8 additions & 1 deletion docs/finn/internals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,13 @@ Depending on the amount of parallelism requested, one of two implementation styl
- 1
- default
- depthwise-agnostic
* - < C
- 1
- 1
- 1
- K
- parallel
- depthwise only
* - C
- 1
- 1
Expand Down Expand Up @@ -343,4 +350,4 @@ The RTL SWG is supported by the basic automatic folding algorithm in FINN (:py:m

**MVAU:** Although it is recommended to unfold SIMD first, SIMD and PE can be set independently. Full (and balanced) parallelism is achieved by using the SWG in parallel window mode and setting MVAU SIMD and PE to their maximum values (SIMD = MW = C_in * K, PE = MH = C_out).

**VVAU:** While the VVAU HLS component supports SIMD unfolding independently from PE, the RTL SWG requires full unfolding across the channel dimension (SIMD of the SWG = PE of the VVAU) before enabling window-parallelism. Unlike the MVAU, the VVAU can't accept datawidth-converted input from a fully-parallel SWG in this case due to the depthwise data layout. As a result, the VVAU should be unfolded by PE first (up to PE = C), followed by SIMD (up to SIMD = K).
**VVAU:** The VVAU component supports SIMD unfolding (up to SIMD = K) independently from PE unfolding (up to PE = C), but can't accept a datawidth-converted input from a fully-parallel SWG in case PE is not fully unfolded due to the depthwise data layout. Therefore, it is required to set SIMD of the SWG = PE of the VVAU when window-parallelism is enabled. In this scenario, VVAU SIMD < K is supported via an automatically inserted DWC.
75 changes: 44 additions & 31 deletions src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,11 @@ def get_buffer_depth(self):
mmv_in = 1
mmv_out = 1
channel_factor = int(ifm_ch / simd)

# compute minimal buffer length (assuming it holds 1 complete window)
buffer_min_size = ((k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w + 1) * channel_factor

impl_style = self.select_impl_style()
if impl_style == "default":
buffer_min_size = (
(k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w + 1
) * channel_factor
# add additional buffer space in case of stride > 1
# this minimizes cycle count as it allows an earlier pre-load of inputs
buffer_depth = (
Expand All @@ -257,6 +256,9 @@ def get_buffer_depth(self):
)
)
elif impl_style == "parallel":
buffer_min_size = (
(k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w
) * channel_factor + 1
buffer_depth = buffer_min_size + 1
return buffer_depth

Expand Down Expand Up @@ -676,6 +678,7 @@ def prepare_codegen_parallel(self):
dilation = self.get_nodeattr("Dilation")
simd = self.get_nodeattr("SIMD")
M = self.get_nodeattr("M")
depthwise = self.get_nodeattr("depthwise")

k_h, k_w = k
h, w = ifm_dim
Expand All @@ -691,7 +694,7 @@ def prepare_codegen_parallel(self):
channel_factor = int(ifm_ch / simd)

# compute minimal buffer length (assuming it holds 1 complete window)
buffer_min_size = ((k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w + 1) * channel_factor
buffer_min_size = ((k_h - 1) * dilation_h * w + (k_w - 1) * dilation_w) * channel_factor + 1

buffer_actual_size = self.get_buffer_depth()
code_gen_dict["$BUF_ELEM_TOTAL$"] = [str(buffer_actual_size)]
Expand All @@ -710,32 +713,32 @@ def prepare_codegen_parallel(self):
]

# re-use default controller loop structure
code_gen_dict["$IS_DEPTHWISE$"] = ["0"]
code_gen_dict["$IS_DEPTHWISE$"] = ["1"] if depthwise else ["0"]
loop_h_iterations = out_dim_h
loop_w_iterations = out_dim_w # now the innermost loop
loop_kh_iterations = 1
loop_w_iterations = out_dim_w
loop_kh_iterations = channel_factor
loop_kw_iterations = 1
loop_simd_iterations = 1

if loop_w_iterations == 1:
code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_H"]
loop_h_iterations -= 1 # -1 because state is initial state
if loop_kh_iterations == 1:
if loop_w_iterations == 1:
code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_H"]
loop_h_iterations -= 1 # -1 because state is initial state
else:
code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_W"]
loop_w_iterations -= 1 # -1 because state is initial state
else:
code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_W"]
loop_w_iterations -= 1 # -1 because state is initial state
code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_KH"]
loop_kh_iterations -= 1 # -1 because state is initial state

# set head and tail address increment values
addr_incr_end_window = -buffer_min_size + stride_w * channel_factor + 1
addr_incr_end_row = (
-buffer_min_size
+ ((skip_columns + kernel_width) * channel_factor) # remaining line
+ ((stride_h - 1) * w * channel_factor) # skip lines
+ 1
)

tail_incr_w = addr_incr_end_window + buffer_min_size - 1
tail_incr_h = addr_incr_end_row + buffer_min_size - 1
tail_incr_last_window = stride_w
tail_incr_w = (stride_w - 1) * channel_factor + 1
tail_incr_h = (
(skip_columns + (kernel_width - 1)) * channel_factor + 1
) + ( # remaining line
(stride_h - 1) * w * channel_factor
) # skip lines
tail_incr_last_window = stride_w * channel_factor

addr_incr_end_simd = 1
addr_incr_end_window_elem = 1
Expand Down Expand Up @@ -810,15 +813,21 @@ def prepare_codegen_parallel(self):
for ky in range(k_h):
reg_fifo = []
for kx in range(k_w):
reg_fifo.append(px_idx)
px_idx += 1
for c in range(channel_factor):
if c < (channel_factor - 1):
if not (ky == 0 and kx == 0):
reg_fifo.append(-1)
px_idx += 1
else:
reg_fifo.append(px_idx)
px_idx += 1
if kx < (k_w - 1):
reg_fifo.extend([-1] * (dilation_w - 1))
px_idx += dilation_w - 1
reg_fifo.extend([-1] * ((dilation_w - 1) * channel_factor))
px_idx += (dilation_w - 1) * channel_factor
reg_fifos.append(reg_fifo)

if ky < (k_h - 1):
line_buffer_len = (w - kernel_width) + w * (dilation_h - 1)
line_buffer_len = ((w - kernel_width) + w * (dilation_h - 1)) * channel_factor
bram_fifos_depth.append(line_buffer_len)
px_idx += line_buffer_len

Expand Down Expand Up @@ -926,6 +935,7 @@ def select_impl_style(self):
"""Selects implementation style based on folding configuration."""
simd = self.get_nodeattr("SIMD")
M = self.get_nodeattr("M")
depthwise = self.get_nodeattr("depthwise")
ifm_ch = self.get_nodeattr("IFMChannels")
ifm_dim = self.get_nodeattr("IFMDim")
stride = self.get_nodeattr("Stride")
Expand All @@ -950,7 +960,6 @@ def select_impl_style(self):
if self.get_nodeattr("parallel_window"):
# mmv_in = M * 1
mmv_out = M * k_h * k_w
assert ifm_ch == simd, "Constraint violated: SIMD must be equal to IFMChannels"
else:
# mmv_in = 1
mmv_out = 1
Expand All @@ -959,7 +968,11 @@ def select_impl_style(self):
# choose implementation style
if mmv_out > 1 or (k_h == 1 and k_w == 1):
impl_style = "parallel"
assert ifm_ch == simd, "Constraint violated: SIMD must be equal to IFMChannels"
if depthwise:
# allow SIMD < IFM_CH in depthwise mode (VVAU supports the resulting data layout)
assert ifm_ch % simd == 0, "Constraint violated: SIMD must divide IFMChannels"
else:
assert ifm_ch == simd, "Constraint violated: SIMD must be equal to IFMChannels"
else:
impl_style = "default"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def test_fpgadataflow_slidingwindow_rtl(
pytest.skip("Not all combinations for stride > k edge case supported in default mode")
if k_h == 1 and k_w == 1 and simd != ifm_ch:
pytest.skip("1x1 Kernel only supported in parallel mode (SIMD=C)")
if parallel_window and simd != ifm_ch:
pytest.skip("Parallel window requires SIMD=C")
if parallel_window and simd != ifm_ch and not dw:
pytest.skip("Parallel window requires SIMD=C for non-depthwise case")

ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, 0, dilation_h)
ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, 0, dilation_w)
Expand Down

0 comments on commit 4c80cf8

Please sign in to comment.