Skip to content

Commit

Permalink
types: pull unify and bind into inference context
Browse files Browse the repository at this point in the history
Pulls the unify and bind methods out of Type and BoundMutex and
implement them on a new private LockedContext struct. This locks the
entire inference context for the duration of a bind or unify operation,
and because it only locks inside of non-recursive methods, it is
impossible to deadlock.

This is "API-only" in the sense that the actual type bounds continue to
be represented by free-floating Arcs, but it has a semantic change in
that binds and unifications now happen atomically (due to the
continuously held lock on the context) which fixes a likely class of
bugs wherein if you try to unify related variables from multiple threads
at once, the old code probably would due weird things, due to the very
local locking and total lack of other synchronization.

The next commit will finally delete BoundMutex, move the bounds into the
actual context object, and you will see the point of all these massive
code lifts :).
  • Loading branch information
apoelstra committed Jul 3, 2024
1 parent 70ba9d2 commit 598cff3
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 103 deletions.
116 changes: 108 additions & 8 deletions src/types/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
//!
use std::fmt;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, MutexGuard};

use crate::dag::{Dag, DagLike};

use super::bound_mutex::BoundMutex;
use super::{Bound, Error, Final, Type};
use super::{Bound, CompleteBound, Error, Final, Type};

/// Type inference context, or handle to a context.
///
Expand Down Expand Up @@ -156,14 +156,24 @@ impl Context {
///
/// Fails if the type has an existing incompatible bound.
pub fn bind(&self, existing: &Type, new: Bound, hint: &'static str) -> Result<(), Error> {
existing.bind(new, hint)
let existing_root = existing.bound.root();
let lock = self.lock();
lock.bind(existing_root, new, hint)
}

/// Unify the type with another one.
///
/// Fails if the bounds on the two types are incompatible
pub fn unify(&self, ty1: &Type, ty2: &Type, hint: &'static str) -> Result<(), Error> {
ty1.unify(ty2, hint)
let lock = self.lock();
lock.unify(ty1, ty2, hint)
}

/// Locks the underlying slab mutex.
fn lock(&self) -> LockedContext {
LockedContext {
slab: self.slab.lock().unwrap(),
}
}
}

Expand All @@ -184,10 +194,6 @@ impl BoundRef {
);
}

pub fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> {
self.index.bind(bound, hint)
}

/// Creates an "occurs-check ID" which is just a copy of the [`BoundRef`]
/// with `PartialEq` and `Eq` implemented in terms of underlying pointer
/// equality.
Expand Down Expand Up @@ -239,3 +245,97 @@ pub struct OccursCheckId {
// now we set it to an Arc<BoundMutex> to preserve semantics.
index: *const BoundMutex,
}

/// Structure representing an inference context with its slab allocator mutex locked.
///
/// This type is never exposed outside of this module and should only exist
/// ephemerally within function calls into this module.
struct LockedContext<'ctx> {
slab: MutexGuard<'ctx, Vec<Bound>>,
}

impl<'ctx> LockedContext<'ctx> {
/// Unify the type with another one.
///
/// Fails if the bounds on the two types are incompatible
fn unify(&self, existing: &Type, other: &Type, hint: &'static str) -> Result<(), Error> {
existing.bound.unify(&other.bound, |x_bound, y_bound| {
self.bind(x_bound, y_bound.index.get(), hint)
})
}

fn bind(&self, existing: BoundRef, new: Bound, hint: &'static str) -> Result<(), Error> {
let existing_bound = existing.index.get();
let bind_error = || Error::Bind {
existing_bound: existing_bound.shallow_clone(),
new_bound: new.shallow_clone(),
hint,
};

match (&existing_bound, &new) {
// Binding a free type to anything is a no-op
(_, Bound::Free(_)) => Ok(()),
// Free types are simply dropped and replaced by the new bound
(Bound::Free(_), _) => {
// Free means non-finalized, so set() is ok.
existing.index.set(new);
Ok(())
}
// Binding complete->complete shouldn't ever happen, but if so, we just
// compare the two types and return a pass/fail
(Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => {
if existing_final == new_final {
Ok(())
} else {
Err(bind_error())
}
}
// Binding an incomplete to a complete type requires recursion.
(Bound::Complete(complete), incomplete) | (incomplete, Bound::Complete(complete)) => {
match (complete.bound(), incomplete) {
// A unit might match a Bound::Free(..) or a Bound::Complete(..),
// and both cases were handled above. So this is an error.
(CompleteBound::Unit, _) => Err(bind_error()),
(
CompleteBound::Product(ref comp1, ref comp2),
Bound::Product(ref ty1, ref ty2),
)
| (CompleteBound::Sum(ref comp1, ref comp2), Bound::Sum(ref ty1, ref ty2)) => {
let bound1 = ty1.bound.root();
let bound2 = ty2.bound.root();
self.bind(bound1, Bound::Complete(Arc::clone(comp1)), hint)?;
self.bind(bound2, Bound::Complete(Arc::clone(comp2)), hint)
}
_ => Err(bind_error()),
}
}
(Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2))
| (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => {
self.unify(x1, y1, hint)?;
self.unify(x2, y2, hint)?;
// This type was not complete, but it may be after unification, giving us
// an opportunity to finaliize it. We do this eagerly to make sure that
// "complete" (no free children) is always equivalent to "finalized" (the
// bound field having variant Bound::Complete(..)), even during inference.
//
// It also gives the user access to more information about the type,
// prior to finalization.
if let (Some(data1), Some(data2)) = (y1.final_data(), y2.final_data()) {
existing
.index
.set(Bound::Complete(if let Bound::Sum(..) = existing_bound {
Final::sum(data1, data2)
} else {
Final::product(data1, data2)
}));
}
Ok(())
}
(x, y) => Err(Error::Bind {
existing_bound: x.shallow_clone(),
new_bound: y.shallow_clone(),
hint,
}),
}
}
}
97 changes: 2 additions & 95 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ impl fmt::Display for Error {
impl std::error::Error for Error {}

mod bound_mutex {
use super::{Bound, CompleteBound, Error, Final};
use super::Bound;
use std::fmt;
use std::sync::{Arc, Mutex};
use std::sync::Mutex;

/// Source or target type of a Simplicity expression
pub struct BoundMutex {
Expand Down Expand Up @@ -184,81 +184,6 @@ mod bound_mutex {
);
*lock = new;
}

pub fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> {
let existing_bound = self.get();
let bind_error = || Error::Bind {
existing_bound: existing_bound.shallow_clone(),
new_bound: bound.shallow_clone(),
hint,
};

match (&existing_bound, &bound) {
// Binding a free type to anything is a no-op
(_, Bound::Free(_)) => Ok(()),
// Free types are simply dropped and replaced by the new bound
(Bound::Free(_), _) => {
// Free means non-finalized, so set() is ok.
self.set(bound);
Ok(())
}
// Binding complete->complete shouldn't ever happen, but if so, we just
// compare the two types and return a pass/fail
(Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => {
if existing_final == new_final {
Ok(())
} else {
Err(bind_error())
}
}
// Binding an incomplete to a complete type requires recursion.
(Bound::Complete(complete), incomplete)
| (incomplete, Bound::Complete(complete)) => {
match (complete.bound(), incomplete) {
// A unit might match a Bound::Free(..) or a Bound::Complete(..),
// and both cases were handled above. So this is an error.
(CompleteBound::Unit, _) => Err(bind_error()),
(
CompleteBound::Product(ref comp1, ref comp2),
Bound::Product(ref ty1, ref ty2),
)
| (
CompleteBound::Sum(ref comp1, ref comp2),
Bound::Sum(ref ty1, ref ty2),
) => {
ty1.bind(Bound::Complete(Arc::clone(comp1)), hint)?;
ty2.bind(Bound::Complete(Arc::clone(comp2)), hint)
}
_ => Err(bind_error()),
}
}
(Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2))
| (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => {
x1.unify(y1, hint)?;
x2.unify(y2, hint)?;
// This type was not complete, but it may be after unification, giving us
// an opportunity to finaliize it. We do this eagerly to make sure that
// "complete" (no free children) is always equivalent to "finalized" (the
// bound field having variant Bound::Complete(..)), even during inference.
//
// It also gives the user access to more information about the type,
// prior to finalization.
if let (Some(data1), Some(data2)) = (y1.final_data(), y2.final_data()) {
self.set(Bound::Complete(if let Bound::Sum(..) = bound {
Final::sum(data1, data2)
} else {
Final::product(data1, data2)
}));
}
Ok(())
}
(x, y) => Err(Error::Bind {
existing_bound: x.shallow_clone(),
new_bound: y.shallow_clone(),
hint,
}),
}
}
}
}

Expand Down Expand Up @@ -391,24 +316,6 @@ impl Type {
self.clone()
}

/// Binds the type to a given bound. If this fails, attach the provided
/// hint to the error.
///
/// Fails if the type has an existing incompatible bound.
fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> {
let root = self.bound.root();
root.bind(bound, hint)
}

/// Unify the type with another one.
///
/// Fails if the bounds on the two types are incompatible
fn unify(&self, other: &Self, hint: &'static str) -> Result<(), Error> {
self.bound.unify(&other.bound, |x_bound, y_bound| {
x_bound.bind(self.ctx.get(y_bound), hint)
})
}

/// Accessor for this type's bound
pub fn bound(&self) -> Bound {
self.ctx.get(&self.bound.root())
Expand Down

0 comments on commit 598cff3

Please sign in to comment.