From 4e9af82c18948caa4c202b57bb7de2a6a4c1a646 Mon Sep 17 00:00:00 2001 From: Lennart Van Hirtum Date: Mon, 8 Jan 2024 20:28:05 +0100 Subject: [PATCH] First Block Instruction! if Statement. Also generative now --- multiply_add.sus | 13 +++ src/arena_alloc.rs | 33 ++++--- src/dev_aid/syntax_highlighting.rs | 2 +- src/flattening.rs | 139 +++++++++++++++-------------- src/instantiation/mod.rs | 119 +++++++++++++++--------- src/typing.rs | 63 +++++++------ src/value.rs | 6 +- 7 files changed, 219 insertions(+), 156 deletions(-) diff --git a/multiply_add.sus b/multiply_add.sus index caa89cc..740d913 100644 --- a/multiply_add.sus +++ b/multiply_add.sus @@ -310,6 +310,11 @@ module generative : int in -> int o, int o2 { a[0] = 10; gen int[3] xx = a; + gen bool test = true; + + if test { + + } o = a[in]; o2 = a[a[0]]; @@ -339,6 +344,14 @@ module first_bit_idx_6 : bool[6] bits -> int first, bool all_zeros { } else { all_zeros = true; } + + /*first int i in 0..6 where bits[i] { + first = i; + all_zeros = false; + } else { + all_zeros = true; + }*/ + } module first_bit_idx_24 : bool[24] bits -> int first { diff --git a/src/arena_alloc.rs b/src/arena_alloc.rs index 68f54d1..19d3a42 100644 --- a/src/arena_alloc.rs +++ b/src/arena_alloc.rs @@ -29,23 +29,13 @@ impl UUID { } } -#[derive(Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct UUIDRange(pub UUID, pub UUID); -impl IntoIterator for &UUIDRange { - type Item = UUID; - - type IntoIter = UUIDRangeIter; - - fn into_iter(self) -> UUIDRangeIter { - UUIDRangeIter(UUID(self.0.0, PhantomData), UUID(self.1.0, PhantomData)) - } -} - #[derive(Clone, Copy, PartialEq, Eq, Hash)] pub struct UUIDRangeIter(UUID, UUID); -impl Iterator for UUIDRangeIter { +impl Iterator for UUIDRange { type Item = UUID; fn next(&mut self) -> Option { @@ -59,6 +49,19 @@ impl Iterator for UUIDRangeIter { } } +impl UUIDRange { + pub fn skip_to(&mut self, to : UUID) { + assert!(to.0 >= self.0.0); + assert!(to.0 <= self.1.0); + self.0 = to; + } + pub fn len(&self) -> usize { + self.1.0 - self.0.0 + } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} @@ -275,6 +278,9 @@ impl ListAllocator { pub fn get_next_alloc_id(&self) -> UUID { UUID(self.data.len(), PhantomData) } + pub fn id_range(&self) -> UUIDRange { + UUIDRange(UUID(0, PhantomData), UUID(self.data.len(), PhantomData)) + } pub fn iter<'a>(&'a self) -> ListAllocIterator<'a, T, IndexMarker> { self.into_iter() } @@ -387,6 +393,9 @@ impl FlatAlloc { pub fn len(&self) -> usize { self.data.len() } + pub fn id_range(&self) -> UUIDRange { + UUIDRange(UUID(0, PhantomData), UUID(self.data.len(), PhantomData)) + } pub fn is_empty(&self) -> bool { self.data.is_empty() } diff --git a/src/dev_aid/syntax_highlighting.rs b/src/dev_aid/syntax_highlighting.rs index 00b0553..8dec9f5 100644 --- a/src/dev_aid/syntax_highlighting.rs +++ b/src/dev_aid/syntax_highlighting.rs @@ -138,7 +138,7 @@ fn walk_name_color(all_objects : &[NamedUUID], links : &Links, result : &mut [ID if decl.identifier_type == IdentifierType::Virtual {continue;} // Virtual wires don't appear in the program text result[conn.to.span.0].typ = IDETokenType::Identifier(IDEIdentifierType::Value(decl.identifier_type)); } - Instantiation::SubModule(_) => {} + Instantiation::SubModule(_) | Instantiation::IfStatement(_) => {} } } } diff --git a/src/flattening.rs b/src/flattening.rs index b5b1bc4..4fa6453 100644 --- a/src/flattening.rs +++ b/src/flattening.rs @@ -2,8 +2,8 @@ use std::{ops::{Deref, Range}, iter::zip}; use crate::{ ast::{Span, Module, Expression, SpanExpression, LocalOrGlobal, Operator, AssignableExpression, SpanAssignableExpression, Statement, CodeBlock, IdentifierType, TypeExpression, DeclIDMarker, SignalDeclaration}, - linker::{Linker, Named, Linkable, get_builtin_uuid, FileUUID, NamedUUID}, - errors::{ErrorCollector, error_info}, arena_alloc::{UUID, UUIDMarker, FlatAlloc}, tokenizer::kw, typing::{Type, typecheck_unary_operator, get_binary_operator_types, typecheck, typecheck_is_array_indexer}, value::Value + linker::{Linker, Named, Linkable, FileUUID, NamedUUID}, + errors::{ErrorCollector, error_info}, arena_alloc::{UUID, UUIDMarker, FlatAlloc, UUIDRange}, typing::{Type, typecheck_unary_operator, get_binary_operator_types, typecheck, typecheck_is_array_indexer, BOOL_TYPE, INT_TYPE}, value::Value }; #[derive(Debug,Clone,Copy,PartialEq,Eq,Hash)] @@ -11,6 +11,8 @@ pub struct FlatIDMarker; impl UUIDMarker for FlatIDMarker {const DISPLAY_NAME : &'static str = "obj_";} pub type FlatID = UUID; +pub type FlatIDRange = UUIDRange; + pub type FieldID = usize; #[derive(Debug)] @@ -41,8 +43,7 @@ impl ConnectionWrite { pub struct Connection { pub num_regs : i64, pub from : FlatID, - pub to : ConnectionWrite, - pub condition : Option + pub to : ConnectionWrite } #[derive(Debug)] @@ -107,12 +108,22 @@ impl SubModuleInstance { } } +#[derive(Debug)] +pub struct IfStatement{ + pub is_compiletime : bool, + pub condition : FlatID, + pub then_start : FlatID, + pub then_end_else_start : FlatID, + pub else_end : FlatID +} + #[derive(Debug)] pub enum Instantiation { SubModule(SubModuleInstance), WireDeclaration(WireDeclaration), Wire(WireInstance), - Connection(Connection) + Connection(Connection), + IfStatement(IfStatement) } impl Instantiation { @@ -136,6 +147,7 @@ impl Instantiation { match self { Instantiation::SubModule(_) => {} Instantiation::Connection(_) => {} + Instantiation::IfStatement(_) => {} Instantiation::WireDeclaration(decl) => { f(&decl.typ, decl.typ_span); } @@ -168,7 +180,7 @@ impl<'l, 'm> FlatteningContext<'l, 'm> { TypeExpression::Array(b) => { let (array_type_expr, array_size_expr) = b.deref(); let array_element_type = self.map_to_type(&array_type_expr.0); - if let Some(array_size_wire_id) = self.flatten_single_expr(array_size_expr, None) { + if let Some(array_size_wire_id) = self.flatten_single_expr(array_size_expr) { let array_size_wire = self.instantiations[array_size_wire_id].extract_wire(); if !array_size_wire.is_compiletime { self.errors.error_basic(array_size_expr.1, "Array size must be compile time"); @@ -233,7 +245,7 @@ impl<'l, 'm> FlatteningContext<'l, 'm> { })) } // Returns the module, full interface, and the output range for the function call syntax - fn desugar_func_call(&mut self, func_and_args : &[SpanExpression], closing_bracket_pos : usize, condition : Option) -> Option<(&Module, Box<[FlatID]>, Range)> { + fn desugar_func_call(&mut self, func_and_args : &[SpanExpression], closing_bracket_pos : usize) -> Option<(&Module, Box<[FlatID]>, Range)> { let (name_expr, name_expr_span) = &func_and_args[0]; // Function name is always there let func_instantiation_id = match name_expr { Expression::Named(LocalOrGlobal::Local(l)) => { @@ -276,23 +288,23 @@ impl<'l, 'm> FlatteningContext<'l, 'm> { } for (field, arg_expr) in zip(inputs, args) { - if let Some(arg_read_side) = self.flatten_single_expr(arg_expr, condition) { + if let Some(arg_read_side) = self.flatten_single_expr(arg_expr) { /*if self.typecheck(arg_read_side, &md.interface.interface_wires[field].typ, "submodule output") == None { continue; }*/ let func_input_port = &submodule_local_wires[field]; - self.instantiations.alloc(Instantiation::Connection(Connection { num_regs: 0, from: arg_read_side, to: ConnectionWrite::simple(*func_input_port, *name_expr_span), condition })); + self.instantiations.alloc(Instantiation::Connection(Connection{num_regs: 0, from: arg_read_side, to: ConnectionWrite::simple(*func_input_port, *name_expr_span)})); } } Some((md, submodule_local_wires, output_range)) } - fn flatten_single_expr(&mut self, (expr, expr_span) : &SpanExpression, condition : Option) -> Option { + fn flatten_single_expr(&mut self, (expr, expr_span) : &SpanExpression) -> Option { let (is_compiletime, source) = match expr { Expression::Named(LocalOrGlobal::Local(l)) => { let from_wire = self.decl_to_flat_map[*l].unwrap(); - let WireDeclaration { typ: _, typ_span:_, read_only:_, identifier_type, name:_, name_token:_ } = self.instantiations[from_wire].extract_wire_declaration(); - (*identifier_type == IdentifierType::Generative, WireSource::WireRead{from_wire}) + let decl = self.instantiations[from_wire].extract_wire_declaration(); + (decl.identifier_type == IdentifierType::Generative, WireSource::WireRead{from_wire}) } Expression::Named(LocalOrGlobal::Global(g)) => { let r = self.module.link_info.global_references[*g]; @@ -304,14 +316,14 @@ impl<'l, 'm> FlatteningContext<'l, 'm> { } Expression::UnaryOp(op_box) => { let (op, _op_pos, operate_on) = op_box.deref(); - let right = self.flatten_single_expr(operate_on, condition)?; + let right = self.flatten_single_expr(operate_on)?; let right_wire = self.instantiations[right].extract_wire(); (right_wire.is_compiletime, WireSource::UnaryOp{op : *op, right}) } Expression::BinOp(binop_box) => { let (left_expr, op, _op_pos, right_expr) = binop_box.deref(); - let left = self.flatten_single_expr(left_expr, condition)?; - let right = self.flatten_single_expr(right_expr, condition)?; + let left = self.flatten_single_expr(left_expr)?; + let right = self.flatten_single_expr(right_expr)?; let left_wire = self.instantiations[left].extract_wire(); let right_wire = self.instantiations[right].extract_wire(); let is_compiletime = left_wire.is_compiletime && right_wire.is_compiletime; @@ -319,14 +331,14 @@ impl<'l, 'm> FlatteningContext<'l, 'm> { } Expression::Array(arr_box) => { let (left, right, _bracket_span) = arr_box.deref(); - let arr = self.flatten_single_expr(left, condition)?; - let arr_idx = self.flatten_single_expr(right, condition)?; + let arr = self.flatten_single_expr(left)?; + let arr_idx = self.flatten_single_expr(right)?; let arr_wire = self.instantiations[arr].extract_wire(); let arr_idx_wire = self.instantiations[arr_idx].extract_wire(); (arr_wire.is_compiletime && arr_idx_wire.is_compiletime, WireSource::ArrayAccess{arr, arr_idx}) } Expression::FuncCall(func_and_args) => { - let (md, interface_wires, outputs_range) = self.desugar_func_call(func_and_args, expr_span.1, condition)?; + let (md, interface_wires, outputs_range) = self.desugar_func_call(func_and_args, expr_span.1)?; if outputs_range.len() != 1 { let info = error_info(md.link_info.span, md.link_info.file, "Module Defined here"); @@ -338,9 +350,10 @@ impl<'l, 'm> FlatteningContext<'l, 'm> { } }; - Some(self.instantiations.alloc(Instantiation::Wire(WireInstance{typ : Type::Unknown, span : *expr_span, is_compiletime, source}))) + let wire_instance = WireInstance{typ : Type::Unknown, span : *expr_span, is_compiletime, source}; + Some(self.instantiations.alloc(Instantiation::Wire(wire_instance))) } - fn flatten_assignable_expr(&mut self, (expr, span) : &SpanAssignableExpression, condition : Option) -> Option { + fn flatten_assignable_expr(&mut self, (expr, span) : &SpanAssignableExpression) -> Option { Some(match expr { AssignableExpression::Named{local_idx} => { let root = self.decl_to_flat_map[*local_idx].unwrap(); @@ -355,9 +368,9 @@ impl<'l, 'm> FlatteningContext<'l, 'm> { } AssignableExpression::ArrayIndex(arr_box) => { let (arr, idx_expr, _bracket_span) = arr_box.deref(); - let flattened_arr_expr_opt = self.flatten_assignable_expr(arr, condition); + let flattened_arr_expr_opt = self.flatten_assignable_expr(arr); - let idx = self.flatten_single_expr(idx_expr, condition)?; + let idx = self.flatten_single_expr(idx_expr)?; let mut flattened_arr_expr = flattened_arr_expr_opt?; // only unpack the subexpr after flattening the idx, so we catch all errors @@ -367,17 +380,6 @@ impl<'l, 'm> FlatteningContext<'l, 'm> { } }) } - fn extend_condition(&mut self, condition : Option, additional_condition : FlatID) -> FlatID { - if let Some(condition) = condition { - let bool_typ = Type::Named(get_builtin_uuid("bool")); - let prev_condition_wire = self.instantiations[condition].extract_wire(); - let additional_condition_wire = self.instantiations[condition].extract_wire(); - assert!(!prev_condition_wire.is_compiletime); // Conditions are only used for runtime conditions. Compile time ifs are handled at instantiation time - self.instantiations.alloc(Instantiation::Wire(WireInstance{typ : bool_typ, is_compiletime : false, span : additional_condition_wire.span, source : WireSource::BinaryOp{op: Operator{op_typ : kw("&")}, left : condition, right : additional_condition}})) - } else { - additional_condition - } - } fn flatten_declaration(&mut self, decl : &SignalDeclaration) -> FlatID { assert!(decl.identifier_type != IdentifierType::Input); assert!(decl.identifier_type != IdentifierType::Output); @@ -416,7 +418,7 @@ impl<'l, 'm> FlatteningContext<'l, 'm> { name_token : Some(decl.name_token) })) } - fn flatten_code(&mut self, code : &CodeBlock, condition : Option) { + fn flatten_code(&mut self, code : &CodeBlock) { for (stmt, stmt_span) in &code.statements { match stmt { Statement::Declaration(decl_id) => { @@ -426,26 +428,29 @@ impl<'l, 'm> FlatteningContext<'l, 'm> { self.decl_to_flat_map[*decl_id] = Some(wire_id); } Statement::If{condition : condition_expr, then, els} => { - let Some(if_statement_condition) = self.flatten_single_expr(condition_expr, condition) else {continue;}; + let Some(condition) = self.flatten_single_expr(condition_expr) else {continue;}; - let condition_is_const = self.instantiations[if_statement_condition].extract_wire().is_compiletime; + let is_compiletime = self.instantiations[condition].extract_wire().is_compiletime; - if condition_is_const { - println!("TODO generative if statements"); - } + let if_id = self.instantiations.get_next_alloc_id(); + let if_id_proper = self.instantiations.alloc(Instantiation::IfStatement(IfStatement{is_compiletime, condition, then_start : if_id, then_end_else_start : if_id, else_end : if_id})); + assert!(if_id == if_id_proper); + let then_start = self.instantiations.get_next_alloc_id(); - //let bool_typ = Type::Named(get_builtin_uuid("bool")); - //if self.typecheck(if_statement_condition, &bool_typ, "if statement condition") == None {continue;} - let then_condition = self.extend_condition(condition, if_statement_condition); - self.flatten_code(then, Some(then_condition)); + self.flatten_code(then); + let then_end_else_start = self.instantiations.get_next_alloc_id(); if let Some(e) = els { - let else_condition_bool = self.instantiations.alloc(Instantiation::Wire(WireInstance{typ : Type::Unknown, is_compiletime : false/* Generative If */, span : condition_expr.1, source : WireSource::UnaryOp{op : Operator{op_typ : kw("!")}, right : if_statement_condition}})); - let else_condition = self.extend_condition(condition, else_condition_bool); - self.flatten_code(e, Some(else_condition)); + self.flatten_code(e); } + let else_end = self.instantiations.get_next_alloc_id(); + + let Instantiation::IfStatement(ifstmt) = &mut self.instantiations[if_id] else {unreachable!()}; + ifstmt.then_start = then_start; + ifstmt.then_end_else_start = then_end_else_start; + ifstmt.else_end = else_end; } Statement::Assign{to, expr : (Expression::FuncCall(func_and_args), func_span), eq_sign_position} => { - let Some((md, interface, outputs_range)) = self.desugar_func_call(&func_and_args, func_span.1, condition) else {return;}; + let Some((md, interface, outputs_range)) = self.desugar_func_call(&func_and_args, func_span.1) else {return;}; let outputs = &interface[outputs_range]; let func_name_span = func_and_args[0].1; @@ -463,26 +468,26 @@ impl<'l, 'm> FlatteningContext<'l, 'm> { } for (field, to_i) in zip(outputs, to) { - let Some(write_side) = self.flatten_assignable_expr(&to_i.expr, condition) else {return;}; + let Some(write_side) = self.flatten_assignable_expr(&to_i.expr) else {return;}; // temporary let module_port_wire_decl = self.instantiations[*field].extract_wire_declaration(); let module_port_proxy = self.instantiations.alloc(Instantiation::Wire(WireInstance{typ : module_port_wire_decl.typ.clone(), is_compiletime : module_port_wire_decl.identifier_type == IdentifierType::Generative, span : *func_span, source : WireSource::WireRead { from_wire: *field }})); - self.instantiations.alloc(Instantiation::Connection(Connection{num_regs : to_i.num_regs, from: module_port_proxy, to: write_side, condition})); + self.instantiations.alloc(Instantiation::Connection(Connection{num_regs : to_i.num_regs, from: module_port_proxy, to: write_side})); } }, Statement::Assign{to, expr : non_func_expr, eq_sign_position : _} => { if to.len() == 1 { - let Some(read_side) = self.flatten_single_expr(non_func_expr, condition) else {return;}; + let Some(read_side) = self.flatten_single_expr(non_func_expr) else {return;}; let t = &to[0]; - let Some(write_side) = self.flatten_assignable_expr(&t.expr, condition) else {return;}; - self.instantiations.alloc(Instantiation::Connection(Connection{num_regs : t.num_regs, from: read_side, to: write_side, condition})); + let Some(write_side) = self.flatten_assignable_expr(&t.expr) else {return;}; + self.instantiations.alloc(Instantiation::Connection(Connection{num_regs : t.num_regs, from: read_side, to: write_side})); } else { self.errors.error_basic(*stmt_span, format!("Non-function assignments must only output exactly 1 instead of {}", to.len())); } }, Statement::Block(inner_code) => { - self.flatten_code(inner_code, condition); + self.flatten_code(inner_code); }, Statement::TimelineStage(_) => {/*TODO */} } @@ -544,7 +549,7 @@ impl FlattenedModule { let interface_ports = context.initialize_interface(); - context.flatten_code(&module.code, None); + context.flatten_code(&module.code); let flat_mod = FlattenedModule { instantiations : context.instantiations, @@ -567,6 +572,10 @@ impl FlattenedModule { match &self.instantiations[elem_id] { Instantiation::SubModule(_) => {} Instantiation::WireDeclaration(_) => {}, + Instantiation::IfStatement(stm) => { + let wire = &self.instantiations[stm.condition].extract_wire(); + self.typecheck_wire_is_of_type(wire, &BOOL_TYPE, "if statement condition", linker) + } Instantiation::Wire(w) => { let result_typ = match &w.source { &WireSource::WireRead{from_wire} => { @@ -588,7 +597,7 @@ impl FlattenedModule { let arr_wire = self.instantiations[arr].extract_wire(); let arr_idx_wire = self.instantiations[arr_idx].extract_wire(); - self.typecheck_wire_is_of_type(arr_idx_wire, &Type::Named(get_builtin_uuid("int")), "array index", linker); + self.typecheck_wire_is_of_type(arr_idx_wire, &INT_TYPE, "array index", linker); if let Some(typ) = typecheck_is_array_indexer(&arr_wire.typ, arr_wire.span, linker, &self.errors) { typ.clone() } else { @@ -611,7 +620,7 @@ impl FlattenedModule { match p { &ConnectionWritePathElement::ArrayIdx{idx, idx_span} => { let idx_wire = self.instantiations[idx].extract_wire(); - self.typecheck_wire_is_of_type(idx_wire, &Type::Named(get_builtin_uuid("int")), "array index", linker); + self.typecheck_wire_is_of_type(idx_wire, &INT_TYPE, "array index", linker); if let Some(wr) = write_to_type { write_to_type = typecheck_is_array_indexer(wr, idx_span, linker, &self.errors); } @@ -630,12 +639,6 @@ impl FlattenedModule { if let Some(target_type) = write_to_type { self.typecheck_wire_is_of_type(from_wire, &target_type, "connection", linker); } - - // Typecheck condition is bool - if let Some(condition) = conn.condition { - let condition_wire = self.instantiations[condition].extract_wire(); - self.typecheck_wire_is_of_type(condition_wire, &Type::Named(get_builtin_uuid("bool")), "assignment condition", linker); - } } } } @@ -659,9 +662,6 @@ impl FlattenedModule { match inst { Instantiation::Connection(conn) => { gathered_connection_fanin[conn.to.root].push(conn.from); - if let Some(cond) = conn.condition { - gathered_connection_fanin[conn.to.root].push(cond); - } } Instantiation::SubModule(sm) => { for w in sm.outputs() { @@ -670,6 +670,13 @@ impl FlattenedModule { } Instantiation::WireDeclaration(_) => {} // Handle these outside Instantiation::Wire(_) => {} + Instantiation::IfStatement(stm) => { + for id in UUIDRange(stm.then_start, stm.else_end) { + if let Instantiation::Connection(conn) = &self.instantiations[id] { + gathered_connection_fanin[conn.to.root].push(stm.condition); + } + } + } } } @@ -701,7 +708,7 @@ impl FlattenedModule { mark_not_unused(*port); } } - Instantiation::Connection(_) => {unreachable!()} + Instantiation::Connection(_) | Instantiation::IfStatement(_) => {unreachable!()} } for from in &gathered_connection_fanin[item] { mark_not_unused(*from); diff --git a/src/instantiation/mod.rs b/src/instantiation/mod.rs index 132e1bb..395fc0f 100644 --- a/src/instantiation/mod.rs +++ b/src/instantiation/mod.rs @@ -2,7 +2,7 @@ use std::{rc::Rc, ops::Deref, cell::RefCell}; use num::BigInt; -use crate::{arena_alloc::{UUID, UUIDMarker, FlatAlloc}, ast::{Operator, IdentifierType, Span}, typing::{ConcreteType, Type}, flattening::{FlatID, Instantiation, FlatIDMarker, ConnectionWritePathElement, WireSource, WireInstance, Connection, ConnectionWritePathElementComputed, FlattenedModule}, errors::ErrorCollector, linker::{Linker, get_builtin_uuid}, value::{Value, compute_unary_op, compute_binary_op}}; +use crate::{arena_alloc::{UUID, UUIDMarker, FlatAlloc, UUIDRange}, ast::{Operator, IdentifierType, Span}, typing::{ConcreteType, Type, BOOL_CONCRETE_TYPE, INT_CONCRETE_TYPE}, flattening::{FlatID, Instantiation, FlatIDMarker, ConnectionWritePathElement, WireSource, WireInstance, Connection, ConnectionWritePathElementComputed, FlattenedModule, FlatIDRange}, errors::ErrorCollector, linker::Linker, value::{Value, compute_unary_op, compute_binary_op}, tokenizer::kw}; pub mod latency; @@ -125,11 +125,6 @@ impl SubModuleOrWire { *result } #[track_caller] - fn extract_submodule(&self) -> SubModuleID { - let Self::SubModule(result) = self else {panic!("Failed SubModule extraction! Is {self:?} instead")}; - *result - } - #[track_caller] fn extract_generation_value(&self) -> &Value { let Self::CompileTimeValue(result) = self else {panic!("Failed GenerationValue extraction! Is {self:?} instead")}; result @@ -159,8 +154,24 @@ struct InstantiationContext<'fl, 'l> { } impl<'fl, 'l> InstantiationContext<'fl, 'l> { + fn get_generation_value(&self, v : FlatID) -> Option<&Value> { + if let SubModuleOrWire::CompileTimeValue(vv) = &self.generation_state[v] { + if let Value::Unset = vv { + self.errors.error_basic(self.flattened.instantiations[v].extract_wire().span, "This variable is set but it's Value::Unset!"); + None + } else if let Value::Error = vv { + self.errors.error_basic(self.flattened.instantiations[v].extract_wire().span, "This variable is set but it's Value::Error!"); + None + } else { + Some(vv) + } + } else { + self.errors.error_basic(self.flattened.instantiations[v].extract_wire().span, "This variable is not set at this point!"); + None + } + } fn extract_integer_from_value<'v, IntT : TryFrom<&'v BigInt>>(&self, val : &'v Value, span : Span) -> Option { - let Value::Integer(val) = val else {self.errors.error_basic(span, format!("Value is not an int, it is {val:?} instead")); return None}; + let val = val.extract_integer(); // Typecheck should cover this match IntT::try_from(val) { Ok(val) => Some(val), Err(_) => { @@ -169,6 +180,10 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { } } } + fn extract_bool_from_value(&self, val : &Value, span : Span) -> Option { + let Value::Bool(val) = val else {self.errors.error_basic(span, format!("Value is not a bool, it is {val:?} instead")); return None}; + Some(*val) + } fn concretize_type(&self, typ : &Type, span : Span) -> Option { match typ { Type::Error | Type::Unknown => unreachable!("Bad types should be caught in flattening: {}", typ.to_string(self.linker)), @@ -185,7 +200,7 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { } } } - fn process_connection_to_wire(&mut self, to_path : &[ConnectionWritePathElement], from : ConnectFrom, wire_id : WireID) { + fn process_connection_to_wire(&mut self, to_path : &[ConnectionWritePathElement], from : ConnectFrom, wire_id : WireID) -> Option<()> { let mut new_path : Vec = Vec::new(); let mut write_to_typ = &self.wires[wire_id].typ; @@ -197,12 +212,12 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { SubModuleOrWire::SubModule(_) => unreachable!(), SubModuleOrWire::Unnasigned => unreachable!(), SubModuleOrWire::Wire(idx_wire) => { - assert!(self.wires[*idx_wire].typ == ConcreteType::Named(get_builtin_uuid("int"))); + assert!(self.wires[*idx_wire].typ == INT_CONCRETE_TYPE); new_path.push(ConnectToPathElem::MuxArrayWrite{idx_wire : *idx_wire}); } SubModuleOrWire::CompileTimeValue(v) => { - let Some(idx) = self.extract_integer_from_value(v, *idx_span) else {return}; + let idx = self.extract_integer_from_value(v, *idx_span)?; new_path.push(ConnectToPathElem::ConstArrayWrite{idx}); } } @@ -221,6 +236,8 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { let RealWireDataSource::Multiplexer{is_state : _, sources} = &mut self.wires[wire_id].source else {unreachable!("Should only be a writeable wire here")}; sources.push(MultiplexerSource{from, path : new_path}); + + Some(()) } fn convert_connection_path_to_known_values(&self, conn_path : &[ConnectionWritePathElement]) -> Option> { let mut result = Vec::new(); @@ -228,24 +245,22 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { for p in conn_path { match p { ConnectionWritePathElement::ArrayIdx{idx, idx_span} => { - let Some(idx_val) = self.get_generation_value(*idx) else {return None}; - let Some(idx_val) = self.extract_integer_from_value::(idx_val, *idx_span) else {return None}; + let idx_val = self.get_generation_value(*idx)?; + let idx_val = self.extract_integer_from_value::(idx_val, *idx_span)?; result.push(ConnectionWritePathElementComputed::ArrayIdx(idx_val)) } } } Some(result) } - fn process_connection(&mut self, conn : &Connection, original_wire : FlatID) { + fn process_connection(&mut self, conn : &Connection, original_wire : FlatID, condition : Option) -> Option<()> { match &self.generation_state[conn.to.root] { SubModuleOrWire::SubModule(_) => unreachable!(), SubModuleOrWire::Unnasigned => unreachable!(), SubModuleOrWire::Wire(w) => { // Runtime wire let deref_w = *w; - let condition = conn.condition.map(|found_conn| self.generation_state[found_conn].extract_wire()); - - let Some(from) = self.get_wire_or_constant_as_wire(conn.from) else {return;}; + let from = self.get_wire_or_constant_as_wire(conn.from)?; let conn_from = ConnectFrom { num_regs: conn.num_regs, from, @@ -253,27 +268,17 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { original_wire }; - self.process_connection_to_wire(&conn.to.path, conn_from, deref_w); - - return; + self.process_connection_to_wire(&conn.to.path, conn_from, deref_w)?; } SubModuleOrWire::CompileTimeValue(_original_value) => { // Compiletime wire let found_v = self.generation_state[conn.from].extract_generation_value().clone(); - let Some(cvt_path) = self.convert_connection_path_to_known_values(&conn.to.path) else {return}; + let cvt_path = self.convert_connection_path_to_known_values(&conn.to.path)?; // Hack to get around the borrow rules here let SubModuleOrWire::CompileTimeValue(v_writable) = &mut self.generation_state[conn.to.root] else {unreachable!()}; write_gen_variable(v_writable, &cvt_path, found_v); } }; - - } - fn get_generation_value(&self, v : FlatID) -> Option<&Value> { - if let SubModuleOrWire::CompileTimeValue(vv) = &self.generation_state[v] { - Some(vv) - } else { - self.errors.error_basic(self.flattened.instantiations[v].extract_wire().span, "This variable is not set at this point!"); - None - } + Some(()) } fn compute_compile_time(&self, wire_inst : &WireSource) -> Option { Some(match wire_inst { @@ -358,9 +363,17 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { let name = self.get_unique_name(); Some(self.wires.alloc(RealWire{ name, typ, original_wire, source})) } - fn instantiate_flattened_module(&mut self) { - for (original_wire, inst) in &self.flattened.instantiations { - let instance_to_add : SubModuleOrWire = match inst { + fn extend_condition(&mut self, condition : Option, additional_condition : WireID, original_wire : FlatID) -> WireID { + if let Some(condition) = condition { + self.wires.alloc(RealWire{typ : BOOL_CONCRETE_TYPE, name : self.get_unique_name(), original_wire, source : RealWireDataSource::BinaryOp{op: Operator{op_typ : kw("&")}, left : condition, right : additional_condition}}) + } else { + additional_condition + } + } + fn instantiate_flattened_module(&mut self, flat_range : FlatIDRange, condition : Option) -> Option<()> { + let mut instruction_range = flat_range.into_iter(); + while let Some(original_wire) = instruction_range.next() { + let instance_to_add : SubModuleOrWire = match &self.flattened.instantiations[original_wire] { Instantiation::SubModule(submodule) => { let Some(instance) = self.linker.instantiate(submodule.module_uuid) else {continue}; // Avoid error from submodule let interface_real_wires = submodule.local_wires.iter().map(|port| { @@ -369,9 +382,7 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { SubModuleOrWire::SubModule(self.submodules.alloc(SubModule { original_flat: original_wire, instance, wires : interface_real_wires, name : submodule.name.clone()})) } Instantiation::WireDeclaration(wire_decl) => { - let Some(typ) = self.concretize_type(&wire_decl.typ, wire_decl.typ_span) else { - return; // Exit early, do not produce invalid wires in InstantiatedModule - }; + let typ = self.concretize_type(&wire_decl.typ, wire_decl.typ_span)?; if wire_decl.identifier_type == IdentifierType::Generative { /*Do nothing (in fact re-initializes the wire to 'empty'), just corresponds to wire declaration*/ if wire_decl.read_only { @@ -392,25 +403,49 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { } } Instantiation::Wire(w) => { - let Some(typ) = self.concretize_type(&w.typ, w.span) else { - return; // Exit early, do not produce invalid wires in InstantiatedModule - }; + let typ = self.concretize_type(&w.typ, w.span)?; if w.is_compiletime { - let Some(value_computed) = self.compute_compile_time(&w.source) else {return}; + let value_computed = self.compute_compile_time(&w.source)?; assert!(value_computed.is_of_type(&typ)); SubModuleOrWire::CompileTimeValue(value_computed) } else { - let Some(wire_found) = self.wire_to_real_wire(w, typ, original_wire) else {return}; + let wire_found = self.wire_to_real_wire(w, typ, original_wire)?; SubModuleOrWire::Wire(wire_found) } } Instantiation::Connection(conn) => { - self.process_connection(conn, original_wire); + self.process_connection(conn, original_wire, condition); + continue; + } + Instantiation::IfStatement(stm) => { + let then_range = UUIDRange(stm.then_start, stm.then_end_else_start); + let else_range = UUIDRange(stm.then_end_else_start, stm.else_end); + if stm.is_compiletime { + let condition_val = self.get_generation_value(stm.condition)?; + let run_range = if condition_val.extract_bool() { + then_range + } else { + else_range + }; + self.instantiate_flattened_module(run_range, condition); + } else { + let condition_wire = self.generation_state[stm.condition].extract_wire(); + let then_cond = self.extend_condition(condition, condition_wire, original_wire); + self.instantiate_flattened_module(then_range, Some(then_cond)); + + if !else_range.is_empty() { + let else_condition_bool = self.wires.alloc(RealWire{typ : BOOL_CONCRETE_TYPE, name : self.get_unique_name(), original_wire, source : RealWireDataSource::UnaryOp{op : Operator{op_typ : kw("!")}, right : condition_wire}}); + let else_cond = self.extend_condition(condition, else_condition_bool, original_wire); + self.instantiate_flattened_module(else_range, Some(else_cond)); + } + } + instruction_range.skip_to(stm.else_end); continue; } }; self.generation_state[original_wire] = instance_to_add; } + Some(()) } // Returns a proper interface if all ports involved did not produce an error. If a port did produce an error then returns None. @@ -467,7 +502,7 @@ impl InstantiationList { errors : ErrorCollector::new(flattened.errors.file) }; - context.instantiate_flattened_module(); + context.instantiate_flattened_module(flattened.instantiations.id_range(), None); let interface = context.make_interface(); cache_borrow.push(Rc::new(InstantiatedModule{ diff --git a/src/typing.rs b/src/typing.rs index 28ee05c..a24486b 100644 --- a/src/typing.rs +++ b/src/typing.rs @@ -60,23 +60,26 @@ impl Type { } } + +pub const BOOL_TYPE : Type = Type::Named(get_builtin_uuid("bool")); +pub const INT_TYPE : Type = Type::Named(get_builtin_uuid("int")); +pub const BOOL_CONCRETE_TYPE : ConcreteType = ConcreteType::Named(get_builtin_uuid("bool")); +pub const INT_CONCRETE_TYPE : ConcreteType = ConcreteType::Named(get_builtin_uuid("int")); + pub fn typecheck_unary_operator(op : Operator, input_typ : &Type, span : Span, linker : &Linker, errors : &ErrorCollector) -> Type { - const BOOL : Type = Type::Named(get_builtin_uuid("bool")); - const INT : Type = Type::Named(get_builtin_uuid("int")); - if op.op_typ == kw("!") { - typecheck(input_typ, span, &BOOL, "! input", linker, errors); - BOOL + typecheck(input_typ, span, &BOOL_TYPE, "! input", linker, errors); + BOOL_TYPE } else if op.op_typ == kw("-") { - typecheck(input_typ, span, &INT, "- input", linker, errors); - INT + typecheck(input_typ, span, &INT_TYPE, "- input", linker, errors); + INT_TYPE } else { let gather_type = match op.op_typ { - x if x == kw("&") => BOOL, - x if x == kw("|") => BOOL, - x if x == kw("^") => BOOL, - x if x == kw("+") => INT, - x if x == kw("*") => INT, + x if x == kw("&") => BOOL_TYPE, + x if x == kw("|") => BOOL_TYPE, + x if x == kw("^") => BOOL_TYPE, + x if x == kw("+") => INT_TYPE, + x if x == kw("*") => INT_TYPE, _ => unreachable!() }; if let Some(arr_content_typ) = typecheck_is_array_indexer(input_typ, span, linker, errors) { @@ -86,27 +89,23 @@ pub fn typecheck_unary_operator(op : Operator, input_typ : &Type, span : Span, l } } pub fn get_binary_operator_types(op : Operator) -> ((Type, Type), Type) { - const BOOL : NamedUUID = get_builtin_uuid("bool"); - const INT : NamedUUID = get_builtin_uuid("int"); - - let (a, b, o) = match op.op_typ { - x if x == kw("&") => (BOOL, BOOL, BOOL), - x if x == kw("|") => (BOOL, BOOL, BOOL), - x if x == kw("^") => (BOOL, BOOL, BOOL), - x if x == kw("+") => (INT, INT, INT), - x if x == kw("-") => (INT, INT, INT), - x if x == kw("*") => (INT, INT, INT), - x if x == kw("/") => (INT, INT, INT), - x if x == kw("%") => (INT, INT, INT), - x if x == kw("==") => (INT, INT, BOOL), - x if x == kw("!=") => (INT, INT, BOOL), - x if x == kw(">=") => (INT, INT, BOOL), - x if x == kw("<=") => (INT, INT, BOOL), - x if x == kw(">") => (INT, INT, BOOL), - x if x == kw("<") => (INT, INT, BOOL), + match op.op_typ { + x if x == kw("&") => ((BOOL_TYPE, BOOL_TYPE), BOOL_TYPE), + x if x == kw("|") => ((BOOL_TYPE, BOOL_TYPE), BOOL_TYPE), + x if x == kw("^") => ((BOOL_TYPE, BOOL_TYPE), BOOL_TYPE), + x if x == kw("+") => ((INT_TYPE, INT_TYPE), INT_TYPE), + x if x == kw("-") => ((INT_TYPE, INT_TYPE), INT_TYPE), + x if x == kw("*") => ((INT_TYPE, INT_TYPE), INT_TYPE), + x if x == kw("/") => ((INT_TYPE, INT_TYPE), INT_TYPE), + x if x == kw("%") => ((INT_TYPE, INT_TYPE), INT_TYPE), + x if x == kw("==") => ((INT_TYPE, INT_TYPE), BOOL_TYPE), + x if x == kw("!=") => ((INT_TYPE, INT_TYPE), BOOL_TYPE), + x if x == kw(">=") => ((INT_TYPE, INT_TYPE), BOOL_TYPE), + x if x == kw("<=") => ((INT_TYPE, INT_TYPE), BOOL_TYPE), + x if x == kw(">") => ((INT_TYPE, INT_TYPE), BOOL_TYPE), + x if x == kw("<") => ((INT_TYPE, INT_TYPE), BOOL_TYPE), _ => unreachable!() - }; - ((Type::Named(a), Type::Named(b)), Type::Named(o)) + } } fn type_compare(expected : &Type, found : &Type) -> bool { diff --git a/src/value.rs b/src/value.rs index 4017d3b..8f007a0 100644 --- a/src/value.rs +++ b/src/value.rs @@ -2,7 +2,7 @@ use std::ops::Deref; use num::BigInt; -use crate::{typing::{Type, ConcreteType}, linker::get_builtin_uuid, ast::Operator, tokenizer::kw}; +use crate::{typing::{Type, ConcreteType, BOOL_TYPE, INT_TYPE}, linker::get_builtin_uuid, ast::Operator, tokenizer::kw}; #[derive(Debug,Clone,PartialEq,Eq)] pub enum Value { @@ -16,8 +16,8 @@ pub enum Value { impl Value { pub fn get_type_of_constant(&self) -> Type { match self { - Value::Bool(_) => Type::Named(get_builtin_uuid("bool")), - Value::Integer(_) => Type::Named(get_builtin_uuid("int")), + Value::Bool(_) => BOOL_TYPE, + Value::Integer(_) => INT_TYPE, Value::Array(_b) => { unreachable!("Can't express arrays as constants (yet?)"); /*let content_typ = if let Some(b_first) = b.first() {