Skip to content

Commit

Permalink
LVM: First "Working" version
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinjoseph1995 committed Oct 26, 2024
1 parent 770d716 commit 62eaabe
Show file tree
Hide file tree
Showing 6 changed files with 508 additions and 93 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[workspace]
resolver = "2"
members = ["bril/bril-rs/bril2json", "common", "driver", "optimizations"]
members = ["bril/bril-rs/bril2json", "common", "driver", "optimizations", "bril/brilirs"]
7 changes: 5 additions & 2 deletions driver/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{error::Error, io::Read};

use clap::Parser;
use optimizations::OptimizationPass;
use optimizations::{OptimizationPass, PassManager};

#[derive(Parser)]
#[command(version, about, long_about = None)]
Expand Down Expand Up @@ -34,9 +34,12 @@ fn main() -> Result<(), Box<dyn Error>> {

let mut optimizations = args.optimizations;
optimizations.dedup_by(|a, b| a == b);
let mut pass_manager = PassManager::new();
for optimization in optimizations.iter() {
optimization.apply(&mut program);
pass_manager.register_pass(*optimization);
}
program = pass_manager.run(program);

println!("{}", program);
Ok(())
}
2 changes: 2 additions & 0 deletions optimizations/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ edition = "2021"
clap = "4.5.20"
common = { version = "0.1.0", path = "../common" }
indoc = "2.0.5"
smallstr = "0.3.0"
smallvec = "1.13.2"

[dependencies.bril-rs]
version = "0.1.0"
Expand Down
50 changes: 40 additions & 10 deletions optimizations/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,60 @@ mod local_value_numbering;
use bril_rs::{Code, Program};
use clap::ValueEnum;
use common::BasicBlock;
#[derive(ValueEnum, Clone, Debug, PartialEq)]
#[derive(ValueEnum, Clone, Debug, PartialEq, Copy)]
pub enum OptimizationPass {
LocalDeadCodeElimination,
LocalValueNumbering,
}

impl OptimizationPass {
pub fn apply(&self, program: &mut Program) {
match &self {
pub struct PassManager {
registered_passes: Vec<Box<dyn Pass>>,
}

impl PassManager {
fn construct_pass(pass_type: OptimizationPass) -> Box<dyn Pass> {
match pass_type {
OptimizationPass::LocalDeadCodeElimination => {
local_dead_code_elimination::apply(program)
Box::new(local_dead_code_elimination::LocalDeadCodeEliminationPass::new())
}
OptimizationPass::LocalValueNumbering => {
Box::new(local_value_numbering::LocalValueNumberingPass::new())
}
OptimizationPass::LocalValueNumbering => local_value_numbering::apply(program),
}
}
pub fn new() -> Self {
Self {
registered_passes: vec![],
}
}
pub fn register_pass(&mut self, pass_type: OptimizationPass) {
self.registered_passes.push(Self::construct_pass(pass_type));
}
pub fn run(&mut self, mut program: Program) -> Program {
for pass in self.registered_passes.iter_mut() {
program = pass.apply(program);
}
program
}
}

pub trait Pass {
fn apply(&mut self, program: Program) -> Program;
}

pub trait LocalScopePass {
fn apply(&mut self, input_block: BasicBlock) -> BasicBlock;
}

fn apply_for_each_block<F>(program: &mut bril_rs::Program, block_optimization : F)
where F: Fn(&BasicBlock) -> BasicBlock,
pub fn apply_for_each_block<P>(mut program: bril_rs::Program, pass_manager: &mut P) -> Program
where
P: LocalScopePass,
{
program.functions.iter_mut().for_each(|function| {
// For every function optimze the basic blocks within it.
function.instrs = common::construct_basic_block_stream(&function.instrs)
.iter_mut() // For every block
.map(|block| block_optimization(block)) // Optimize this block
.into_iter() // For every block
.map(|block| pass_manager.apply(block)) // Optimize this block
.map(|optimized_block| -> Vec<Code> {
// Re-form instruction stream from blocks
match optimized_block.name {
Expand All @@ -46,4 +75,5 @@ where F: Fn(&BasicBlock) -> BasicBlock,
.flatten()
.collect::<Vec<Code>>();
});
return program;
}
174 changes: 99 additions & 75 deletions optimizations/src/local_dead_code_elimination.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,92 @@
use bril_rs::{Code, Instruction};
use bril_rs::{Code, Instruction, Program};
use common::BasicBlock;
use std::collections::HashMap;

use crate::{LocalScopePass, Pass};

pub struct LocalDeadCodeEliminationPass {
instruction_stream_workspace: Vec<Code>,
}

impl LocalDeadCodeEliminationPass {
pub fn new() -> Self {
Self {
instruction_stream_workspace: Vec::new(),
}
}
}

impl LocalScopePass for LocalDeadCodeEliminationPass {
fn apply(&mut self, block: BasicBlock) -> BasicBlock {
let mut instuction_stream = block.instruction_stream;
loop {
// Iterate in a loop till convergence
// We need to find all the dead stores
let mut last_defined = HashMap::<&str, usize>::new();
let mut deletion_mask = Vec::<bool>::new();
deletion_mask.resize(instuction_stream.len(), false);
let mut atleast_one_marked_for_deleteion = false;

for index in 0..instuction_stream.len() {
let instruction = match &instuction_stream[index] {
Code::Label { label: _, pos: _ } => panic!("Invalid pre-condition"),
Code::Instruction(instruction) => instruction,
};
// Check for loads
for variable_loaded in get_load_sources(&instruction) {
let _ = last_defined.remove(variable_loaded.as_str());
}
// Check for stores
let variable_stored = get_store_destination(&instruction);
if variable_stored.is_some() {
match last_defined.get(variable_stored.unwrap()) {
Some(dead_store_index) => {
deletion_mask[*dead_store_index] = true;
atleast_one_marked_for_deleteion = true;
} // Mark for deletion,
None => {}
}
last_defined.insert(variable_stored.unwrap(), index);
}
}
// Iterate through all the stores that were not read from
for (_label, index) in last_defined {
deletion_mask[index] = true;
atleast_one_marked_for_deleteion = true;
}
if !atleast_one_marked_for_deleteion {
// We have converged. No more dead stores in this local block
break;
}
{
// Perform the actual deletion of instructions
self.instruction_stream_workspace.clear();
self.instruction_stream_workspace
.reserve(instuction_stream.len());
for (index, instr) in instuction_stream.iter().enumerate() {
if !deletion_mask[index] {
self.instruction_stream_workspace.push(instr.clone());
}
}
std::mem::swap(
&mut instuction_stream,
&mut self.instruction_stream_workspace,
);
}
}
return BasicBlock {
name: block.name.clone(),
instruction_stream: instuction_stream,
};
}
}

impl Pass for LocalDeadCodeEliminationPass {
fn apply(&mut self, program: bril_rs::Program) -> Program {
crate::apply_for_each_block(program, self)
}
}

fn get_load_sources(instruction: &Instruction) -> &[String] {
match instruction {
bril_rs::Instruction::Value {
Expand Down Expand Up @@ -52,74 +137,9 @@ fn get_store_destination(instruction: &Instruction) -> Option<&str> {
}
}

fn dead_code_elimination(block: &BasicBlock) -> BasicBlock {
let mut instuction_stream = block.instruction_stream.clone();
loop {
// Iterate in a loop till convergence
// We need to find all the dead stores
let mut last_defined = HashMap::<&str, usize>::new();
let mut deletion_mask = Vec::<bool>::new();
deletion_mask.resize(instuction_stream.len(), false);
let mut atleast_one_marked_for_deleteion = false;

for index in 0..instuction_stream.len() {
let instruction = match &instuction_stream[index] {
Code::Label { label: _, pos: _ } => panic!("Invalid pre-condition"),
Code::Instruction(instruction) => instruction,
};
// Check for loads
for variable_loaded in get_load_sources(&instruction) {
let _ = last_defined.remove(variable_loaded.as_str());
}
// Check for stores
let variable_stored = get_store_destination(&instruction);
if variable_stored.is_some() {
match last_defined.get(variable_stored.unwrap()) {
Some(dead_store_index) => {
deletion_mask[*dead_store_index] = true;
atleast_one_marked_for_deleteion = true;
} // Mark for deletion,
None => {}
}
last_defined.insert(variable_stored.unwrap(), index);
}
}
// Iterate through all the stores that were not read from
for (_label, index) in last_defined {
deletion_mask[index] = true;
atleast_one_marked_for_deleteion = true;
}
if !atleast_one_marked_for_deleteion {
// We have converged. No more dead stores in this local block
break;
}
instuction_stream = {
let mut new_instruction_stream = Vec::<Code>::new();
new_instruction_stream.reserve(instuction_stream.len());
for (index, instr) in instuction_stream.iter().enumerate() {
if !deletion_mask[index] {
new_instruction_stream.push(instr.clone());
}
}
new_instruction_stream
};
}
return BasicBlock {
name: block.name.clone(),
instruction_stream: instuction_stream,
};
}

pub fn apply(program: &mut bril_rs::Program) {
crate::apply_for_each_block(
program,
dead_code_elimination,
);
}

#[cfg(test)]
mod tests {
use crate::OptimizationPass;
use super::*;

#[test]
fn test_local_dead_code_elimination_1() {
Expand All @@ -133,8 +153,9 @@ mod tests {
"#};
let program = common::parse_bril_text(&BRIL_PROGRAM_TEXT);
assert!(program.is_ok());
let mut program = program.unwrap();
OptimizationPass::LocalDeadCodeElimination.apply(&mut program);
let program = program.unwrap();
let mut manager = LocalDeadCodeEliminationPass::new();
let program = Pass::apply(&mut manager, program);
assert!(program.functions[0].instrs.is_empty());
}
#[test]
Expand All @@ -150,8 +171,9 @@ mod tests {
"#};
let program = common::parse_bril_text(&BRIL_PROGRAM_TEXT);
assert!(program.is_ok());
let mut program = program.unwrap();
OptimizationPass::LocalDeadCodeElimination.apply(&mut program);
let program = program.unwrap();
let mut manager = LocalDeadCodeEliminationPass::new();
let program = Pass::apply(&mut manager, program);
assert!(program.functions[0].instrs.len() == 4); // "v3: int = const 3;" is a dead store and will get deleted
}

Expand All @@ -166,8 +188,9 @@ mod tests {
"#};
let program = common::parse_bril_text(&BRIL_PROGRAM_TEXT);
assert!(program.is_ok());
let mut program = program.unwrap();
OptimizationPass::LocalDeadCodeElimination.apply(&mut program);
let program = program.unwrap();
let mut manager = LocalDeadCodeEliminationPass::new();
let program = Pass::apply(&mut manager, program);
assert!(program.functions[0].instrs.len() == 2); // "a: int = const 100"; is a dead store and will get deleted
}

Expand All @@ -185,8 +208,9 @@ mod tests {
"#};
let program = common::parse_bril_text(&BRIL_PROGRAM_TEXT);
assert!(program.is_ok());
let mut program = program.unwrap();
OptimizationPass::LocalDeadCodeElimination.apply(&mut program);
let program = program.unwrap();
let mut manager = LocalDeadCodeEliminationPass::new();
let program = Pass::apply(&mut manager, program);
/*
Output program:
@main {
Expand Down
Loading

0 comments on commit 62eaabe

Please sign in to comment.