Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

types: use slab allocator for type bounds #231

Merged
merged 10 commits into from
Jul 4, 2024
6 changes: 3 additions & 3 deletions jets-bench/benches/elements/data_structures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
use bitcoin::secp256k1;
use elements::Txid;
use rand::{thread_rng, RngCore};
pub use simplicity::hashes::sha256;
use simplicity::{
bitcoin, elements, hashes::Hash, hex::FromHex, types::Type, BitIter, Error, Value,
bitcoin, elements, hashes::Hash, hex::FromHex, types::{self, Type}, BitIter, Error, Value,
};
use std::sync::Arc;

Expand Down Expand Up @@ -57,7 +56,8 @@ pub fn var_len_buf_from_slice(v: &[u8], mut n: usize) -> Result<Arc<Value>, Erro
assert!(n < 16);
assert!(v.len() < (1 << (n + 1)));
let mut iter = BitIter::new(v.iter().copied());
let types = Type::powers_of_two(n); // size n + 1
let ctx = types::Context::new();
let types = Type::powers_of_two(&ctx, n); // size n + 1
let mut res = None;
while n > 0 {
let v = if v.len() >= (1 << (n + 1)) {
Expand Down
6 changes: 3 additions & 3 deletions jets-bench/benches/elements/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ impl ElementsBenchEnvType {
}

fn jet_arrow(jet: Elements) -> (Arc<types::Final>, Arc<types::Final>) {
let src_ty = jet.source_ty().to_type().final_data().unwrap();
let tgt_ty = jet.target_ty().to_type().final_data().unwrap();
let src_ty = jet.source_ty().to_final();
let tgt_ty = jet.target_ty().to_final();
(src_ty, tgt_ty)
}

Expand Down Expand Up @@ -302,7 +302,7 @@ fn bench(c: &mut Criterion) {
let keypair = bitcoin::key::Keypair::new(&secp_ctx, &mut thread_rng());
let xpk = bitcoin::key::XOnlyPublicKey::from_keypair(&keypair);

let msg = bitcoin::secp256k1::Message::from_slice(&rand::random::<[u8; 32]>()).unwrap();
let msg = bitcoin::secp256k1::Message::from_digest_slice(&rand::random::<[u8; 32]>()).unwrap();
let sig = secp_ctx.sign_schnorr(&msg, &keypair);
let xpk_value = Value::u256_from_slice(&xpk.0.serialize());
let sig_value = Value::u512_from_slice(sig.as_ref());
Expand Down
3 changes: 2 additions & 1 deletion src/bit_encoding/bitwriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,13 @@ mod tests {
use super::*;
use crate::jet::Core;
use crate::node::CoreConstructible;
use crate::types;
use crate::ConstructNode;
use std::sync::Arc;

#[test]
fn vec() {
let program = Arc::<ConstructNode<Core>>::unit();
let program = Arc::<ConstructNode<Core>>::unit(&types::Context::new());
let _ = write_to_vec(|w| program.encode(w));
}

Expand Down
14 changes: 8 additions & 6 deletions src/bit_encoding/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::node::{
ConstructNode, CoreConstructible, DisconnectConstructible, JetConstructible, NoWitness,
WitnessConstructible,
};
use crate::types;
use crate::{BitIter, FailEntropy, Value};
use std::collections::HashSet;
use std::sync::Arc;
Expand Down Expand Up @@ -178,6 +179,7 @@ pub fn decode_expression<I: Iterator<Item = u8>, J: Jet>(
return Err(Error::TooManyNodes(len));
}

let inference_context = types::Context::new();
let mut nodes = Vec::with_capacity(len);
for _ in 0..len {
let new_node = decode_node(bits, nodes.len())?;
Expand All @@ -195,8 +197,8 @@ pub fn decode_expression<I: Iterator<Item = u8>, J: Jet>(
}

let new = match nodes[data.node.0] {
DecodeNode::Unit => Node(ArcNode::unit()),
DecodeNode::Iden => Node(ArcNode::iden()),
DecodeNode::Unit => Node(ArcNode::unit(&inference_context)),
DecodeNode::Iden => Node(ArcNode::iden(&inference_context)),
DecodeNode::InjL(i) => Node(ArcNode::injl(converted[i].get()?)),
DecodeNode::InjR(i) => Node(ArcNode::injr(converted[i].get()?)),
DecodeNode::Take(i) => Node(ArcNode::take(converted[i].get()?)),
Expand All @@ -222,16 +224,16 @@ pub fn decode_expression<I: Iterator<Item = u8>, J: Jet>(
converted[i].get()?,
&Some(Arc::clone(converted[j].get()?)),
)?),
DecodeNode::Witness => Node(ArcNode::witness(NoWitness)),
DecodeNode::Fail(entropy) => Node(ArcNode::fail(entropy)),
DecodeNode::Witness => Node(ArcNode::witness(&inference_context, NoWitness)),
DecodeNode::Fail(entropy) => Node(ArcNode::fail(&inference_context, entropy)),
DecodeNode::Hidden(cmr) => {
if !hidden_set.insert(cmr) {
return Err(Error::SharingNotMaximal);
}
Hidden(cmr)
}
DecodeNode::Jet(j) => Node(ArcNode::jet(j)),
DecodeNode::Word(ref w) => Node(ArcNode::const_word(Arc::clone(w))),
DecodeNode::Jet(j) => Node(ArcNode::jet(&inference_context, j)),
DecodeNode::Word(ref w) => Node(ArcNode::const_word(&inference_context, Arc::clone(w))),
};
converted.push(new);
}
Expand Down
74 changes: 47 additions & 27 deletions src/human_encoding/named_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::node::{
self, Commit, CommitData, CommitNode, Converter, Inner, NoDisconnect, NoWitness, Node, Witness,
WitnessData,
};
use crate::node::{Construct, ConstructData, Constructible};
use crate::node::{Construct, ConstructData, Constructible as _, CoreConstructible as _};
use crate::types;
use crate::types::arrow::{Arrow, FinalArrow};
use crate::{encode, Value, WitnessNode};
Expand Down Expand Up @@ -116,6 +116,7 @@ impl<J: Jet> NamedCommitNode<J> {
struct Populator<'a, J: Jet> {
witness_map: &'a HashMap<Arc<str>, Arc<Value>>,
disconnect_map: &'a HashMap<Arc<str>, Arc<NamedCommitNode<J>>>,
inference_context: types::Context,
phantom: PhantomData<J>,
}

Expand Down Expand Up @@ -153,17 +154,16 @@ impl<J: Jet> NamedCommitNode<J> {
// Like witness nodes (see above), disconnect nodes may be pruned later.
// The finalization will detect missing branches and throw an error.
let maybe_commit = self.disconnect_map.get(hole_name);
// FIXME: Recursive call of to_witness_node
// We cannot introduce a stack
// because we are implementing methods of the trait Converter
// which are used Marker::convert().
// FIXME: recursive call to convert
// We cannot introduce a stack because we are implementing the Converter
// trait and do not have access to the actual algorithm used for conversion
// in order to save its state.
//
// OTOH, if a user writes a program with so many disconnected expressions
// that there is a stack overflow, it's his own fault :)
// This would fail in a fuzz test.
let witness = maybe_commit.map(|commit| {
commit.to_witness_node(self.witness_map, self.disconnect_map)
});
// This may fail in a fuzz test.
let witness = maybe_commit
.map(|commit| commit.convert::<InternalSharing, _, _>(self).unwrap());
Ok(witness)
}
}
Expand All @@ -181,13 +181,15 @@ impl<J: Jet> NamedCommitNode<J> {
let inner = inner
.map(|node| node.cached_data())
.map_witness(|maybe_value| maybe_value.clone());
Ok(WitnessData::from_inner(inner).expect("types are already finalized"))
Ok(WitnessData::from_inner(&self.inference_context, inner)
.expect("types are already finalized"))
}
}

self.convert::<InternalSharing, _, _>(&mut Populator {
witness_map: witness,
disconnect_map: disconnect,
inference_context: types::Context::new(),
phantom: PhantomData,
})
.unwrap()
Expand Down Expand Up @@ -245,13 +247,15 @@ pub struct NamedConstructData<J> {
impl<J: Jet> NamedConstructNode<J> {
/// Construct a named construct node from parts.
pub fn new(
inference_context: &types::Context,
name: Arc<str>,
position: Position,
user_source_types: Arc<[types::Type]>,
user_target_types: Arc<[types::Type]>,
inner: node::Inner<Arc<Self>, J, Arc<Self>, WitnessOrHole>,
) -> Result<Self, types::Error> {
let construct_data = ConstructData::from_inner(
inference_context,
inner
.as_ref()
.map(|data| &data.cached_data().internal)
Expand Down Expand Up @@ -295,6 +299,11 @@ impl<J: Jet> NamedConstructNode<J> {
self.cached_data().internal.arrow()
}

/// Accessor for the node's type inference context.
pub fn inference_context(&self) -> &types::Context {
self.cached_data().internal.inference_context()
}

/// Finalizes the types of the underlying [`crate::ConstructNode`].
pub fn finalize_types_main(&self) -> Result<Arc<NamedCommitNode<J>>, ErrorSet> {
self.finalize_types_inner(true)
Expand Down Expand Up @@ -386,17 +395,23 @@ impl<J: Jet> NamedConstructNode<J> {
.map_disconnect(|_| &NoDisconnect)
.copy_witness();

let ctx = data.node.inference_context();

if !self.for_main {
// For non-`main` fragments, treat the ascriptions as normative, and apply them
// before finalizing the type.
let arrow = data.node.arrow();
for ty in data.node.cached_data().user_source_types.as_ref() {
if let Err(e) = arrow.source.unify(ty, "binding source type annotation") {
if let Err(e) =
ctx.unify(&arrow.source, ty, "binding source type annotation")
{
self.errors.add(data.node.position(), e);
}
}
for ty in data.node.cached_data().user_target_types.as_ref() {
if let Err(e) = arrow.target.unify(ty, "binding target type annotation") {
if let Err(e) =
ctx.unify(&arrow.target, ty, "binding target type annotation")
{
self.errors.add(data.node.position(), e);
}
}
Expand All @@ -413,15 +428,19 @@ impl<J: Jet> NamedConstructNode<J> {
if self.for_main {
// For `main`, only apply type ascriptions *after* inference has completely
// determined the type.
let source_ty = types::Type::complete(Arc::clone(&commit_data.arrow().source));
let source_ty =
types::Type::complete(ctx, Arc::clone(&commit_data.arrow().source));
for ty in data.node.cached_data().user_source_types.as_ref() {
if let Err(e) = source_ty.unify(ty, "binding source type annotation") {
if let Err(e) = ctx.unify(&source_ty, ty, "binding source type annotation")
{
self.errors.add(data.node.position(), e);
}
}
let target_ty = types::Type::complete(Arc::clone(&commit_data.arrow().target));
let target_ty =
types::Type::complete(ctx, Arc::clone(&commit_data.arrow().target));
for ty in data.node.cached_data().user_target_types.as_ref() {
if let Err(e) = target_ty.unify(ty, "binding target type annotation") {
if let Err(e) = ctx.unify(&target_ty, ty, "binding target type annotation")
{
self.errors.add(data.node.position(), e);
}
}
Expand All @@ -442,22 +461,23 @@ impl<J: Jet> NamedConstructNode<J> {
};

if for_main {
let unit_ty = types::Type::unit();
let ctx = self.inference_context();
let unit_ty = types::Type::unit(ctx);
if self.cached_data().user_source_types.is_empty() {
if let Err(e) = self
.arrow()
.source
.unify(&unit_ty, "setting root source to unit")
{
if let Err(e) = ctx.unify(
&self.arrow().source,
&unit_ty,
"setting root source to unit",
) {
finalizer.errors.add(self.position(), e);
}
}
if self.cached_data().user_target_types.is_empty() {
if let Err(e) = self
.arrow()
.target
.unify(&unit_ty, "setting root source to unit")
{
if let Err(e) = ctx.unify(
&self.arrow().target,
&unit_ty,
"setting root target to unit",
) {
finalizer.errors.add(self.position(), e);
}
}
Expand Down
25 changes: 18 additions & 7 deletions src/human_encoding/parse/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,25 @@ pub enum Type {

impl Type {
/// Convert to a Simplicity type
pub fn reify(self) -> types::Type {
pub fn reify(self, ctx: &types::Context) -> types::Type {
match self {
Type::Name(s) => types::Type::free(s),
Type::One => types::Type::unit(),
Type::Two => types::Type::sum(types::Type::unit(), types::Type::unit()),
Type::Product(left, right) => types::Type::product(left.reify(), right.reify()),
Type::Sum(left, right) => types::Type::sum(left.reify(), right.reify()),
Type::TwoTwoN(n) => types::Type::two_two_n(n as usize), // cast OK as we are only using tiny numbers
Type::Name(s) => types::Type::free(ctx, s),
Type::One => types::Type::unit(ctx),
Type::Two => {
let unit_ty = types::Type::unit(ctx);
types::Type::sum(ctx, unit_ty.shallow_clone(), unit_ty)
}
Type::Product(left, right) => {
let left = left.reify(ctx);
let right = right.reify(ctx);
types::Type::product(ctx, left, right)
}
Type::Sum(left, right) => {
let left = left.reify(ctx);
let right = right.reify(ctx);
types::Type::sum(ctx, left, right)
}
Type::TwoTwoN(n) => types::Type::two_two_n(ctx, n as usize), // cast OK as we are only using tiny numbers
}
}
}
Expand Down
8 changes: 5 additions & 3 deletions src/human_encoding/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod ast;
use crate::dag::{Dag, DagLike, InternalSharing};
use crate::jet::Jet;
use crate::node;
use crate::types::Type;
use crate::types::{self, Type};
use std::collections::HashMap;
use std::mem;
use std::sync::atomic::{AtomicUsize, Ordering};
Expand Down Expand Up @@ -181,6 +181,7 @@ pub fn parse<J: Jet + 'static>(
program: &str,
) -> Result<HashMap<Arc<str>, Arc<NamedCommitNode<J>>>, ErrorSet> {
let mut errors = ErrorSet::new();
let inference_context = types::Context::new();
// **
// Step 1: Read expressions into HashMap, checking for dupes and illegal names.
// **
Expand All @@ -205,10 +206,10 @@ pub fn parse<J: Jet + 'static>(
}
}
if let Some(ty) = line.arrow.0 {
entry.add_source_type(ty.reify());
entry.add_source_type(ty.reify(&inference_context));
}
if let Some(ty) = line.arrow.1 {
entry.add_target_type(ty.reify());
entry.add_target_type(ty.reify(&inference_context));
}
}

Expand Down Expand Up @@ -485,6 +486,7 @@ pub fn parse<J: Jet + 'static>(
.unwrap_or_else(|| Arc::from(namer.assign_name(inner.as_ref()).as_str()));

let node = NamedConstructNode::new(
&inference_context,
Arc::clone(&name),
data.node.position,
Arc::clone(&data.node.user_source_types),
Expand Down
3 changes: 2 additions & 1 deletion src/jet/elements/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::Arc;
use crate::jet::elements::{ElementsEnv, ElementsUtxo};
use crate::jet::Elements;
use crate::node::{ConstructNode, JetConstructible};
use crate::types;
use crate::{BitMachine, Cmr, Value};
use elements::secp256k1_zkp::Tweak;
use elements::taproot::ControlBlock;
Expand Down Expand Up @@ -99,7 +100,7 @@ fn test_ffi_env() {
BlockHash::all_zeros(),
);

let prog = Arc::<ConstructNode<_>>::jet(Elements::LockTime);
let prog = Arc::<ConstructNode<_>>::jet(&types::Context::new(), Elements::LockTime);
assert_eq!(
BitMachine::test_exec(prog, &env).expect("executing"),
Value::u32(100),
Expand Down
Loading
Loading