Skip to content

Commit

Permalink
Fix ModuleLinker and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-Nak committed Dec 4, 2024
1 parent 5c79442 commit e28e898
Show file tree
Hide file tree
Showing 15 changed files with 230 additions and 61 deletions.
1 change: 1 addition & 0 deletions crates/codegen/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ indexmap = { version = "2.0" }

[dev-dependencies]
sonatina-parser = { path = "../parser", version = "0.0.3-alpha" }
insta = { version = "1.41" }
107 changes: 52 additions & 55 deletions crates/codegen/src/module_linker.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
//! This module defines a module-level linking on sonatina-IR that links
//! multiple sonatina modules into a single module.
use std::mem;

use cranelift_entity::entity_impl;
Expand All @@ -9,7 +12,7 @@ use sonatina_ir::{
module::FuncRef,
types::{CompoundType, CompoundTypeRef, StructData},
visitor::VisitorMut,
Function, GlobalVariableRef, Linkage, Module, Signature, Type, Value,
GlobalVariableRef, Linkage, Module, Signature, Type, Value,
};

/// A struct represents a linked module, that is the result of the
Expand Down Expand Up @@ -51,6 +54,10 @@ pub enum LinkError {
}

impl LinkedModule {
pub fn module(&self) -> &Module {
&self.module
}

/// Links multiple modules into a single module.
/// Returns a linked module and a list of module references.
/// The order of module references are the same as the input modules.
Expand Down Expand Up @@ -94,7 +101,7 @@ impl LinkedModule {
/// An entity representing a module reference.
/// This is used to identify a module in the linked module for mapping from
/// the source module reference to the linked module reference.
#[derive(Clone, PartialEq, Eq, Copy, Hash, PartialOrd, Ord)]
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, PartialOrd, Ord)]
pub struct ModuleRef(pub u32);
entity_impl!(ModuleRef);

Expand Down Expand Up @@ -198,6 +205,7 @@ struct ModuleLinker {

module_ref_map: DashMap<ModuleRef, RefMap>,

/// Modules to be linked.
modules: IndexMap<ModuleRef, Module>,
}

Expand Down Expand Up @@ -234,7 +242,7 @@ impl ModuleLinker {
}
}

/// Registers module as a source module to be linked.
/// Registers a module as a source module to be linked.
fn register_module(&mut self, module: Module) -> ModuleRef {
let next_id = self.module_ref_map.len();
let module_ref = ModuleRef(next_id as u32);
Expand All @@ -251,20 +259,49 @@ impl ModuleLinker {
// Links all references in the source modules to the linked module.
self.link_refs(&module_refs)?;

// Updates the references in the function body.
self.update_funcs(&module_refs);

let modules = mem::take(&mut self.modules);
// Move functions to the linked module.
for (module_ref, module) in modules {
let ref_map = self.module_ref_map.get(&module_ref).unwrap();

module.func_store.par_into_for_each(|func_ref, func| {
let linkage = func.dfg.ctx.func_sig(func_ref, |sig| sig.linkage());
if linkage.is_external() {
module.func_store.par_into_for_each(|func_ref, mut func| {
// If linkage is external, we don't need to move the function definition to the
// linked module.
if func
.dfg
.ctx
.func_sig(func_ref, |sig| sig.linkage())
.is_external()
{
return;
}

// Updates module context to the linked module.
func.dfg.ctx = self.builder.ctx.clone();

// Updates references in values to the linked module.
func.dfg.values.values_mut().for_each(|value| {
ref_map.update_value(value);
});

// Updates the references in instructions to the linked module.
struct InstUpdater<'a> {
ref_map: &'a Ref<'a, ModuleRef, RefMap>,
}
impl VisitorMut for InstUpdater<'_> {
fn visit_func_ref(&mut self, item: &mut FuncRef) {
*item = self.ref_map.lookup_func(*item);
}

fn visit_ty(&mut self, item: &mut Type) {
*item = self.ref_map.lookup_type(*item);
}
}
let mut visitor = InstUpdater { ref_map: &ref_map };
func.dfg
.insts
.values_mut()
.for_each(|inst| inst.accept_mut(&mut visitor));

let linked_func_ref = ref_map.lookup_func(func_ref);
self.builder.func_store.update(linked_func_ref, func);
})
Expand Down Expand Up @@ -479,12 +516,13 @@ impl ModuleLinker {

// Validates the linkage and update the linked gv if needed.
// The allowed combinations are:
// (SourceLinkage, LinkedLinkage) = (External, Public) or (Public, External).
// (SourceLinkage, LinkedLinkage) = (External, Public), (Public, External) or
// (External, External).
//
// Also, in case of LinkedLinkage is External, we need to update it to
// a Public.
match (gv_data.linkage, linked_gv_data.linkage) {
(Linkage::External, Linkage::Public) => {}
(Linkage::External, Linkage::Public) | (Linkage::External, Linkage::External) => {}
(Linkage::Public, Linkage::External) => {
s.update_linkage(linked_gv_ref, Linkage::Public);
}
Expand Down Expand Up @@ -544,11 +582,11 @@ impl ModuleLinker {
});
}

Ok(sig.linkage())
Ok(linked_sig.linkage())
})?;

match (sig.linkage(), linked_func_linkage) {
(Linkage::External, Linkage::Public) => {}
(Linkage::External, Linkage::Public) | (Linkage::External, Linkage::External) => {}
(Linkage::Public, Linkage::External) => {
self.builder
.ctx
Expand All @@ -564,45 +602,4 @@ impl ModuleLinker {
ref_map.func_mapping.insert(func_ref, linked_func_ref);
Ok(linked_func_ref)
}

fn update_funcs(&self, module_refs: &[ModuleRef]) {
for module_ref in module_refs {
let module = self.modules.get(module_ref).unwrap();
module
.func_store
.par_for_each(|_, func| self.update_func(*module_ref, func));
}
}

fn update_func(&self, module_ref: ModuleRef, func: &mut Function) {
let ref_map = self.module_ref_map.get(&module_ref).unwrap();

// Updates module context to the linked module.
func.dfg.ctx = self.builder.ctx.clone();

// Updates values to the linked module.
func.dfg.values.values_mut().for_each(|value| {
ref_map.update_value(value);
});

// Updates the instructions to the linked module.
struct InstUpdater<'a> {
ref_map: Ref<'a, ModuleRef, RefMap>,
}
impl VisitorMut for InstUpdater<'_> {
fn visit_func_ref(&mut self, item: &mut FuncRef) {
*item = self.ref_map.lookup_func(*item);
}

fn visit_ty(&mut self, item: &mut Type) {
*item = self.ref_map.lookup_type(*item);
}
}

let mut visitor = InstUpdater { ref_map };
func.dfg
.insts
.values_mut()
.for_each(|value| value.accept_mut(&mut visitor))
}
}
42 changes: 42 additions & 0 deletions crates/codegen/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,45 @@ pub fn parse_module(file_path: &str) -> ParsedModule {
}
}
}

// copied from fe test-utils
/// A macro to assert that a value matches a snapshot.
/// If the snapshot does not exist, it will be created in the same directory as
/// the test file.
#[macro_export]
macro_rules! snap_test {
($value:expr, $fixture_path: expr) => {
snap_test!($value, $fixture_path, None)
};

($value:expr, $fixture_path: expr, $suffix: expr) => {
let mut settings = insta::Settings::new();
let fixture_path = ::std::path::Path::new($fixture_path);
let fixture_dir = fixture_path.parent().unwrap();
let fixture_name = fixture_path.file_stem().unwrap().to_str().unwrap();

settings.set_snapshot_path(fixture_dir);
settings.set_input_file($fixture_path);
settings.set_prepend_module_to_snapshot(false);
settings.set_omit_expression(true);

let suffix: Option<&str> = $suffix;
let name = if let Some(suffix) = suffix {
format!("{fixture_name}.{suffix}")
} else {
fixture_name.into()
};
settings.bind(|| {
insta::_macro_support::assert_snapshot(
(name, $value.as_str()).into(),
std::path::Path::new(env!("CARGO_MANIFEST_DIR")),
fixture_name,
module_path!(),
file!(),
line!(),
&$value,
)
.unwrap()
})
};
}
45 changes: 39 additions & 6 deletions crates/codegen/tests/linker.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,52 @@
mod common;
use std::{env, path::Path};
use std::{
env,
path::{Path, PathBuf},
};

use common::parse_module;
use sonatina_codegen::module_linker::LinkedModule;
use sonatina_ir::ir_writer::ModuleWriter;

fn fixture_dir() -> PathBuf {
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let manifest_dir = Path::new(&manifest_dir);
manifest_dir.join("tests/linker/fixtures")
}

macro_rules! test_ok {
($name:ident) => {
#[test]
fn $name() {
let name = stringify!($name);
let dir = fixture_dir().join("link_ok");

let path_a = dir.join(format!("{name}_a.sntn"));
let path_b = dir.join(format!("{name}_b.sntn"));
let module_a = parse_module(path_a.to_str().unwrap());
let module_b = parse_module(path_b.to_str().unwrap());

let (linked, _) = LinkedModule::link(vec![module_a.module, module_b.module]).unwrap();
let mut writer = ModuleWriter::new(linked.module());
let module_text = writer.dump_string();

let snap_file_path = dir.join(format!("{name}.snap"));
snap_test!(module_text, snap_file_path.to_str().unwrap());
}
};
}

test_ok!(module);

macro_rules! test_error {
($name:ident) => {
#[test]
fn $name() {
let name = stringify!($name);
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let manifest_dir = Path::new(&manifest_dir);
let fixture_dir = manifest_dir.join("tests/linker/fixtures/errors");
let path_a = fixture_dir.join(format!("{name}_a.sntn"));
let path_b = fixture_dir.join(format!("{name}_b.sntn"));
let dir = fixture_dir().join("link_errors");

let path_a = dir.join(format!("{name}_a.sntn"));
let path_b = dir.join(format!("{name}_b.sntn"));
let module_a = parse_module(path_a.to_str().unwrap());
let module_b = parse_module(path_b.to_str().unwrap());

Expand All @@ -27,3 +59,4 @@ macro_rules! test_error {
test_error!(struct_error);
test_error!(gv_error);
test_error!(func_error);
test_error!(sig_error);
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
target = "evm-ethereum-cancun"

declare external %f(i64) -> i64;
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
target = "evm-ethereum-cancun"

func public %f(v0.i64) {
block0:
return;
}
38 changes: 38 additions & 0 deletions crates/codegen/tests/linker/fixtures/link_ok/module.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
---
source: crates/codegen/tests/linker.rs
input_file: tests/linker/fixtures/link_ok/module.snap
---
target = evm-ethereum-cancun

type @foo = {i8, i16, *i64};
type @bar = {i8, [i8; 31]};

global public const i256 $ZERO = 0
global external const i256 $ONE

declare func external %f_outer(i8) -> i8;

func public %f_b(v0.i8, v1.i8) -> i8 {
block0:
v2.i8 = call %f_a v0;
v3.i8 = add v0 v2;
return;
}

func public %f_a(v0.i8) -> i8 {
block0:
return v0;
}

func private %types_a(v0.*i8, v1.[i8; 2], v2.[*i8; 2], v3.[[i8; 2]; 2], v4.@foo, v5.*@bar) -> i8 {
block0:
v6.i8 = mload v0 i8;
v7.i8 = call %f_b v6 1.i8;
call %f_outer;
return v6;
}

func private %types_b(v0.*i8, v1.[i8; 2], v2.[*i8; 2], v3.[[i8; 2]; 2], v4.@bar, v5.*@foo) {
block0:
return;
}
25 changes: 25 additions & 0 deletions crates/codegen/tests/linker/fixtures/link_ok/module_a.sntn
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
target = "evm-ethereum-cancun"

type @foo = { i8, i16, *i64 };
type @bar = <{ i8, [i8; 31] }>;

declare external %f_b(i8, i8) -> i8;
declare external %f_outer(i8) -> i8;


global public const i256 $ZERO = 0
global external const i256 $ONE


func public %f_a(v0.i8) -> i8 {
block0:
return v0;
}

func private %types_a(v0.*i8, v1.[i8; 2], v2.[*i8; 2], v3.[[i8; 2]; 2], v4.@foo, v5.*@bar) -> i8 {
block0:
v6.i8 = mload v0 i8;
v7.i8 = call %f_b v6 1.i8;
call %f_outer;
return v6;
}
Loading

0 comments on commit e28e898

Please sign in to comment.