Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not rely on any explicitness information for inference and constraint resolution #174

Merged
merged 22 commits into from
May 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5c9ff62
Merge infer.preserving and infer.apply
umazalakain May 7, 2021
67f8e15
Make FTVs to preserve local to solveOne (add bound dep lambda ids)
umazalakain May 13, 2021
27bca59
Don't rely on explicitness info for constraint resolution (minus dept…
umazalakain May 13, 2021
689b76b
Use products instead of a custom case class
umazalakain May 13, 2021
c4e6334
Make preserve more granular, add newly created explicit identifiers
umazalakain May 14, 2021
d6ed8ba
Add all explicit type variables to the set of variables to preserve
umazalakain May 15, 2021
786c59d
Use a global preserve set: type identifiers are unique up to their type
umazalakain May 16, 2021
f7c0533
Remove the explicitly dependent constraint solving
umazalakain May 16, 2021
4a5a0c7
Remove typeEnv from constraint gathering
umazalakain May 16, 2021
27450be
Gather FTV in identifier types as well
umazalakain May 16, 2021
5a94837
Inspect type identifiers in datatypes and nats in fragment kinds
umazalakain May 16, 2021
97c89d0
Keep set of variables to substitute instead of variables to preserve
umazalakain May 16, 2021
7552689
Exclude matrix layout and fragment kind identifiers from vars to close
umazalakain May 16, 2021
c8b666a
Preserve order while assuring uniqueness
umazalakain May 27, 2021
2b55f32
Replace getFTVsRec (uses expl info) with IsClosedForm in makeClosed
umazalakain May 16, 2021
daafb0f
Merge branch 'master' into only-preserve
umazalakain May 27, 2021
972f808
Revert "Keep set of variables to substitute instead of variables to p…
umazalakain May 27, 2021
c674e0c
Use IsClosedForm instead of getFTVs in TopLevel
umazalakain May 27, 2021
24aaf8f
Use IsClosedForm to extract FTVs from assertions, remove getFTVs
umazalakain May 27, 2021
7f1719d
Do not rely on explicitness in collectPreserve
umazalakain May 27, 2021
288eb1e
Expose OrderedSet in freeVars
umazalakain May 28, 2021
264abd5
Do not skip matrix layout and fragment kind identifiers in collectPre…
umazalakain May 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/main/scala/rise/core/DSL/TopLevel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import Type.impl
import util.monads._
import rise.core.traverse._
import rise.core.types._
import rise.core.{DSL, Expr, Primitive}
import rise.core.{DSL, Expr, IsClosedForm, Primitive}

final case class TopLevel(e: Expr, inst: Solution = Solution())(
override val t: Type = e.t
Expand All @@ -23,7 +23,7 @@ final case class TopLevel(e: Expr, inst: Solution = Solution())(
object TopLevel {
private def instantiate(t: Type): Solution = {
import scala.collection.immutable.Map
infer.getFTVs(t).foldLeft(Solution())((subs, ftv) =>
IsClosedForm.varsToClose(t).foldLeft(Solution())((subs, ftv) =>
subs match {
case s@Solution(ts, ns, as, ms, fs, n2ds, n2ns, natColls) =>
ftv match {
Expand Down
98 changes: 30 additions & 68 deletions src/main/scala/rise/core/DSL/infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,17 @@ object infer {
traverse(e, Traversal(Set()))._1
}

def preservingWithEnv(e: Expr, env: Map[String, Type], preserve: Set[Kind.Identifier]): Expr = {
val (typed_e, constraints) = constrainTypes(env)(e)
val solution = Constraint.solve(constraints, preserve, Seq())(
Flags.ExplicitDependence.Off)
solution(typed_e)
}
type ExprEnv = Map[String, Type]

// TODO: Get rid of TypeAssertion and deprecate, instead evaluate !: in place and use `preserving` directly
private [DSL] def apply(e: Expr,
printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off,
explDep: Flags.ExplicitDependence = Flags.ExplicitDependence.Off): Expr = {
def apply(e: Expr, exprEnv: ExprEnv = Map(), typeEnv : Set[Kind.Identifier] = Set(),
printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off): Expr = {
// TODO: Get rid of TypeAssertion and deprecate, instead evaluate !: in place and use `preserving` directly
// Collect FTVs in assertions and opaques; transform assertions into annotations
val (preserve, e_wo_assertions) = traverse(e, collectPreserve)
infer.preserving(e_wo_assertions, preserve, printFlag, explDep)
}

private [DSL] def preserving(wo_assertions: Expr, preserve : Set[Kind.Identifier],
printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off,
explDep: Flags.ExplicitDependence = Flags.ExplicitDependence.Off): Expr = {
val (e_preserve, e_wo_assertions) = traverse(e, collectPreserve)
// Collect constraints
val (typed_e, constraints) = constrainTypes(Map())(wo_assertions)
val (typed_e, constraints) = constrainTypes(exprEnv)(e_wo_assertions)
// Solve constraints while preserving the FTVs in preserve
val solution = Constraint.solve(constraints, preserve, Seq())(explDep)
val solution = Constraint.solve(constraints, e_preserve ++ typeEnv, Seq())
// Apply the solution
val res = traverse(typed_e, Visitor(solution))
if (printFlag == Flags.PrintTypesAndTypeHoles.On) {
Expand All @@ -83,74 +71,55 @@ object infer {
}
}

val FTVGathering = new PureAccumulatorTraversal[Seq[Kind.Identifier]] {
override val accumulator = SeqMonoid
override def typeIdentifier[I <: Kind.Identifier]: VarType => I => Pair[I] = _ => {
case i: Kind.Explicitness => accumulate(if (!i.isExplicit) Seq(i) else Seq())(i.asInstanceOf[I])
case i => accumulate(Seq(i))(i)
}
override def nat: Nat => Pair[Nat] = ae => {
val ftvs = mutable.ListBuffer[Kind.Identifier]()
val r = ae.visitAndRebuild({
case i: NatIdentifier if !i.isExplicit => ftvs += i; i
case n => n
})
accumulate(ftvs.toSeq)(r)
}
}

def getFTVs(t: Type): Seq[Kind.Identifier] = {
traverse(t, FTVGathering)._1.distinct
}

def getFTVsRec(e: Expr): Seq[Kind.Identifier] = {
traverse(e, FTVGathering)._1.distinct
}

private val collectPreserve = new PureAccumulatorTraversal[Set[Kind.Identifier]] {
override val accumulator = SetMonoid

override def typeIdentifier[I <: Kind.Identifier]: VarType => I => Pair[I] = {
case Binding => i => accumulate(Set(i))(i)
case _ => return_
}

override def expr: Expr => Pair[Expr] = {
// Transform assertions into annotations, collect FTVs
case TypeAssertion(e, t) =>
val (s, e1) = expr(e).unwrap
accumulate(s ++ getFTVs(t))(TypeAnnotation(e1, t) : Expr)
val (s1, e1) = expr(e).unwrap
accumulate(s1 ++ IsClosedForm.freeVars(t).set)(TypeAnnotation(e1, t) : Expr)
// Collect FTVs
case Opaque(e, t) =>
accumulate(getFTVs(t).toSet)(Opaque(e, t) : Expr)
accumulate(IsClosedForm.freeVars(t).set)(Opaque(e, t) : Expr)
case e => super.expr(e)
}
}

private val genType : Expr => Type =
e => if (e.t == TypePlaceholder) freshTypeIdentifier else e.t
private val constrIfTyped : Type => Constraint => Seq[Constraint] =
private def ifTyped[T] : Type => T => Seq[T] =
t => c => if (t == TypePlaceholder) Nil else Seq(c)

private val constrainTypes : Map[String, Type] => Expr => (Expr, Seq[Constraint]) = env => {
def constrainTypes(exprEnv : ExprEnv) : Expr => (Expr, Seq[Constraint]) = {
case i: Identifier =>
val t = env.getOrElse(i.name,
val t = exprEnv.getOrElse(i.name,
if (i.t == TypePlaceholder) error(s"$i has no type")(Seq()) else i.t )
val c = TypeConstraint(t, i.t)
(i.setType(t), Nil :+ c)

case expr@Lambda(x, e) =>
val tx = x.setType(genType(x))
val env1 : Map[String, Type] = env + (tx.name -> tx.t)
val (te, cs) = constrainTypes(env1)(e)
val exprEnv1 = exprEnv + (tx.name -> tx.t)
val (te, cs) = constrainTypes(exprEnv1)(e)
val ft = FunType(tx.t, te.t)
val cs1 = constrIfTyped(expr.t)(TypeConstraint(expr.t, ft))
val cs1 = ifTyped(expr.t)(TypeConstraint(expr.t, ft))
(Lambda(tx, te)(ft), cs ++ cs1)

case expr@App(f, e) =>
val (tf, csF) = constrainTypes(env)(f)
val (te, csE) = constrainTypes(env)(e)
val (tf, csF) = constrainTypes(exprEnv)(f)
val (te, csE) = constrainTypes(exprEnv)(e)
val exprT = genType(expr)
val c = TypeConstraint(tf.t, FunType(te.t, exprT))
(App(tf, te)(exprT), csF ++ csE :+ c)

case expr@DepLambda(x, e) =>
val (te, csE) = constrainTypes(env)(e)
val exprT = genType(expr)
val (te, csE) = constrainTypes(exprEnv)(e)
val tf = x match {
case n: NatIdentifier =>
DepLambda[NatKind](n, te)(DepFunType[NatKind, Type](n, te.t))
Expand All @@ -161,22 +130,22 @@ object infer {
case n2n: NatToNatIdentifier =>
DepLambda[NatToNatKind](n2n, te)(DepFunType[NatToNatKind, Type](n2n, te.t))
}
val csE1 = constrIfTyped(expr.t)(TypeConstraint(expr.t, tf.t))
val csE1 = ifTyped(expr.t)(TypeConstraint(expr.t, tf.t))
(tf, csE ++ csE1)

case expr@DepApp(f, x) =>
val (tf, csF) = constrainTypes(env)(f)
val (tf, csF) = constrainTypes(exprEnv)(f)
val exprT = genType(expr)
val c = DepConstraint(tf.t, x, exprT)
(DepApp(tf, x)(exprT), csF :+ c)

case TypeAnnotation(e, t) =>
val (te, csE) = constrainTypes(env)(e)
val (te, csE) = constrainTypes(exprEnv)(e)
val c = TypeConstraint(te.t, t)
(te, csE :+ c)

case TypeAssertion(e, t) =>
val (te, csE) = constrainTypes(env)(e)
val (te, csE) = constrainTypes(exprEnv)(e)
val c = TypeConstraint(te.t, t)
(te, csE :+ c)

Expand All @@ -197,11 +166,4 @@ object infer {
override def natToData : NatToData => Pure[NatToData] = n2d => return_(sol(n2d))
override def natToNat : NatToNat => Pure[NatToNat] = n2n => return_(sol(n2n))
}
}

object inferDependent {
def apply(e: ToBeTyped[Expr],
printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off): Expr = infer(e match {
case ToBeTyped(e) => e
}, printFlag, Flags.ExplicitDependence.On)
}
}
67 changes: 56 additions & 11 deletions src/main/scala/rise/core/IsClosedForm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,32 @@ import rise.core.traverse._
import rise.core.types._

object IsClosedForm {
case class OrderedSet[T](seq : Seq[T], set : Set[T])
object OrderedSet {
def empty[T] : OrderedSet[T] = OrderedSet(Seq(), Set())
def add[T] : T => OrderedSet[T] => OrderedSet[T] = t => ts =>
if (ts.set.contains(t)) ts else OrderedSet(t +: ts.seq, ts.set + t)
def one[T] : T => OrderedSet[T] = add(_)(empty)
def append[T] : OrderedSet[T] => OrderedSet[T] => OrderedSet[T] = x => y => {
val ordered = x.seq.filter(!y.set.contains(_)) ++ y.seq
val unique = x.set ++ y.set
OrderedSet(ordered, unique)
}
}
implicit def OrderedSetMonoid[T] : Monoid[OrderedSet[T]] = new Monoid[OrderedSet[T]] {
def empty : OrderedSet[T] = OrderedSet.empty
def append : OrderedSet[T] => OrderedSet[T] => OrderedSet[T] = OrderedSet.append
}

case class Visitor(boundV: Set[Identifier], boundT: Set[Kind.Identifier])
extends PureAccumulatorTraversal[(Set[Identifier], Set[Kind.Identifier])]
extends PureAccumulatorTraversal[(OrderedSet[Identifier], OrderedSet[Kind.Identifier])]
{
override val accumulator = PairMonoid(SetMonoid, SetMonoid)
override val accumulator = PairMonoid(OrderedSetMonoid, OrderedSetMonoid)

override def identifier[I <: Identifier]: VarType => I => Pair[I] = vt => i => {
for { t2 <- `type`(i.t);
i2 <- if (vt == Reference && !boundV(i)) {
accumulate((Set(i), Set()))(i)
accumulate((OrderedSet.one(i : Identifier), OrderedSet.empty))(i)
} else {
return_(i)
}}
Expand All @@ -23,20 +40,27 @@ object IsClosedForm {

override def typeIdentifier[I <: Kind.Identifier]: VarType => I => Pair[I] = {
case Reference => i =>
if (boundT(i)) return_(i) else accumulate((Set(), Set(i)))(i)
if (boundT(i)) return_(i) else accumulate((OrderedSet.empty, OrderedSet.one(i : Kind.Identifier)))(i)
case _ => return_
}

override def nat: Nat => Pair[Nat] = n => {
val free = n.varList.foldLeft(Set[Kind.Identifier]()) {
case (free, v: NamedVar) if !boundT(NatIdentifier(v)) => free + NatIdentifier(v)
val free = n.varList.foldLeft(OrderedSet.empty[Kind.Identifier]) {
case (free, v: NamedVar) if !boundT(NatIdentifier(v)) => OrderedSet.add(NatIdentifier(v) : Kind.Identifier)(free)
case (free, _) => free
}
accumulate((Set(), free))(n)
accumulate((OrderedSet.empty, free))(n)
}

override def expr: Expr => Pair[Expr] = {
case Lambda(x, b) => this.copy(boundV = boundV + x).expr(b)
case l@Lambda(x, e) =>
// The binder's type itself might contain free type variables
val ((fVx, fTx), x1) = identifier(Binding)(x).unwrap
val ((fVe, fTe), e1) = this.copy(boundV = boundV + x1).expr(e).unwrap
val ((fVt, fTt), t1) = `type`(l.t).unwrap
val fV = OrderedSet.append(OrderedSet.append(fVx)(fVe))(fVt)
val fT = OrderedSet.append(OrderedSet.append(fTx)(fTe))(fTt)
accumulate((fV, fT))(Lambda(x1, e1)(t1): Expr)
case DepLambda(x, b) => this.copy(boundT = boundT + x).expr(b)
case e => super.expr(e)
}
Expand Down Expand Up @@ -66,11 +90,32 @@ object IsClosedForm {
}
}

def freeVars(expr: Expr): (Set[Identifier], Set[Kind.Identifier]) =
traverse(expr, Visitor(Set(), Set()))._1
def freeVars(expr: Expr): (OrderedSet[Identifier], OrderedSet[Kind.Identifier]) = {
val ((fV, fT), _) = traverse(expr, Visitor(Set(), Set()))
(fV, fT)
}

def freeVars(t: Type): OrderedSet[Kind.Identifier] = {
val ((_, ftv), _) = traverse(t, Visitor(Set(), Set()))
ftv
}

// Exclude matrix layout and fragment kind identifiers, since they cannot currently be bound
def needsClosing : Seq[Kind.Identifier] => Seq[Kind.Identifier] = _.flatMap {
case i : MatrixLayoutIdentifier => Seq()
case i : FragmentKindIdentifier => Seq()
case e => Seq(e)
}

def varsToClose(expr : Expr): (Seq[Identifier], Seq[Kind.Identifier]) = {
val (fV, fT) = freeVars(expr)
(fV.seq, needsClosing(fT.seq))
}

def varsToClose(t : Type): Seq[Kind.Identifier] = needsClosing(freeVars(t).seq)

def apply(expr: Expr): Boolean = {
val (freeV, freeT) = freeVars(expr)
val (freeV, freeT) = varsToClose(expr)
freeV.isEmpty && freeT.isEmpty
}
}
2 changes: 1 addition & 1 deletion src/main/scala/rise/core/makeClosed.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ object makeClosed {
Map[AddressSpaceIdentifier, AddressSpace],
Map[NatToDataIdentifier, NatToData]) = (Map(), Map(), Map(), Map())

val (expr, (ts, ns, as, n2ds)) = DSL.infer.getFTVsRec(e).foldLeft((e, emptySubs))((acc, ftv) => acc match {
val (expr, (ts, ns, as, n2ds)) = IsClosedForm.varsToClose(e)._2.foldLeft((e, emptySubs))((acc, ftv) => acc match {
case (expr, (ts, ns, as, n2ds)) => ftv match {
case i: TypeIdentifier =>
val dt = DataTypeIdentifier(freshName("dt"), isExplicit = true)
Expand Down
48 changes: 31 additions & 17 deletions src/main/scala/rise/core/traverse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,21 @@ object traverse {
case t: TypeIdentifier => typeIdentifier(vt)(t)
}).asInstanceOf[M[I]]
def natDispatch : VarType => Nat => M[Nat] = vt => {
case i : NatIdentifier =>
bind(typeIdentifier(vt)(i))(nat)
case i : NatIdentifier => bind(typeIdentifier(vt)(i))(nat)
case n => nat(n)
}
def matrixLayoutDispatch : VarType => MatrixLayout => M[MatrixLayout] = vt => {
case i : MatrixLayoutIdentifier => bind(typeIdentifier(vt)(i))(matrixLayout)
case m => matrixLayout(m)
}
def fragmentKindDispatch : VarType => FragmentKind => M[FragmentKind] = vt => {
case i : FragmentKindIdentifier => bind(typeIdentifier(vt)(i))(fragmentKind)
case m => fragmentKind(m)
}
def dataTypeDispatch : VarType => DataType => M[DataType] = vt => {
case i : DataTypeIdentifier => bind(typeIdentifier(vt)(i))(datatype)
case d => datatype(d)
}

def addressSpace : AddressSpace => M[AddressSpace] = return_
def matrixLayout : MatrixLayout => M[MatrixLayout] = return_
Expand All @@ -46,30 +57,34 @@ object traverse {
case NatType => return_(NatType : DataType)
case s : ScalarType => return_(s : DataType)
case ArrayType(n, d) =>
for {n1 <- natDispatch(Reference)(n); d1 <- `type`[DataType](d)}
for {n1 <- natDispatch(Reference)(n); d1 <- dataTypeDispatch(Reference)(d)}
yield ArrayType(n1, d1)
case DepArrayType(n, n2d) =>
for {n1 <- natDispatch(Reference)(n); n2d1 <- natToData(n2d)}
yield DepArrayType(n1, n2d1)
case PairType(p1, p2) =>
for {p11 <- `type`(p1); p21 <- `type`(p2)}
yield PairType(p11, p21)
case pair@DepPairType(x, e) =>
for {x1 <- typeIdentifierDispatch(Binding)(x); e1 <- `type`(e)}
yield DepPairType(x1, e1)(pair.kindName)
case pair@DepPairType(x, d) =>
for {x1 <- typeIdentifierDispatch(Binding)(x); d1 <- dataTypeDispatch(Reference)(d)}
yield DepPairType(x1, d1)(pair.kindName)
case IndexType(n) =>
for {n1 <- natDispatch(Reference)(n)}
yield IndexType(n1)
case VectorType(n, e) =>
for {n1 <- natDispatch(Reference)(n); e1 <- `type`(e)}
yield VectorType(n1, e1)
case VectorType(n, d) =>
for {n1 <- natDispatch(Reference)(n); d1 <- dataTypeDispatch(Reference)(d)}
yield VectorType(n1, d1)
case ManagedBufferType(dt) =>
for {dt1 <- datatype(dt)}
for {dt1 <- dataTypeDispatch(Reference)(dt)}
yield ManagedBufferType(dt1)
case o: OpaqueType => return_(o: DataType)
case FragmentType(rows, columns, d3, dt, fragKind, layout) =>
for {rows1 <- nat(rows); columns1 <- nat(columns); d31 <- nat(d3); dt1 <- datatype(dt);
fragKind1 <- fragmentKind(fragKind); layout1 <- matrixLayout(layout)}
for {rows1 <- natDispatch(Reference)(rows);
columns1 <- natDispatch(Reference)(columns);
d31 <- natDispatch(Reference)(d3);
dt1 <- dataTypeDispatch(Reference)(dt);
fragKind1 <- fragmentKindDispatch(Reference)(fragKind);
layout1 <- matrixLayoutDispatch(Reference)(layout)}
yield FragmentType(rows1, columns1, d31, dt1, fragKind1, layout1)
case NatToDataApply(ntdf, n) =>
for {ntdf1 <- natToData(ntdf); n1 <- natDispatch(Reference)(n)}
Expand All @@ -85,9 +100,9 @@ object traverse {

def natToData : NatToData => M[NatToData] = {
case i : NatToDataIdentifier => return_(i.asInstanceOf[NatToData])
case NatToDataLambda(x, e) =>
for { x1 <- typeIdentifierDispatch(Binding)(x); e1 <- `type`(e) }
yield NatToDataLambda(x1, e1)
case NatToDataLambda(x, d) =>
for { x1 <- typeIdentifierDispatch(Binding)(x); d1 <- dataTypeDispatch(Reference)(d) }
yield NatToDataLambda(x1, d1)
}

def data : Data => M[Data] = {
Expand All @@ -109,9 +124,8 @@ object traverse {

def `type`[T <: Type ] : T => M[T] = t => (t match {
case TypePlaceholder => return_(TypePlaceholder)
case i: DataTypeIdentifier => typeIdentifierDispatch(Reference)(i)
case i: TypeIdentifier => typeIdentifierDispatch(Reference)(i)
case dt: DataType => datatype(dt)
case dt: DataType => dataTypeDispatch(Reference)(dt)
case FunType(a, b) =>
for {a1 <- `type`(a); b1 <- `type`(b)}
yield FunType(a1, b1)
Expand Down
Loading