diff --git a/shared_utils.py b/shared_utils.py index 4e777932..34482b1b 100644 --- a/shared_utils.py +++ b/shared_utils.py @@ -6,6 +6,7 @@ import pprint import re from itertools import chain +from typing import Dict, TypedDict from constants import * @@ -16,30 +17,35 @@ logging.basicConfig(level=LOG_LEVEL, format=LOG_FORMAT) +# Log an error message +def log_and_exit(message: str): + """Log an error message and exit the program.""" + logging.error(message) + raise SystemExit(1) + + # Initialize encoding to 32-bit '-' values -def initialize_encoding(bits=32): +def initialize_encoding(bits: int = 32) -> "list[str]": """Initialize encoding with '-' to represent don't care bits.""" return ["-"] * bits # Validate bit range and value -def validate_bit_range(msb, lsb, entry_value, line): +def validate_bit_range(msb: int, lsb: int, entry_value: int, line: str): """Validate the bit range and entry value.""" if msb < lsb: - logging.error( + log_and_exit( f'{line.split(" ")[0]:<10} has position {msb} less than position {lsb} in its encoding' ) - raise SystemExit(1) if entry_value >= (1 << (msb - lsb + 1)): - logging.error( + log_and_exit( f'{line.split(" ")[0]:<10} has an illegal value {entry_value} assigned as per the bit width {msb - lsb}' ) - raise SystemExit(1) # Split the instruction line into name and remaining part -def parse_instruction_line(line): +def parse_instruction_line(line: str) -> "tuple[str, str]": """Parse the instruction name and the remaining encoding details.""" name, remaining = line.split(" ", 1) name = name.replace(".", "_") # Replace dots for compatibility @@ -48,17 +54,18 @@ def parse_instruction_line(line): # Verify Overlapping Bits -def check_overlapping_bits(encoding, ind, line): +def check_overlapping_bits(encoding: "list[str]", ind: int, line: str): """Check for overlapping bits in the encoding.""" if encoding[31 - ind] != "-": - logging.error( + log_and_exit( f'{line.split(" ")[0]:<10} has {ind} bit overlapping in its opcodes' ) - raise SystemExit(1) # Update encoding for fixed ranges -def update_encoding_for_fixed_range(encoding, msb, lsb, entry_value, line): +def update_encoding_for_fixed_range( + encoding: "list[str]", msb: int, lsb: int, entry_value: int, line: str +): """ Update encoding bits for a given bit range. Checks for overlapping bits and assigns the value accordingly. @@ -70,7 +77,7 @@ def update_encoding_for_fixed_range(encoding, msb, lsb, entry_value, line): # Process fixed bit patterns -def process_fixed_ranges(remaining, encoding, line): +def process_fixed_ranges(remaining: str, encoding: "list[str]", line: str): """Process fixed bit ranges in the encoding.""" for s2, s1, entry in fixed_ranges.findall(remaining): msb, lsb, entry_value = int(s2), int(s1), int(entry, 0) @@ -83,9 +90,9 @@ def process_fixed_ranges(remaining, encoding, line): # Process single bit assignments -def process_single_fixed(remaining, encoding, line): +def process_single_fixed(remaining: str, encoding: "list[str]", line: str): """Process single fixed assignments in the encoding.""" - for lsb, value, drop in single_fixed.findall(remaining): + for lsb, value, _drop in single_fixed.findall(remaining): lsb = int(lsb, 0) value = int(value, 0) @@ -94,7 +101,7 @@ def process_single_fixed(remaining, encoding, line): # Main function to check argument look-up table -def check_arg_lut(args, encoding_args, name): +def check_arg_lut(args: "list[str]", encoding_args: "list[str]", name: str): """Check if arguments are present in arg_lut.""" for arg in args: if arg not in arg_lut: @@ -104,30 +111,28 @@ def check_arg_lut(args, encoding_args, name): # Handle missing argument mappings -def handle_arg_lut_mapping(arg, name): +def handle_arg_lut_mapping(arg: str, name: str): """Handle cases where an argument needs to be mapped to an existing one.""" parts = arg.split("=") if len(parts) == 2: - existing_arg, new_arg = parts + existing_arg, _new_arg = parts if existing_arg in arg_lut: arg_lut[arg] = arg_lut[existing_arg] else: - logging.error( + log_and_exit( f" Found field {existing_arg} in variable {arg} in instruction {name} " f"whose mapping in arg_lut does not exist" ) - raise SystemExit(1) else: - logging.error( + log_and_exit( f" Found variable {arg} in instruction {name} " f"whose mapping in arg_lut does not exist" ) - raise SystemExit(1) return arg # Update encoding args with variables -def update_encoding_args(encoding_args, arg, msb, lsb): +def update_encoding_args(encoding_args: "list[str]", arg: str, msb: int, lsb: int): """Update encoding arguments and ensure no overlapping.""" for ind in range(lsb, msb + 1): check_overlapping_bits(encoding_args, ind, arg) @@ -135,15 +140,26 @@ def update_encoding_args(encoding_args, arg, msb, lsb): # Compute match and mask -def convert_encoding_to_match_mask(encoding): +def convert_encoding_to_match_mask(encoding: "list[str]") -> "tuple[str, str]": """Convert the encoding list to match and mask strings.""" match = "".join(encoding).replace("-", "0") mask = "".join(encoding).replace("0", "1").replace("-", "0") return hex(int(match, 2)), hex(int(mask, 2)) +class SingleInstr(TypedDict): + encoding: str + variable_fields: "list[str]" + extension: "list[str]" + match: str + mask: str + + +InstrDict = Dict[str, SingleInstr] + + # Processing main function for a line in the encoding file -def process_enc_line(line, ext): +def process_enc_line(line: str, ext: str) -> "tuple[str, SingleInstr]": """ This function processes each line of the encoding files (rv*). As part of the processing, the function ensures that the encoding is legal through the @@ -199,13 +215,13 @@ def process_enc_line(line, ext): # Extract ISA Type -def extract_isa_type(ext_name): +def extract_isa_type(ext_name: str) -> str: """Extracts the ISA type from the extension name.""" return ext_name.split("_")[0] # Verify the types for RV* -def is_rv_variant(type1, type2): +def is_rv_variant(type1: str, type2: str) -> bool: """Checks if the types are RV variants (rv32/rv64).""" return (type2 == "rv" and type1 in {"rv32", "rv64"}) or ( type1 == "rv" and type2 in {"rv32", "rv64"} @@ -213,77 +229,79 @@ def is_rv_variant(type1, type2): # Check for same base ISA -def has_same_base_isa(type1, type2): +def has_same_base_isa(type1: str, type2: str) -> bool: """Determines if the two ISA types share the same base.""" return type1 == type2 or is_rv_variant(type1, type2) # Compare the base ISA type of a given extension name against a list of extension names -def same_base_isa(ext_name, ext_name_list): +def same_base_isa(ext_name: str, ext_name_list: "list[str]") -> bool: """Checks if the base ISA type of ext_name matches any in ext_name_list.""" type1 = extract_isa_type(ext_name) return any(has_same_base_isa(type1, extract_isa_type(ext)) for ext in ext_name_list) # Pad two strings to equal length -def pad_to_equal_length(str1, str2, pad_char="-"): +def pad_to_equal_length(str1: str, str2: str, pad_char: str = "-") -> "tuple[str, str]": """Pads two strings to equal length using the given padding character.""" max_len = max(len(str1), len(str2)) return str1.rjust(max_len, pad_char), str2.rjust(max_len, pad_char) # Check compatibility for two characters -def has_no_conflict(char1, char2): +def has_no_conflict(char1: str, char2: str) -> bool: """Checks if two characters are compatible (either matching or don't-care).""" return char1 == "-" or char2 == "-" or char1 == char2 # Conflict check between two encoded strings -def overlaps(x, y): +def overlaps(x: str, y: str) -> bool: """Checks if two encoded strings overlap without conflict.""" x, y = pad_to_equal_length(x, y) return all(has_no_conflict(x[i], y[i]) for i in range(len(x))) # Check presence of keys in dictionary. -def is_in_nested_dict(a, key1, key2): +def is_in_nested_dict(a: "dict[str, set[str]]", key1: str, key2: str) -> bool: """Checks if key2 exists in the dictionary under key1.""" return key1 in a and key2 in a[key1] # Overlap allowance -def overlap_allowed(a, x, y): +def overlap_allowed(a: "dict[str, set[str]]", x: str, y: str) -> bool: """Determines if overlap is allowed between x and y based on nested dictionary checks""" return is_in_nested_dict(a, x, y) or is_in_nested_dict(a, y, x) # Check overlap allowance between extensions -def extension_overlap_allowed(x, y): +def extension_overlap_allowed(x: str, y: str) -> bool: """Checks if overlap is allowed between two extensions using the overlapping_extensions dictionary.""" return overlap_allowed(overlapping_extensions, x, y) # Check overlap allowance between instructions -def instruction_overlap_allowed(x, y): +def instruction_overlap_allowed(x: str, y: str) -> bool: """Checks if overlap is allowed between two instructions using the overlapping_instructions dictionary.""" return overlap_allowed(overlapping_instructions, x, y) # Check 'nf' field -def is_segmented_instruction(instruction): +def is_segmented_instruction(instruction: SingleInstr) -> bool: """Checks if an instruction contains the 'nf' field.""" return "nf" in instruction["variable_fields"] # Expand 'nf' fields -def update_with_expanded_instructions(updated_dict, key, value): +def update_with_expanded_instructions( + updated_dict: InstrDict, key: str, value: SingleInstr +): """Expands 'nf' fields in the instruction dictionary and updates it with new instructions.""" for new_key, new_value in expand_nf_field(key, value): updated_dict[new_key] = new_value # Process instructions, expanding segmented ones and updating the dictionary -def add_segmented_vls_insn(instr_dict): +def add_segmented_vls_insn(instr_dict: InstrDict) -> InstrDict: """Processes instructions, expanding segmented ones and updating the dictionary.""" # Use dictionary comprehension for efficiency return dict( @@ -299,7 +317,9 @@ def add_segmented_vls_insn(instr_dict): # Expand the 'nf' field in the instruction dictionary -def expand_nf_field(name, single_dict): +def expand_nf_field( + name: str, single_dict: SingleInstr +) -> "list[tuple[str, SingleInstr]]": """Validate and prepare the instruction dictionary.""" validate_nf_field(single_dict, name) remove_nf_field(single_dict) @@ -322,29 +342,33 @@ def expand_nf_field(name, single_dict): # Validate the presence of 'nf' -def validate_nf_field(single_dict, name): +def validate_nf_field(single_dict: SingleInstr, name: str): """Validates the presence of 'nf' in variable fields before expansion.""" if "nf" not in single_dict["variable_fields"]: - logging.error(f"Cannot expand nf field for instruction {name}") - raise SystemExit(1) + log_and_exit(f"Cannot expand nf field for instruction {name}") # Remove 'nf' from variable fields -def remove_nf_field(single_dict): +def remove_nf_field(single_dict: SingleInstr): """Removes 'nf' from variable fields in the instruction dictionary.""" single_dict["variable_fields"].remove("nf") # Update the mask to include the 'nf' field -def update_mask(single_dict): +def update_mask(single_dict: SingleInstr): """Updates the mask to include the 'nf' field in the instruction dictionary.""" single_dict["mask"] = hex(int(single_dict["mask"], 16) | 0b111 << 29) # Create an expanded instruction def create_expanded_instruction( - name, single_dict, nf, name_expand_index, base_match, encoding_prefix -): + name: str, + single_dict: SingleInstr, + nf: int, + name_expand_index: int, + base_match: int, + encoding_prefix: str, +) -> "tuple[str, SingleInstr]": """Creates an expanded instruction based on 'nf' value.""" new_single_dict = copy.deepcopy(single_dict) @@ -363,7 +387,7 @@ def create_expanded_instruction( # Return a list of relevant lines from the specified file -def read_lines(file): +def read_lines(file: str) -> "list[str]": """Reads lines from a file and returns non-blank, non-comment lines.""" with open(file) as fp: lines = (line.rstrip() for line in fp) @@ -371,7 +395,9 @@ def read_lines(file): # Update the instruction dictionary -def process_standard_instructions(lines, instr_dict, file_name): +def process_standard_instructions( + lines: "list[str]", instr_dict: InstrDict, file_name: str +): """Processes standard instructions from the given lines and updates the instruction dictionary.""" for line in lines: if "$import" in line or "$pseudo" in line: @@ -409,7 +435,12 @@ def process_standard_instructions(lines, instr_dict, file_name): # Incorporate pseudo instructions into the instruction dictionary based on given conditions def process_pseudo_instructions( - lines, instr_dict, file_name, opcodes_dir, include_pseudo, include_pseudo_ops + lines: "list[str]", + instr_dict: InstrDict, + file_name: str, + opcodes_dir: str, + include_pseudo: bool, + include_pseudo_ops: "list[str]", ): """Processes pseudo instructions from the given lines and updates the instruction dictionary.""" for line in lines: @@ -433,12 +464,15 @@ def process_pseudo_instructions( else: if single_dict["match"] != instr_dict[name]["match"]: instr_dict[f"{name}_pseudo"] = single_dict - elif single_dict["extension"] not in instr_dict[name]["extension"]: + # TODO: This expression is always false since both sides are list[str]. + elif single_dict["extension"] not in instr_dict[name]["extension"]: # type: ignore instr_dict[name]["extension"].extend(single_dict["extension"]) # Integrate imported instructions into the instruction dictionary -def process_imported_instructions(lines, instr_dict, file_name, opcodes_dir): +def process_imported_instructions( + lines: "list[str]", instr_dict: InstrDict, file_name: str, opcodes_dir: str +): """Processes imported instructions from the given lines and updates the instruction dictionary.""" for line in lines: if "$import" not in line: @@ -464,7 +498,7 @@ def process_imported_instructions(lines, instr_dict, file_name, opcodes_dir): # Locate the path of the specified extension file, checking fallback directories -def find_extension_file(ext, opcodes_dir): +def find_extension_file(ext: str, opcodes_dir: str): """Finds the extension file path, considering the unratified directory if necessary.""" ext_file = f"{opcodes_dir}/{ext}" if not os.path.exists(ext_file): @@ -475,7 +509,9 @@ def find_extension_file(ext, opcodes_dir): # Confirm the presence of an original instruction in the corresponding extension file. -def validate_instruction_in_extension(inst, ext_file, file_name, pseudo_inst): +def validate_instruction_in_extension( + inst: str, ext_file: str, file_name: str, pseudo_inst: str +): """Validates if the original instruction exists in the dependent extension.""" found = False for oline in open(ext_file): @@ -489,7 +525,11 @@ def validate_instruction_in_extension(inst, ext_file, file_name, pseudo_inst): # Construct a dictionary of instructions filtered by specified criteria -def create_inst_dict(file_filter, include_pseudo=False, include_pseudo_ops=[]): +def create_inst_dict( + file_filter: "list[str]", + include_pseudo: bool = False, + include_pseudo_ops: "list[str]" = [], +) -> InstrDict: """Creates a dictionary of instructions based on the provided file filters.""" """ @@ -522,7 +562,7 @@ def create_inst_dict(file_filter, include_pseudo=False, include_pseudo_ops=[]): is not already present; otherwise, it is skipped. """ opcodes_dir = os.path.dirname(os.path.realpath(__file__)) - instr_dict = {} + instr_dict: InstrDict = {} file_names = [ file @@ -559,17 +599,10 @@ def create_inst_dict(file_filter, include_pseudo=False, include_pseudo_ops=[]): # Extracts the extensions used in an instruction dictionary -def instr_dict_2_extensions(instr_dict): +def instr_dict_2_extensions(instr_dict: InstrDict) -> "list[str]": return list({item["extension"][0] for item in instr_dict.values()}) # Returns signed interpretation of a value within a given width -def signed(value, width): +def signed(value: int, width: int) -> int: return value if 0 <= value < (1 << (width - 1)) else value - (1 << width) - - -# Log an error message -def log_and_exit(message): - """Log an error message and exit the program.""" - logging.error(message) - raise SystemExit(1)