Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not rely on any explicitness information for inference and constraint resolution #174

Merged
merged 22 commits into from
May 31, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5c9ff62
Merge infer.preserving and infer.apply
umazalakain May 7, 2021
67f8e15
Make FTVs to preserve local to solveOne (add bound dep lambda ids)
umazalakain May 13, 2021
27bca59
Don't rely on explicitness info for constraint resolution (minus dept…
umazalakain May 13, 2021
689b76b
Use products instead of a custom case class
umazalakain May 13, 2021
c4e6334
Make preserve more granular, add newly created explicit identifiers
umazalakain May 14, 2021
d6ed8ba
Add all explicit type variables to the set of variables to preserve
umazalakain May 15, 2021
786c59d
Use a global preserve set: type identifiers are unique up to their type
umazalakain May 16, 2021
f7c0533
Remove the explicitly dependent constraint solving
umazalakain May 16, 2021
4a5a0c7
Remove typeEnv from constraint gathering
umazalakain May 16, 2021
27450be
Gather FTV in identifier types as well
umazalakain May 16, 2021
5a94837
Inspect type identifiers in datatypes and nats in fragment kinds
umazalakain May 16, 2021
97c89d0
Keep set of variables to substitute instead of variables to preserve
umazalakain May 16, 2021
7552689
Exclude matrix layout and fragment kind identifiers from vars to close
umazalakain May 16, 2021
c8b666a
Preserve order while assuring uniqueness
umazalakain May 27, 2021
2b55f32
Replace getFTVsRec (uses expl info) with IsClosedForm in makeClosed
umazalakain May 16, 2021
daafb0f
Merge branch 'master' into only-preserve
umazalakain May 27, 2021
972f808
Revert "Keep set of variables to substitute instead of variables to p…
umazalakain May 27, 2021
c674e0c
Use IsClosedForm instead of getFTVs in TopLevel
umazalakain May 27, 2021
24aaf8f
Use IsClosedForm to extract FTVs from assertions, remove getFTVs
umazalakain May 27, 2021
7f1719d
Do not rely on explicitness in collectPreserve
umazalakain May 27, 2021
288eb1e
Expose OrderedSet in freeVars
umazalakain May 28, 2021
264abd5
Do not skip matrix layout and fragment kind identifiers in collectPre…
umazalakain May 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 31 additions & 41 deletions src/main/scala/rise/core/DSL/infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,19 @@ object infer {
traverse(e, Traversal(Set()))._1
}

def preservingWithEnv(e: Expr, env: Map[String, Type], preserve: Set[Kind.Identifier]): Expr = {
val (typed_e, constraints) = constrainTypes(env)(e)
val solution = Constraint.solve(constraints, preserve, Seq())(
Flags.ExplicitDependence.Off)
solution(typed_e)
}
type ExprEnv = Map[String, Type]
type TypeEnv = Set[Kind.Identifier]

// TODO: Get rid of TypeAssertion and deprecate, instead evaluate !: in place and use `preserving` directly
private [DSL] def apply(e: Expr,
printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off,
explDep: Flags.ExplicitDependence = Flags.ExplicitDependence.Off): Expr = {
def apply(e: Expr, exprEnv: ExprEnv = Map(), typeEnv : TypeEnv = Set(),
printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off,
explDep: Flags.ExplicitDependence = Flags.ExplicitDependence.Off): Expr = {
// TODO: Get rid of TypeAssertion and deprecate, instead evaluate !: in place and use `preserving` directly
// Collect FTVs in assertions and opaques; transform assertions into annotations
val (preserve, e_wo_assertions) = traverse(e, collectPreserve)
infer.preserving(e_wo_assertions, preserve, printFlag, explDep)
}

private [DSL] def preserving(wo_assertions: Expr, preserve : Set[Kind.Identifier],
printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off,
explDep: Flags.ExplicitDependence = Flags.ExplicitDependence.Off): Expr = {
val (e_preserve, e_wo_assertions) = traverse(e, collectPreserve)
// Collect constraints
val (typed_e, constraints) = constrainTypes(Map())(wo_assertions)
val (typed_e, constraints) = constrainTypes(exprEnv, e_preserve ++ typeEnv)(e_wo_assertions)
// Solve constraints while preserving the FTVs in preserve
val solution = Constraint.solve(constraints, preserve, Seq())(explDep)
val solution = Constraint.solve(constraints, Seq())(explDep)
// Apply the solution
val res = traverse(typed_e, Visitor(solution))
if (printFlag == Flags.PrintTypesAndTypeHoles.On) {
Expand Down Expand Up @@ -123,33 +113,34 @@ object infer {

private val genType : Expr => Type =
e => if (e.t == TypePlaceholder) freshTypeIdentifier else e.t
private val constrIfTyped : Type => Constraint => Seq[Constraint] =
private def ifTyped[T] : Type => T => Seq[T] =
t => c => if (t == TypePlaceholder) Nil else Seq(c)

private val constrainTypes : Map[String, Type] => Expr => (Expr, Seq[Constraint]) = env => {
def constrainTypes(exprEnv : ExprEnv, typeEnv : TypeEnv) : Expr => (Expr, Seq[(Constraint, TypeEnv)]) = {
case i: Identifier =>
val t = env.getOrElse(i.name,
val t = exprEnv.getOrElse(i.name,
if (i.t == TypePlaceholder) error(s"$i has no type")(Seq()) else i.t )
val c = TypeConstraint(t, i.t)
val c = (TypeConstraint(t, i.t), typeEnv)
(i.setType(t), Nil :+ c)

case expr@Lambda(x, e) =>
val tx = x.setType(genType(x))
val env1 : Map[String, Type] = env + (tx.name -> tx.t)
val (te, cs) = constrainTypes(env1)(e)
val exprEnv1 : Map[String, Type] = exprEnv + (tx.name -> tx.t)
val (te, cs) = constrainTypes(exprEnv1, typeEnv)(e)
val ft = FunType(tx.t, te.t)
val cs1 = constrIfTyped(expr.t)(TypeConstraint(expr.t, ft))
val cs1 = ifTyped(expr.t)((TypeConstraint(expr.t, ft), typeEnv))
(Lambda(tx, te)(ft), cs ++ cs1)

case expr@App(f, e) =>
val (tf, csF) = constrainTypes(env)(f)
val (te, csE) = constrainTypes(env)(e)
val (tf, csF) = constrainTypes(exprEnv, typeEnv)(f)
val (te, csE) = constrainTypes(exprEnv, typeEnv)(e)
val exprT = genType(expr)
val c = TypeConstraint(tf.t, FunType(te.t, exprT))
val c = (TypeConstraint(tf.t, FunType(te.t, exprT)), typeEnv)
(App(tf, te)(exprT), csF ++ csE :+ c)

case expr@DepLambda(x, e) =>
val (te, csE) = constrainTypes(env)(e)
val typeEnv1 = typeEnv + x
val (te, csE) = constrainTypes(exprEnv, typeEnv1)(e)
val exprT = genType(expr)
val tf = x match {
case n: NatIdentifier =>
Expand All @@ -161,23 +152,23 @@ object infer {
case n2n: NatToNatIdentifier =>
DepLambda[NatToNatKind](n2n, te)(DepFunType[NatToNatKind, Type](n2n, te.t))
}
val csE1 = constrIfTyped(expr.t)(TypeConstraint(expr.t, tf.t))
val csE1 = ifTyped(expr.t)((TypeConstraint(expr.t, tf.t), typeEnv1))
(tf, csE ++ csE1)

case expr@DepApp(f, x) =>
val (tf, csF) = constrainTypes(env)(f)
val (tf, csF) = constrainTypes(exprEnv, typeEnv)(f)
val exprT = genType(expr)
val c = DepConstraint(tf.t, x, exprT)
val c = (DepConstraint(tf.t, x, exprT), typeEnv)
(DepApp(tf, x)(exprT), csF :+ c)

case TypeAnnotation(e, t) =>
val (te, csE) = constrainTypes(env)(e)
val c = TypeConstraint(te.t, t)
val (te, csE) = constrainTypes(exprEnv, typeEnv)(e)
val c = (TypeConstraint(te.t, t), typeEnv)
(te, csE :+ c)

case TypeAssertion(e, t) =>
val (te, csE) = constrainTypes(env)(e)
val c = TypeConstraint(te.t, t)
val (te, csE) = constrainTypes(exprEnv, typeEnv)(e)
val c = (TypeConstraint(te.t, t), typeEnv)
(te, csE :+ c)

case o: Opaque => (o, Nil)
Expand All @@ -200,8 +191,7 @@ object infer {
}

object inferDependent {
def apply(e: ToBeTyped[Expr],
printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off): Expr = infer(e match {
case ToBeTyped(e) => e
}, printFlag, Flags.ExplicitDependence.On)
def apply(e: ToBeTyped[Expr], env : Map[String, Type] = Map(), preserve : Set[Kind.Identifier] = Set(),
printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off): Expr =
infer(e match { case ToBeTyped(e) => e }, env, preserve, printFlag, Flags.ExplicitDependence.On)
}
85 changes: 41 additions & 44 deletions src/main/scala/rise/core/types/Constraints.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,14 @@ case class NatCollectionConstraint(a: NatCollection, b: NatCollection)
}

object Constraint {
def canBeSubstituted(preserve: Set[Kind.Identifier],
i: Kind.Identifier with Kind.Explicitness): Boolean =
!(preserve.contains(i) || i.isExplicit)
def canBeSubstituted(preserve: Set[Kind.Identifier], i: TypeIdentifier): Boolean =
type Scope = Set[Kind.Identifier]

def canBeSubstituted(preserve: Set[Kind.Identifier], i: Kind.Identifier): Boolean =
!preserve.contains(i)

def solve(cs: Seq[Constraint], preserve: Set[Kind.Identifier], trace: Seq[Constraint])
def solve(cs: Seq[(Constraint, Scope)], trace: Seq[Constraint])
(implicit explDep: Flags.ExplicitDependence): Solution =
solveRec(cs, Nil, preserve, trace)
solveRec(cs, Nil, trace)
/* faster but not always enough:
cs match {
case Nil => Solution()
Expand All @@ -62,22 +61,23 @@ object Constraint {
}
*/

def solveRec(cs: Seq[Constraint], rs: Seq[Constraint], preserve: Set[Kind.Identifier], trace: Seq[Constraint])
def solveRec(cs: Seq[(Constraint, Scope)], rs: Seq[(Constraint, Scope)], trace: Seq[Constraint])
(implicit explDep: Flags.ExplicitDependence): Solution = (cs, rs) match {
case (Nil, Nil) => Solution()
case (Nil, _) => error(s"could not solve constraints ${rs}")(trace)
case (c +: cs, _) =>
val s = try { solveOne(c, preserve, trace) }
case (scoped +: cs, _) =>
val (c, scope) = scoped
val s = try { solveOne(c, scope, trace) }
catch { case e: InferenceException =>
println(e.msg)
return solveRec(cs, rs :+ c, preserve, trace) }
s ++ solve(s.apply(rs ++ cs), preserve, trace)
return solveRec(cs, rs :+ (c, scope), trace) }
s ++ solve(s.apply(rs ++ cs), trace)
}

// scalastyle:off method.length
def solveOne(c: Constraint, preserve : Set[Kind.Identifier], trace: Seq[Constraint]) (implicit explDep: Flags.ExplicitDependence): Solution = {
implicit val _trace: Seq[Constraint] = trace
def decomposed(cs: Seq[Constraint]) = solve(cs, preserve, c +: trace)
def decomposed(cs: Seq[Constraint], preserve : Set[Kind.Identifier]) = solve(cs.map((_, preserve)), c +: trace)

c match {
case TypeConstraint(a, b) =>
Expand All @@ -96,22 +96,21 @@ object Constraint {
_: IndexType | _: VectorType)
if a =~= b => Solution()
case (IndexType(sa), IndexType(sb)) =>
decomposed(Seq(NatConstraint(sa, sb)))
decomposed(Seq(NatConstraint(sa, sb)), preserve)
case (ArrayType(sa, ea), ArrayType(sb, eb)) =>
decomposed(Seq(NatConstraint(sa, sb), TypeConstraint(ea, eb)))
decomposed(Seq(NatConstraint(sa, sb), TypeConstraint(ea, eb)), preserve)
case (VectorType(sa, ea), VectorType(sb, eb)) =>
decomposed(Seq(NatConstraint(sa, sb), TypeConstraint(ea, eb)))
decomposed(Seq(NatConstraint(sa, sb), TypeConstraint(ea, eb)), preserve)
case (FragmentType(rowsa, columnsa, d3a, dta, fragTypea, layouta), FragmentType(rowsb, columnsb, d3b, dtb, fragTypeb, layoutb)) =>
decomposed(Seq(NatConstraint(rowsa, rowsb), NatConstraint(columnsa, columnsb), NatConstraint(d3a, d3b),
TypeConstraint(dta, dtb), FragmentTypeConstraint(fragTypea, fragTypeb), MatrixLayoutConstraint(layouta, layoutb)))
TypeConstraint(dta, dtb), FragmentTypeConstraint(fragTypea, fragTypeb), MatrixLayoutConstraint(layouta, layoutb)),
preserve)
case (DepArrayType(sa, ea), DepArrayType(sb, eb)) =>
decomposed(Seq(NatConstraint(sa, sb), NatToDataConstraint(ea, eb)))
decomposed(Seq(NatConstraint(sa, sb), NatToDataConstraint(ea, eb)), preserve)
case (PairType(pa1, pa2), PairType(pb1, pb2)) =>
decomposed(Seq(TypeConstraint(pa1, pb1), TypeConstraint(pa2, pb2)))
decomposed(Seq(TypeConstraint(pa1, pb1), TypeConstraint(pa2, pb2)), preserve)
case (FunType(ina, outa), FunType(inb, outb)) =>
decomposed(
Seq(TypeConstraint(ina, inb), TypeConstraint(outa, outb))
)
decomposed(Seq(TypeConstraint(ina, inb), TypeConstraint(outa, outb)), preserve)
case (
DepFunType(na: NatIdentifier, ta),
DepFunType(nb: NatIdentifier, tb)
Expand All @@ -132,19 +131,18 @@ object Constraint {
substitute.natInType(n, `for`= nb, tb), n, preserve)
nTaSub ++ nTbSub ++ decomposed(
Seq(
NatConstraint(n, na.asImplicit),
NatConstraint(n, nb.asImplicit),
NatConstraint(n, na),
NatConstraint(n, nb),
TypeConstraint(nTa, nTb)
))
), preserve - na - nb)
case ExplicitDependence.Off =>
val n = NatIdentifier(freshName("n"), isExplicit = true)
decomposed(
Seq(
NatConstraint(n, na.asImplicit),
NatConstraint(n, nb.asImplicit),
NatConstraint(n, na),
NatConstraint(n, nb),
TypeConstraint(ta, tb)
)
)
), preserve - na - nb)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+ n? same in other spots. I wonder if we could also alternatively directly apply the required substitutions (na -> n and nb -> n) where necessary and only generate 1 constraint instead of 3 here. Something like:

val n = NatIdentifier(freshName("n"))
decomposed(Seq(TypeConstraint(
  substitute.natsInType(ta, Map(na -> n)),
  substitute.natsInType(tb, Map(nb -> n))
))), preserve + n)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I was just picking up on that, I will get back to you!

}
case (
DepFunType(dta: DataTypeIdentifier, ta),
Expand All @@ -153,11 +151,10 @@ object Constraint {
val dt = DataTypeIdentifier(freshName("t"), isExplicit = true)
decomposed(
Seq(
TypeConstraint(dt, dta.asImplicit),
TypeConstraint(dt, dtb.asImplicit),
TypeConstraint(dt, dta),
TypeConstraint(dt, dtb),
TypeConstraint(ta, tb)
)
)
), preserve - dta - dtb)
case (
DepFunType(_: AddressSpaceIdentifier, _),
DepFunType(_: AddressSpaceIdentifier, _)
Expand All @@ -171,10 +168,10 @@ object Constraint {
val n = NatIdentifier(freshName("n"), isExplicit = true)

decomposed(Seq(
NatConstraint(n, x1.asImplicit),
NatConstraint(n, x2.asImplicit),
NatConstraint(n, x1),
NatConstraint(n, x2),
TypeConstraint(t1, t2)
))
), preserve - x1 - x2)

case (
DepPairType(x1: NatCollectionIdentifier, t1),
Expand All @@ -183,10 +180,10 @@ object Constraint {
val n = NatCollectionIdentifier(freshName("n"), isExplicit = true)

decomposed(Seq(
NatCollectionConstraint(n, x1.asImplicit),
NatCollectionConstraint(n, x2.asImplicit),
NatCollectionConstraint(n, x1),
NatCollectionConstraint(n, x2),
TypeConstraint(t1, t2)
))
), preserve - x1 - x2)

case (
NatToDataApply(f: NatToDataIdentifier, _),
Expand Down Expand Up @@ -222,7 +219,7 @@ object Constraint {
df match {
case _: DepFunType[_, _] =>
val applied = liftDependentFunctionType(df)(arg)
decomposed(Seq(TypeConstraint(applied, t)))
decomposed(Seq(TypeConstraint(applied, t)), preserve)
case _ =>
error(s"expected a dependent function type, but got $df")
}
Expand All @@ -237,10 +234,10 @@ object Constraint {
case (NatToDataLambda(x1, dt1), NatToDataLambda(x2, dt2)) =>
val n = NatIdentifier(freshName("n"), isExplicit = true)
decomposed(Seq(
NatConstraint(n, x1.asImplicit),
NatConstraint(n, x2.asImplicit),
NatConstraint(n, x1),
NatConstraint(n, x2),
TypeConstraint(dt1, dt2)
))
), preserve - x1 - x2)

case _ => error(s"cannot unify $a and $b")
}
Expand Down Expand Up @@ -319,7 +316,7 @@ object Constraint {

def unify(a: Nat, b: Nat, preserve : Set[Kind.Identifier])
(implicit trace: Seq[Constraint], explDep: Flags.ExplicitDependence): Solution = {
def decomposed(cs: Seq[Constraint]) = solve(cs, preserve, NatConstraint(a, b) +: trace)
def decomposed(cs: Seq[Constraint]) = solve(cs.map((_, preserve)), NatConstraint(a, b) +: trace)
(a, b) match {
case (i: NatIdentifier, _) => nat.unifyIdent(i, b, preserve)
case (_, i: NatIdentifier) => nat.unifyIdent(i, a, preserve)
Expand Down Expand Up @@ -480,7 +477,7 @@ object Constraint {
def unify(a: BoolExpr, b: BoolExpr, preserve : Set[Kind.Identifier])(
implicit trace: Seq[Constraint], explDep: Flags.ExplicitDependence
): Solution = {
def decomposed(cs: Seq[Constraint]) = solve(cs, preserve, BoolConstraint(a, b) +: trace)
def decomposed(cs: Seq[Constraint]) = solve(cs.map((_, preserve)), BoolConstraint(a, b) +: trace)
(a, b) match {
case _ if a == b => Solution()
case (ArithPredicate(lhs1, rhs1, op1), ArithPredicate(lhs2, rhs2, op2)) if op1 == op2 =>
Expand Down
52 changes: 28 additions & 24 deletions src/main/scala/rise/core/types/Solution.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,29 +132,33 @@ case class Solution(ts: Map[Type, Type],
combine(this, other)
}

def apply(constraints: Seq[Constraint]): Seq[Constraint] = {
constraints.map {
case TypeConstraint(a, b) => TypeConstraint(apply(a), apply(b))
case NatConstraint(a, b) => NatConstraint(apply(a), apply(b))
case BoolConstraint(a, b) => BoolConstraint(apply(a), apply(b))
case MatrixLayoutConstraint(a, b) => MatrixLayoutConstraint(apply(a), apply(b))
case FragmentTypeConstraint(a, b) => FragmentTypeConstraint(apply(a), apply(b))
case NatToDataConstraint(a, b) => NatToDataConstraint(apply(a), apply(b))
case NatCollectionConstraint(a, b) => NatCollectionConstraint(apply(a), apply(b))
case DepConstraint(df, arg: Nat, t) => DepConstraint[NatKind](apply(df), apply(arg), apply(t))
case DepConstraint(df, arg: DataType, t) =>
DepConstraint[DataKind](apply(df), apply(arg).asInstanceOf[DataType], apply(t))
case DepConstraint(df, arg: Type, t) =>
DepConstraint[TypeKind](apply(df), apply(arg), apply(t))
case DepConstraint(df, arg: AddressSpace, t) =>
DepConstraint[AddressSpaceKind](apply(df), apply(arg), apply(t))
case DepConstraint(df, arg: NatToData, t) =>
DepConstraint[NatToDataKind](apply(df), apply(arg), apply(t))
case DepConstraint(df, arg: NatToNat, t) =>
DepConstraint[NatToNatKind](apply(df), apply(arg), apply(t))
case DepConstraint(df, arg: NatCollection, t) =>
DepConstraint[NatCollectionKind](apply(df), apply(arg), apply(t))
case DepConstraint(_, _, _) => throw new Exception("Impossible case")
}
type Scope = Set[Kind.Identifier]

// def apply(constraints: Seq[Constraint]): Seq[Constraint] = constraints.map(apply)
def apply(constraints: Seq[(Constraint, Scope)]): Seq[(Constraint, Scope)] =
constraints.map {case (c, s) => (apply(c), s)}

def apply(constraint: Constraint): Constraint = constraint match {
case TypeConstraint(a, b) => TypeConstraint(apply(a), apply(b))
case NatConstraint(a, b) => NatConstraint(apply(a), apply(b))
case BoolConstraint(a, b) => BoolConstraint(apply(a), apply(b))
case MatrixLayoutConstraint(a, b) => MatrixLayoutConstraint(apply(a), apply(b))
case FragmentTypeConstraint(a, b) => FragmentTypeConstraint(apply(a), apply(b))
case NatToDataConstraint(a, b) => NatToDataConstraint(apply(a), apply(b))
case NatCollectionConstraint(a, b) => NatCollectionConstraint(apply(a), apply(b))
case DepConstraint(df, arg: Nat, t) => DepConstraint[NatKind](apply(df), apply(arg), apply(t))
case DepConstraint(df, arg: DataType, t) =>
DepConstraint[DataKind](apply(df), apply(arg).asInstanceOf[DataType], apply(t))
case DepConstraint(df, arg: Type, t) =>
DepConstraint[TypeKind](apply(df), apply(arg), apply(t))
case DepConstraint(df, arg: AddressSpace, t) =>
DepConstraint[AddressSpaceKind](apply(df), apply(arg), apply(t))
case DepConstraint(df, arg: NatToData, t) =>
DepConstraint[NatToDataKind](apply(df), apply(arg), apply(t))
case DepConstraint(df, arg: NatToNat, t) =>
DepConstraint[NatToNatKind](apply(df), apply(arg), apply(t))
case DepConstraint(df, arg: NatCollection, t) =>
DepConstraint[NatCollectionKind](apply(df), apply(arg), apply(t))
case DepConstraint(_, _, _) => throw new Exception("Impossible case")
}
}
Loading