Skip to content

Commit

Permalink
Remove superset for performance
Browse files Browse the repository at this point in the history
  • Loading branch information
Rigidity committed Aug 2, 2024
1 parent f0447a1 commit ae9fc62
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 87 deletions.
99 changes: 30 additions & 69 deletions crates/rue-typing/src/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ pub enum Comparison {
Equal,
Assignable,
Castable,
Superset,
Incompatible,
NotEqual,
}

pub(crate) struct ComparisonContext<'a> {
Expand Down Expand Up @@ -81,7 +80,7 @@ pub(crate) fn compare_type(
ctx.inferred.last_mut().unwrap().insert(rhs, lhs);
Comparison::Assignable
} else {
Comparison::Incompatible
Comparison::NotEqual
}
}

Expand All @@ -96,7 +95,7 @@ pub(crate) fn compare_type(
} else if lhs == rhs {
Comparison::Equal
} else {
Comparison::Incompatible
Comparison::NotEqual
}
}

Expand All @@ -114,7 +113,7 @@ pub(crate) fn compare_type(
| (Type::False, Type::Nil)
| (Type::Nil, Type::False) => Comparison::Castable,

// These are a superset since the right hand side is castable to the left hand side.
// These are incompatible since the structure differs.
(
Type::Any,
Type::Bytes
Expand Down Expand Up @@ -149,10 +148,9 @@ pub(crate) fn compare_type(
| Type::Pair(..)
| Type::Callable(..),
Type::Never,
) => Comparison::Superset,

// These are incompatible since the structure differs.
(
)
| (Type::Bytes32 | Type::PublicKey, Type::Value(..))
| (
Type::Pair(..),
Type::Bytes
| Type::Bytes32
Expand Down Expand Up @@ -192,14 +190,14 @@ pub(crate) fn compare_type(
| Type::Value(..)
| Type::Pair(..),
Type::Callable(..),
) => Comparison::Incompatible,
) => Comparison::NotEqual,

// Value is a subtype of Int, so it's castable to Bytes32 if it's 32 bytes long.
(Type::Value(value), Type::Bytes32) => {
if bigint_to_bytes(value.clone()).len() == 32 {
Comparison::Castable
} else {
Comparison::Incompatible
Comparison::NotEqual
}
}

Expand All @@ -208,44 +206,26 @@ pub(crate) fn compare_type(
if bigint_to_bytes(value.clone()).len() == 48 {
Comparison::Castable
} else {
Comparison::Incompatible
Comparison::NotEqual
}
}

// Bytes32 is a superset of Value only if the value is 32 bytes long.
(Type::Bytes32, Type::Value(value)) => {
if bigint_to_bytes(value.clone()).len() == 32 {
Comparison::Superset
} else {
Comparison::Incompatible
}
}

// PublicKey is a superset of Value only if the value is 48 bytes long.
(Type::PublicKey, Type::Value(value)) => {
if bigint_to_bytes(value.clone()).len() == 48 {
Comparison::Superset
} else {
Comparison::Incompatible
}
}

// Nil and False are supersets of Value only if the value is zero.
// Nil and False are castable to Value only if the value is zero.
(Type::Nil | Type::False, Type::Value(value))
| (Type::Value(value), Type::Nil | Type::False) => {
if value == &BigInt::ZERO {
Comparison::Castable
} else {
Comparison::Incompatible
Comparison::NotEqual
}
}

// True is a superset of Value only if the value is one.
// True is castable to Value only if the value is one.
(Type::True, Type::Value(value)) | (Type::Value(value), Type::True) => {
if value == &BigInt::one() {
Comparison::Castable
} else {
Comparison::Incompatible
Comparison::NotEqual
}
}

Expand All @@ -254,7 +234,7 @@ pub(crate) fn compare_type(
if lhs == rhs {
Comparison::Equal
} else {
Comparison::Incompatible
Comparison::NotEqual
}
}

Expand All @@ -270,29 +250,18 @@ pub(crate) fn compare_type(
let items = items.clone();
let mut result = Comparison::Assignable;

let mut any_castable = false;

for item in items {
let cmp = compare_type(db, item, rhs, ctx);
result = max(result, cmp);

if compare_type(db, rhs, item, ctx) <= Comparison::Castable {
any_castable = true;
}
}

if result == Comparison::Incompatible && any_castable {
Comparison::Superset
} else {
result
}
result
}

// Anything can be assigned to a union so long as it's assignable to at least one of the items.
(_, Type::Union(items)) => {
let items = items.clone();
let mut result = Comparison::Incompatible;
let mut any_incompatible = false;
let mut result = Comparison::NotEqual;

for item in &items {
if matches!(db.get_recursive(*item), Type::Never) {
Expand All @@ -301,17 +270,9 @@ pub(crate) fn compare_type(

let cmp = compare_type(db, lhs, *item, ctx);
result = min(result, cmp);

if cmp == Comparison::Incompatible {
any_incompatible = true;
}
}

if any_incompatible && result == Comparison::Superset {
Comparison::Incompatible
} else {
max(result, Comparison::Assignable)
}
max(result, Comparison::Assignable)
}

// Resolve the alias to the type that it's pointing to.
Expand Down Expand Up @@ -421,7 +382,7 @@ mod tests {
fn test_compare_bytes_bytes32() {
let db = TypeSystem::new();
let types = db.std();
assert_eq!(db.compare(types.bytes, types.bytes32), Comparison::Superset);
assert_eq!(db.compare(types.bytes, types.bytes32), Comparison::NotEqual);
}

#[test]
Expand All @@ -440,7 +401,7 @@ mod tests {
let types = db.std();
assert_eq!(
db.compare(types.bytes, types.public_key),
Comparison::Superset
Comparison::NotEqual
);
}

Expand All @@ -460,7 +421,7 @@ mod tests {
let types = db.std();
assert_eq!(
db.compare(types.bytes32, types.public_key),
Comparison::Incompatible
Comparison::NotEqual
);
}

Expand All @@ -470,7 +431,7 @@ mod tests {
let types = db.std();
assert_eq!(
db.compare(types.public_key, types.bytes32),
Comparison::Incompatible
Comparison::NotEqual
);
}

Expand All @@ -485,7 +446,7 @@ mod tests {
fn test_compare_any_bytes() {
let db = TypeSystem::new();
let types = db.std();
assert_eq!(db.compare(types.any, types.bytes), Comparison::Superset);
assert_eq!(db.compare(types.any, types.bytes), Comparison::NotEqual);
}

#[test]
Expand All @@ -499,7 +460,7 @@ mod tests {
fn test_compare_any_bytes32() {
let db = TypeSystem::new();
let types = db.std();
assert_eq!(db.compare(types.any, types.bytes32), Comparison::Superset);
assert_eq!(db.compare(types.any, types.bytes32), Comparison::NotEqual);
}

#[test]
Expand Down Expand Up @@ -574,7 +535,7 @@ mod tests {
let types = db.std();
let lhs = db.alloc(Type::Pair(types.int, types.public_key));
let rhs = db.alloc(Type::Pair(types.bytes, types.nil));
assert_eq!(db.compare(lhs, rhs), Comparison::Incompatible);
assert_eq!(db.compare(lhs, rhs), Comparison::NotEqual);
}

#[test]
Expand Down Expand Up @@ -636,7 +597,7 @@ mod tests {
);
assert_eq!(
db.compare_with_generics(types.any, generic, &mut stack, infer),
Comparison::Superset
Comparison::NotEqual
);
}
}
Expand Down Expand Up @@ -689,7 +650,7 @@ mod tests {

let pair = db.alloc(Type::Pair(types.int, types.public_key));
let union = db.alloc(Type::Union(vec![types.bytes32, pair, types.nil]));
assert_eq!(db.compare(union, types.bytes), Comparison::Incompatible);
assert_eq!(db.compare(union, types.bytes), Comparison::NotEqual);
}

#[test]
Expand All @@ -699,7 +660,7 @@ mod tests {

let pair = db.alloc(Type::Pair(types.int, types.public_key));
let union = db.alloc(Type::Union(vec![types.bytes, pair]));
assert_eq!(db.compare(union, types.bytes), Comparison::Superset);
assert_eq!(db.compare(union, types.bytes), Comparison::NotEqual);
}

#[test]
Expand All @@ -718,7 +679,7 @@ mod tests {

let pair = db.alloc(Type::Pair(types.int, types.public_key));
let union = db.alloc(Type::Union(vec![types.bytes32, pair, types.nil]));
assert_eq!(db.compare(types.bytes, union), Comparison::Incompatible);
assert_eq!(db.compare(types.bytes, union), Comparison::NotEqual);
}

#[test]
Expand Down Expand Up @@ -831,7 +792,7 @@ mod tests {
let mut db = TypeSystem::new();
let types = db.std();
let generic = db.alloc(Type::Generic);
assert_eq!(db.compare(types.int, generic), Comparison::Incompatible);
assert_eq!(db.compare(types.int, generic), Comparison::NotEqual);
assert_eq!(db.compare(generic, generic), Comparison::Equal);
}

Expand Down
2 changes: 1 addition & 1 deletion crates/rue-typing/src/difference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,6 @@ mod tests {
let non_nil = db.difference(list, types.nil);

assert_eq!(db.compare(non_nil, list), Comparison::Assignable);
assert_eq!(db.compare(types.nil, non_nil), Comparison::Incompatible);
assert_eq!(db.compare(types.nil, non_nil), Comparison::NotEqual);
}
}
19 changes: 2 additions & 17 deletions crates/rue-typing/src/type_system.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::time::{Duration, Instant};

use id_arena::{Arena, Id};

use crate::{
Expand Down Expand Up @@ -179,9 +177,7 @@ impl TypeSystem {
substitution_stack: &mut Vec<HashMap<TypeId, TypeId>>,
infer_generics: bool,
) -> Comparison {
let start = Instant::now();

let result = compare_type(
compare_type(
self,
lhs,
rhs,
Expand All @@ -192,18 +188,7 @@ impl TypeSystem {
inferred: substitution_stack,
infer_generics,
},
);
let duration = start.elapsed();
if duration > Duration::from_millis(1) {
println!(
"\n\n\n{duration:?} between {} => {}",
self.stringify(lhs),
self.stringify(rhs)
);
println!("LHS {}", self.debug(lhs));
println!("RHS {}", self.debug(rhs));
}
result
)
}

pub fn substitute(
Expand Down

0 comments on commit ae9fc62

Please sign in to comment.