Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added when keyword #43

Merged
merged 8 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/flattening/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1154,18 +1154,30 @@ impl<'l, 'errs> FlatteningContext<'l, 'errs> {

fn flatten_if_statement(&mut self, cursor: &mut Cursor) {
cursor.go_down(kind!("if_statement"), |cursor| {
cursor.field(field!("statement_type"));
let keyword_is_if = cursor.kind() == kw!("if");
let position_statement_keyword = cursor.span();
cursor.field(field!("condition"));
let (condition, condition_is_generative) = self.flatten_expr(cursor);
match(keyword_is_if, condition_is_generative){
(true, false) => {
self.errors.warn(position_statement_keyword, "Used 'if' in a non generative context, use 'when' instead");
},
(false, true) => {
self.errors.error(position_statement_keyword, "Used 'when' in a generative context, use 'if' instead");
},
(_, _) => ()
}

let if_id = self.instructions.alloc(Instruction::IfStatement(IfStatement {
let if_id = self.instructions.alloc(Instruction::IfStatement(IfStatement {
condition,
is_generative: condition_is_generative,// TODO `if` vs `when` https://github.com/pc2/sus-compiler/issues/3
is_generative: keyword_is_if,
then_start: FlatID::PLACEHOLDER,
then_end_else_start: FlatID::PLACEHOLDER,
else_end: FlatID::PLACEHOLDER,
}));
let then_start = self.instructions.get_next_alloc_id();

let then_start = self.instructions.get_next_alloc_id();
cursor.field(field!("then_block"));
self.flatten_code(cursor);

Expand Down
1 change: 1 addition & 0 deletions src/flattening/initialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ impl<'linker> InitializationContext<'linker> {

fn gather_ports_in_if_stmt(&mut self, cursor: &mut Cursor) {
cursor.go_down_no_check(|cursor| {
cursor.field(field!("statement_type"));
cursor.field(field!("condition"));
cursor.field(field!("then_block"));
self.gather_all_ports_in_block(cursor);
Expand Down
25 changes: 11 additions & 14 deletions stl/util.sus
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module DualPortMem #(T, int SIZE) {
domain write_clk
interface write : bool write, int addr, T data

if write {
when write {
mem[addr] = data
}

Expand Down Expand Up @@ -48,16 +48,16 @@ module FIFO #(
CrossDomain mem_to_pop
mem_to_pop.in = mem

if pop {
when pop {
data_valid = read_addr != write_to_pop.out
if data_valid {
when data_valid {
// Add a pipelining register, because it can usually be fitted to the
reg data_out = mem_to_pop.out[read_addr]
read_addr = (read_addr + 1) % DEPTH
}
}

if push {
when push {
mem[write_addr] = data_in
write_addr = (write_addr + 1) % DEPTH
}
Expand All @@ -72,9 +72,6 @@ module FIFO #(
module JoinDomains #(T1, T2, int OFFSET) {
interface identity1 : T1 i1'0 -> T1 o1'0
interface identity2 : T2 i2'OFFSET -> T2 o2'OFFSET

o1 = i1
o2 = i2
}

module Iterator {
Expand All @@ -87,10 +84,10 @@ module Iterator {

valid = value != current_limit

if start {
when start {
current_limit = up_to
value = 0
} else if valid {
} else when valid {
value = value + 1
}
}
Expand All @@ -107,10 +104,10 @@ module FixedSizeIterator #(int UP_TO) {
last = value == UP_TO - 1
valid = value != UP_TO

if start {
when start {
value = 0
} else {
if valid {
when valid {
value = value + 1
}
}
Expand All @@ -121,7 +118,7 @@ module SlowClockGenerator #(int PERIOD) {

initial cur_value = 0

if cur_value == PERIOD-1 {
when cur_value == PERIOD-1 {
cur_value = 0
} else {
cur_value = cur_value + 1
Expand All @@ -142,7 +139,7 @@ module SplitAt #(T, int SIZE, int SPLIT_POINT) {
module Abs {
interface Abs : int a -> int o

if a < 0 {
when a < 0 {
o = -a
} else {
o = a
Expand Down Expand Up @@ -178,7 +175,7 @@ module PopCount #(int WIDTH) {
if WIDTH <= BASE_CASE_SIZE {
int[WIDTH] cvt
for int I in 0..WIDTH {
if bits[I] {
when bits[I] {
cvt[I] = 1
} else {
cvt[I] = 0
Expand Down
71 changes: 46 additions & 25 deletions test.sus
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ module blur2 {

state int prev

if !first {
when !first {
blurred = data + prev
}
prev = data
Expand Down Expand Up @@ -83,7 +83,7 @@ module Accumulator {
initial tot = 0

int new_tot = tot + term
if done {
when done {
reg total = new_tot
tot = 0
} else {
Expand All @@ -99,7 +99,7 @@ module blur {
initial working = false
state int prev

if working {
when working {
reg reg reg result = prev + a // Add a pipeline stage for shits and giggles
}
prev = a
Expand All @@ -119,19 +119,19 @@ module Unpack4 {
initial st = 0
state int[3] stored_packed

if st == INITIAL {
when st == INITIAL {
out_stream = packed[0]
stored_packed[0] = packed[1] // Shorthand notation is possible here "stored_packed[0:2] = packed[1:3]"
stored_packed[1] = packed[2]
stored_packed[2] = packed[3]
st = 1
} else if st == 1 {
} else when st == 1 {
out_stream = stored_packed[0]
st = 2
} else if st == 2 {
} else when st == 2 {
out_stream = stored_packed[1]
st = 3
} else if st == 3 {
} else when st == 3 {
out_stream = stored_packed[2]
st = INITIAL // Must restore initial conditions
//finish // packet is hereby finished.
Expand Down Expand Up @@ -183,22 +183,22 @@ module test_various_assignments {
//timeline (bs -> /, true) | (bs -> v, false)
module first_bit_idx_6 {
interface first_bit_idx_6 : bool[6] bits -> int first, bool all_zeros
if bits[0] {
when bits[0] {
first = 0
all_zeros = false
} else if bits[1] {
} else when bits[1] {
first = 1
all_zeros = false
} else if bits[2] {
} else when bits[2] {
first = 2
all_zeros = false
} else if bits[3] {
} else when bits[3] {
first = 3
all_zeros = false
} else if bits[4] {
} else when bits[4] {
first = 4
all_zeros = false
} else if bits[5] {
} else when bits[5] {
first = 5
all_zeros = false
} else {
Expand Down Expand Up @@ -471,11 +471,11 @@ module fizz_buzz {
bool fizz = v % 3 == 0
bool buzz = v % 5 == 0

if fizz & buzz {
when fizz & buzz {
fb = FIZZ_BUZZ
} else if fizz {
} else when fizz {
fb = FIZZ
} else if buzz {
} else when buzz {
fb = BUZZ
} else {
fb = v
Expand Down Expand Up @@ -633,13 +633,13 @@ module dual_port_mem {

interface read : bool read, int rd_addr -> bool[20] rd_data

if write {
when write {
mem[wr_addr] = wr_data
}

cross_memory cr_m
cr_m.i = mem
if read {
when read {
rd_data = cr_m.o[rd_addr]
}
}
Expand Down Expand Up @@ -737,11 +737,11 @@ module sequenceDownFrom {

valid = index != 0
ready_cr.i = !valid
if valid {
when valid {
index = index - 1
}

if start_cr.o {
when start_cr.o {
index = upTo_cr.o
}
}
Expand All @@ -756,7 +756,7 @@ module sumUpTo {
bool re = sdf.ready

bool iter_valid, int iter_index = sdf.iter()
if iter_valid {
when iter_valid {
int idx = iter_index
}

Expand Down Expand Up @@ -901,13 +901,13 @@ module run_instruction {
instruction_decoder decoder
decoder.from(instr)

if decoder.is_jump() : int target_addr {
when decoder.is_jump() : int target_addr {
// ...
}
if decoder.is_load() : int reg_to, int addr {
when decoder.is_load() : int reg_to, int addr {
// ...
}
if decoder.is_arith() : int reg_a, int reg_b, Operator op {
when decoder.is_arith() : int reg_a, int reg_b, Operator op {
// ...
}
}
Expand Down Expand Up @@ -1012,7 +1012,6 @@ module use_sized_int_add {
c = sized_int_add(a, b)
}


module implicit_domain_forbidden {
input int bad_port

Expand Down Expand Up @@ -1097,3 +1096,25 @@ module UseBuiltinConstants {
module FailingAssert {
assert #(C: 15 + 3 == 19)
}

/// Test if when seperation
module IfTesting #(int WIDTH) {
// Should be chosen based on what's most efficient for the target architecture
gen int BASE_CASE_SIZE = 5

interface PopCount : bool[WIDTH] bits -> int popcount


when WIDTH <= BASE_CASE_SIZE {
int[WIDTH] cvt
for int I in 0..WIDTH {
if bits[I] {
cvt[I] = 1
} else if !bits[I] {
cvt[I] = 0
}
}
reg popcount = +cvt
} else when WIDTH > BASE_CASE_SIZE {
}
}
Loading