diff --git a/src/main/scala/rise/core/DSL/TopLevel.scala b/src/main/scala/rise/core/DSL/TopLevel.scala index 10fbd3371..7817e93e2 100644 --- a/src/main/scala/rise/core/DSL/TopLevel.scala +++ b/src/main/scala/rise/core/DSL/TopLevel.scala @@ -4,7 +4,7 @@ import Type.impl import util.monads._ import rise.core.traverse._ import rise.core.types._ -import rise.core.{DSL, Expr, Primitive} +import rise.core.{DSL, Expr, IsClosedForm, Primitive} final case class TopLevel(e: Expr, inst: Solution = Solution())( override val t: Type = e.t @@ -23,7 +23,7 @@ final case class TopLevel(e: Expr, inst: Solution = Solution())( object TopLevel { private def instantiate(t: Type): Solution = { import scala.collection.immutable.Map - infer.getFTVs(t).foldLeft(Solution())((subs, ftv) => + IsClosedForm.varsToClose(t).foldLeft(Solution())((subs, ftv) => subs match { case s@Solution(ts, ns, as, ms, fs, n2ds, n2ns, natColls) => ftv match { diff --git a/src/main/scala/rise/core/DSL/infer.scala b/src/main/scala/rise/core/DSL/infer.scala index 2f827773e..bcce60d96 100644 --- a/src/main/scala/rise/core/DSL/infer.scala +++ b/src/main/scala/rise/core/DSL/infer.scala @@ -34,29 +34,17 @@ object infer { traverse(e, Traversal(Set()))._1 } - def preservingWithEnv(e: Expr, env: Map[String, Type], preserve: Set[Kind.Identifier]): Expr = { - val (typed_e, constraints) = constrainTypes(env)(e) - val solution = Constraint.solve(constraints, preserve, Seq())( - Flags.ExplicitDependence.Off) - solution(typed_e) - } + type ExprEnv = Map[String, Type] - // TODO: Get rid of TypeAssertion and deprecate, instead evaluate !: in place and use `preserving` directly - private [DSL] def apply(e: Expr, - printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off, - explDep: Flags.ExplicitDependence = Flags.ExplicitDependence.Off): Expr = { + def apply(e: Expr, exprEnv: ExprEnv = Map(), typeEnv : Set[Kind.Identifier] = Set(), + printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off): Expr = { + // TODO: Get rid of TypeAssertion and deprecate, instead evaluate !: in place and use `preserving` directly // Collect FTVs in assertions and opaques; transform assertions into annotations - val (preserve, e_wo_assertions) = traverse(e, collectPreserve) - infer.preserving(e_wo_assertions, preserve, printFlag, explDep) - } - - private [DSL] def preserving(wo_assertions: Expr, preserve : Set[Kind.Identifier], - printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off, - explDep: Flags.ExplicitDependence = Flags.ExplicitDependence.Off): Expr = { + val (e_preserve, e_wo_assertions) = traverse(e, collectPreserve) // Collect constraints - val (typed_e, constraints) = constrainTypes(Map())(wo_assertions) + val (typed_e, constraints) = constrainTypes(exprEnv)(e_wo_assertions) // Solve constraints while preserving the FTVs in preserve - val solution = Constraint.solve(constraints, preserve, Seq())(explDep) + val solution = Constraint.solve(constraints, e_preserve ++ typeEnv, Seq()) // Apply the solution val res = traverse(typed_e, Visitor(solution)) if (printFlag == Flags.PrintTypesAndTypeHoles.On) { @@ -83,74 +71,55 @@ object infer { } } - val FTVGathering = new PureAccumulatorTraversal[Seq[Kind.Identifier]] { - override val accumulator = SeqMonoid - override def typeIdentifier[I <: Kind.Identifier]: VarType => I => Pair[I] = _ => { - case i: Kind.Explicitness => accumulate(if (!i.isExplicit) Seq(i) else Seq())(i.asInstanceOf[I]) - case i => accumulate(Seq(i))(i) - } - override def nat: Nat => Pair[Nat] = ae => { - val ftvs = mutable.ListBuffer[Kind.Identifier]() - val r = ae.visitAndRebuild({ - case i: NatIdentifier if !i.isExplicit => ftvs += i; i - case n => n - }) - accumulate(ftvs.toSeq)(r) - } - } - - def getFTVs(t: Type): Seq[Kind.Identifier] = { - traverse(t, FTVGathering)._1.distinct - } - - def getFTVsRec(e: Expr): Seq[Kind.Identifier] = { - traverse(e, FTVGathering)._1.distinct - } - private val collectPreserve = new PureAccumulatorTraversal[Set[Kind.Identifier]] { override val accumulator = SetMonoid + + override def typeIdentifier[I <: Kind.Identifier]: VarType => I => Pair[I] = { + case Binding => i => accumulate(Set(i))(i) + case _ => return_ + } + override def expr: Expr => Pair[Expr] = { // Transform assertions into annotations, collect FTVs case TypeAssertion(e, t) => - val (s, e1) = expr(e).unwrap - accumulate(s ++ getFTVs(t))(TypeAnnotation(e1, t) : Expr) + val (s1, e1) = expr(e).unwrap + accumulate(s1 ++ IsClosedForm.freeVars(t).set)(TypeAnnotation(e1, t) : Expr) // Collect FTVs case Opaque(e, t) => - accumulate(getFTVs(t).toSet)(Opaque(e, t) : Expr) + accumulate(IsClosedForm.freeVars(t).set)(Opaque(e, t) : Expr) case e => super.expr(e) } } private val genType : Expr => Type = e => if (e.t == TypePlaceholder) freshTypeIdentifier else e.t - private val constrIfTyped : Type => Constraint => Seq[Constraint] = + private def ifTyped[T] : Type => T => Seq[T] = t => c => if (t == TypePlaceholder) Nil else Seq(c) - private val constrainTypes : Map[String, Type] => Expr => (Expr, Seq[Constraint]) = env => { + def constrainTypes(exprEnv : ExprEnv) : Expr => (Expr, Seq[Constraint]) = { case i: Identifier => - val t = env.getOrElse(i.name, + val t = exprEnv.getOrElse(i.name, if (i.t == TypePlaceholder) error(s"$i has no type")(Seq()) else i.t ) val c = TypeConstraint(t, i.t) (i.setType(t), Nil :+ c) case expr@Lambda(x, e) => val tx = x.setType(genType(x)) - val env1 : Map[String, Type] = env + (tx.name -> tx.t) - val (te, cs) = constrainTypes(env1)(e) + val exprEnv1 = exprEnv + (tx.name -> tx.t) + val (te, cs) = constrainTypes(exprEnv1)(e) val ft = FunType(tx.t, te.t) - val cs1 = constrIfTyped(expr.t)(TypeConstraint(expr.t, ft)) + val cs1 = ifTyped(expr.t)(TypeConstraint(expr.t, ft)) (Lambda(tx, te)(ft), cs ++ cs1) case expr@App(f, e) => - val (tf, csF) = constrainTypes(env)(f) - val (te, csE) = constrainTypes(env)(e) + val (tf, csF) = constrainTypes(exprEnv)(f) + val (te, csE) = constrainTypes(exprEnv)(e) val exprT = genType(expr) val c = TypeConstraint(tf.t, FunType(te.t, exprT)) (App(tf, te)(exprT), csF ++ csE :+ c) case expr@DepLambda(x, e) => - val (te, csE) = constrainTypes(env)(e) - val exprT = genType(expr) + val (te, csE) = constrainTypes(exprEnv)(e) val tf = x match { case n: NatIdentifier => DepLambda[NatKind](n, te)(DepFunType[NatKind, Type](n, te.t)) @@ -161,22 +130,22 @@ object infer { case n2n: NatToNatIdentifier => DepLambda[NatToNatKind](n2n, te)(DepFunType[NatToNatKind, Type](n2n, te.t)) } - val csE1 = constrIfTyped(expr.t)(TypeConstraint(expr.t, tf.t)) + val csE1 = ifTyped(expr.t)(TypeConstraint(expr.t, tf.t)) (tf, csE ++ csE1) case expr@DepApp(f, x) => - val (tf, csF) = constrainTypes(env)(f) + val (tf, csF) = constrainTypes(exprEnv)(f) val exprT = genType(expr) val c = DepConstraint(tf.t, x, exprT) (DepApp(tf, x)(exprT), csF :+ c) case TypeAnnotation(e, t) => - val (te, csE) = constrainTypes(env)(e) + val (te, csE) = constrainTypes(exprEnv)(e) val c = TypeConstraint(te.t, t) (te, csE :+ c) case TypeAssertion(e, t) => - val (te, csE) = constrainTypes(env)(e) + val (te, csE) = constrainTypes(exprEnv)(e) val c = TypeConstraint(te.t, t) (te, csE :+ c) @@ -197,11 +166,4 @@ object infer { override def natToData : NatToData => Pure[NatToData] = n2d => return_(sol(n2d)) override def natToNat : NatToNat => Pure[NatToNat] = n2n => return_(sol(n2n)) } -} - -object inferDependent { - def apply(e: ToBeTyped[Expr], - printFlag: Flags.PrintTypesAndTypeHoles = Flags.PrintTypesAndTypeHoles.Off): Expr = infer(e match { - case ToBeTyped(e) => e - }, printFlag, Flags.ExplicitDependence.On) -} +} \ No newline at end of file diff --git a/src/main/scala/rise/core/IsClosedForm.scala b/src/main/scala/rise/core/IsClosedForm.scala index 4f33a7e7c..462d2f6cb 100644 --- a/src/main/scala/rise/core/IsClosedForm.scala +++ b/src/main/scala/rise/core/IsClosedForm.scala @@ -6,15 +6,32 @@ import rise.core.traverse._ import rise.core.types._ object IsClosedForm { + case class OrderedSet[T](seq : Seq[T], set : Set[T]) + object OrderedSet { + def empty[T] : OrderedSet[T] = OrderedSet(Seq(), Set()) + def add[T] : T => OrderedSet[T] => OrderedSet[T] = t => ts => + if (ts.set.contains(t)) ts else OrderedSet(t +: ts.seq, ts.set + t) + def one[T] : T => OrderedSet[T] = add(_)(empty) + def append[T] : OrderedSet[T] => OrderedSet[T] => OrderedSet[T] = x => y => { + val ordered = x.seq.filter(!y.set.contains(_)) ++ y.seq + val unique = x.set ++ y.set + OrderedSet(ordered, unique) + } + } + implicit def OrderedSetMonoid[T] : Monoid[OrderedSet[T]] = new Monoid[OrderedSet[T]] { + def empty : OrderedSet[T] = OrderedSet.empty + def append : OrderedSet[T] => OrderedSet[T] => OrderedSet[T] = OrderedSet.append + } + case class Visitor(boundV: Set[Identifier], boundT: Set[Kind.Identifier]) - extends PureAccumulatorTraversal[(Set[Identifier], Set[Kind.Identifier])] + extends PureAccumulatorTraversal[(OrderedSet[Identifier], OrderedSet[Kind.Identifier])] { - override val accumulator = PairMonoid(SetMonoid, SetMonoid) + override val accumulator = PairMonoid(OrderedSetMonoid, OrderedSetMonoid) override def identifier[I <: Identifier]: VarType => I => Pair[I] = vt => i => { for { t2 <- `type`(i.t); i2 <- if (vt == Reference && !boundV(i)) { - accumulate((Set(i), Set()))(i) + accumulate((OrderedSet.one(i : Identifier), OrderedSet.empty))(i) } else { return_(i) }} @@ -23,20 +40,27 @@ object IsClosedForm { override def typeIdentifier[I <: Kind.Identifier]: VarType => I => Pair[I] = { case Reference => i => - if (boundT(i)) return_(i) else accumulate((Set(), Set(i)))(i) + if (boundT(i)) return_(i) else accumulate((OrderedSet.empty, OrderedSet.one(i : Kind.Identifier)))(i) case _ => return_ } override def nat: Nat => Pair[Nat] = n => { - val free = n.varList.foldLeft(Set[Kind.Identifier]()) { - case (free, v: NamedVar) if !boundT(NatIdentifier(v)) => free + NatIdentifier(v) + val free = n.varList.foldLeft(OrderedSet.empty[Kind.Identifier]) { + case (free, v: NamedVar) if !boundT(NatIdentifier(v)) => OrderedSet.add(NatIdentifier(v) : Kind.Identifier)(free) case (free, _) => free } - accumulate((Set(), free))(n) + accumulate((OrderedSet.empty, free))(n) } override def expr: Expr => Pair[Expr] = { - case Lambda(x, b) => this.copy(boundV = boundV + x).expr(b) + case l@Lambda(x, e) => + // The binder's type itself might contain free type variables + val ((fVx, fTx), x1) = identifier(Binding)(x).unwrap + val ((fVe, fTe), e1) = this.copy(boundV = boundV + x1).expr(e).unwrap + val ((fVt, fTt), t1) = `type`(l.t).unwrap + val fV = OrderedSet.append(OrderedSet.append(fVx)(fVe))(fVt) + val fT = OrderedSet.append(OrderedSet.append(fTx)(fTe))(fTt) + accumulate((fV, fT))(Lambda(x1, e1)(t1): Expr) case DepLambda(x, b) => this.copy(boundT = boundT + x).expr(b) case e => super.expr(e) } @@ -66,11 +90,32 @@ object IsClosedForm { } } - def freeVars(expr: Expr): (Set[Identifier], Set[Kind.Identifier]) = - traverse(expr, Visitor(Set(), Set()))._1 + def freeVars(expr: Expr): (OrderedSet[Identifier], OrderedSet[Kind.Identifier]) = { + val ((fV, fT), _) = traverse(expr, Visitor(Set(), Set())) + (fV, fT) + } + + def freeVars(t: Type): OrderedSet[Kind.Identifier] = { + val ((_, ftv), _) = traverse(t, Visitor(Set(), Set())) + ftv + } + + // Exclude matrix layout and fragment kind identifiers, since they cannot currently be bound + def needsClosing : Seq[Kind.Identifier] => Seq[Kind.Identifier] = _.flatMap { + case i : MatrixLayoutIdentifier => Seq() + case i : FragmentKindIdentifier => Seq() + case e => Seq(e) + } + + def varsToClose(expr : Expr): (Seq[Identifier], Seq[Kind.Identifier]) = { + val (fV, fT) = freeVars(expr) + (fV.seq, needsClosing(fT.seq)) + } + + def varsToClose(t : Type): Seq[Kind.Identifier] = needsClosing(freeVars(t).seq) def apply(expr: Expr): Boolean = { - val (freeV, freeT) = freeVars(expr) + val (freeV, freeT) = varsToClose(expr) freeV.isEmpty && freeT.isEmpty } } \ No newline at end of file diff --git a/src/main/scala/rise/core/makeClosed.scala b/src/main/scala/rise/core/makeClosed.scala index 7bceec36d..c5d20e199 100644 --- a/src/main/scala/rise/core/makeClosed.scala +++ b/src/main/scala/rise/core/makeClosed.scala @@ -11,7 +11,7 @@ object makeClosed { Map[AddressSpaceIdentifier, AddressSpace], Map[NatToDataIdentifier, NatToData]) = (Map(), Map(), Map(), Map()) - val (expr, (ts, ns, as, n2ds)) = DSL.infer.getFTVsRec(e).foldLeft((e, emptySubs))((acc, ftv) => acc match { + val (expr, (ts, ns, as, n2ds)) = IsClosedForm.varsToClose(e)._2.foldLeft((e, emptySubs))((acc, ftv) => acc match { case (expr, (ts, ns, as, n2ds)) => ftv match { case i: TypeIdentifier => val dt = DataTypeIdentifier(freshName("dt"), isExplicit = true) diff --git a/src/main/scala/rise/core/traverse.scala b/src/main/scala/rise/core/traverse.scala index 724fe1f4f..6f224bb60 100644 --- a/src/main/scala/rise/core/traverse.scala +++ b/src/main/scala/rise/core/traverse.scala @@ -33,10 +33,21 @@ object traverse { case t: TypeIdentifier => typeIdentifier(vt)(t) }).asInstanceOf[M[I]] def natDispatch : VarType => Nat => M[Nat] = vt => { - case i : NatIdentifier => - bind(typeIdentifier(vt)(i))(nat) + case i : NatIdentifier => bind(typeIdentifier(vt)(i))(nat) case n => nat(n) } + def matrixLayoutDispatch : VarType => MatrixLayout => M[MatrixLayout] = vt => { + case i : MatrixLayoutIdentifier => bind(typeIdentifier(vt)(i))(matrixLayout) + case m => matrixLayout(m) + } + def fragmentKindDispatch : VarType => FragmentKind => M[FragmentKind] = vt => { + case i : FragmentKindIdentifier => bind(typeIdentifier(vt)(i))(fragmentKind) + case m => fragmentKind(m) + } + def dataTypeDispatch : VarType => DataType => M[DataType] = vt => { + case i : DataTypeIdentifier => bind(typeIdentifier(vt)(i))(datatype) + case d => datatype(d) + } def addressSpace : AddressSpace => M[AddressSpace] = return_ def matrixLayout : MatrixLayout => M[MatrixLayout] = return_ @@ -46,7 +57,7 @@ object traverse { case NatType => return_(NatType : DataType) case s : ScalarType => return_(s : DataType) case ArrayType(n, d) => - for {n1 <- natDispatch(Reference)(n); d1 <- `type`[DataType](d)} + for {n1 <- natDispatch(Reference)(n); d1 <- dataTypeDispatch(Reference)(d)} yield ArrayType(n1, d1) case DepArrayType(n, n2d) => for {n1 <- natDispatch(Reference)(n); n2d1 <- natToData(n2d)} @@ -54,22 +65,26 @@ object traverse { case PairType(p1, p2) => for {p11 <- `type`(p1); p21 <- `type`(p2)} yield PairType(p11, p21) - case pair@DepPairType(x, e) => - for {x1 <- typeIdentifierDispatch(Binding)(x); e1 <- `type`(e)} - yield DepPairType(x1, e1)(pair.kindName) + case pair@DepPairType(x, d) => + for {x1 <- typeIdentifierDispatch(Binding)(x); d1 <- dataTypeDispatch(Reference)(d)} + yield DepPairType(x1, d1)(pair.kindName) case IndexType(n) => for {n1 <- natDispatch(Reference)(n)} yield IndexType(n1) - case VectorType(n, e) => - for {n1 <- natDispatch(Reference)(n); e1 <- `type`(e)} - yield VectorType(n1, e1) + case VectorType(n, d) => + for {n1 <- natDispatch(Reference)(n); d1 <- dataTypeDispatch(Reference)(d)} + yield VectorType(n1, d1) case ManagedBufferType(dt) => - for {dt1 <- datatype(dt)} + for {dt1 <- dataTypeDispatch(Reference)(dt)} yield ManagedBufferType(dt1) case o: OpaqueType => return_(o: DataType) case FragmentType(rows, columns, d3, dt, fragKind, layout) => - for {rows1 <- nat(rows); columns1 <- nat(columns); d31 <- nat(d3); dt1 <- datatype(dt); - fragKind1 <- fragmentKind(fragKind); layout1 <- matrixLayout(layout)} + for {rows1 <- natDispatch(Reference)(rows); + columns1 <- natDispatch(Reference)(columns); + d31 <- natDispatch(Reference)(d3); + dt1 <- dataTypeDispatch(Reference)(dt); + fragKind1 <- fragmentKindDispatch(Reference)(fragKind); + layout1 <- matrixLayoutDispatch(Reference)(layout)} yield FragmentType(rows1, columns1, d31, dt1, fragKind1, layout1) case NatToDataApply(ntdf, n) => for {ntdf1 <- natToData(ntdf); n1 <- natDispatch(Reference)(n)} @@ -85,9 +100,9 @@ object traverse { def natToData : NatToData => M[NatToData] = { case i : NatToDataIdentifier => return_(i.asInstanceOf[NatToData]) - case NatToDataLambda(x, e) => - for { x1 <- typeIdentifierDispatch(Binding)(x); e1 <- `type`(e) } - yield NatToDataLambda(x1, e1) + case NatToDataLambda(x, d) => + for { x1 <- typeIdentifierDispatch(Binding)(x); d1 <- dataTypeDispatch(Reference)(d) } + yield NatToDataLambda(x1, d1) } def data : Data => M[Data] = { @@ -109,9 +124,8 @@ object traverse { def `type`[T <: Type ] : T => M[T] = t => (t match { case TypePlaceholder => return_(TypePlaceholder) - case i: DataTypeIdentifier => typeIdentifierDispatch(Reference)(i) case i: TypeIdentifier => typeIdentifierDispatch(Reference)(i) - case dt: DataType => datatype(dt) + case dt: DataType => dataTypeDispatch(Reference)(dt) case FunType(a, b) => for {a1 <- `type`(a); b1 <- `type`(b)} yield FunType(a1, b1) diff --git a/src/main/scala/rise/core/types/Constraints.scala b/src/main/scala/rise/core/types/Constraints.scala index f1273605a..bf8839f4c 100644 --- a/src/main/scala/rise/core/types/Constraints.scala +++ b/src/main/scala/rise/core/types/Constraints.scala @@ -44,15 +44,11 @@ case class NatCollectionConstraint(a: NatCollection, b: NatCollection) } object Constraint { - def canBeSubstituted(preserve: Set[Kind.Identifier], - i: Kind.Identifier with Kind.Explicitness): Boolean = - !(preserve.contains(i) || i.isExplicit) - def canBeSubstituted(preserve: Set[Kind.Identifier], i: TypeIdentifier): Boolean = + def canBeSubstituted(preserve: Set[Kind.Identifier], i: Kind.Identifier): Boolean = !preserve.contains(i) - def solve(cs: Seq[Constraint], preserve: Set[Kind.Identifier], trace: Seq[Constraint]) - (implicit explDep: Flags.ExplicitDependence): Solution = - solveRec(cs, Nil, preserve, trace) + def solve(cs: Seq[Constraint], preserve: Set[Kind.Identifier], trace: Seq[Constraint]): Solution = + solveRec(cs, Nil, preserve, trace) /* faster but not always enough: cs match { case Nil => Solution() @@ -62,8 +58,7 @@ object Constraint { } */ - def solveRec(cs: Seq[Constraint], rs: Seq[Constraint], preserve: Set[Kind.Identifier], trace: Seq[Constraint]) - (implicit explDep: Flags.ExplicitDependence): Solution = (cs, rs) match { + def solveRec(cs: Seq[Constraint], rs: Seq[Constraint], preserve: Set[Kind.Identifier], trace: Seq[Constraint]): Solution = (cs, rs) match { case (Nil, Nil) => Solution() case (Nil, _) => error(s"could not solve constraints ${rs}")(trace) case (c +: cs, _) => @@ -75,8 +70,9 @@ object Constraint { } // scalastyle:off method.length - def solveOne(c: Constraint, preserve : Set[Kind.Identifier], trace: Seq[Constraint]) (implicit explDep: Flags.ExplicitDependence): Solution = { + def solveOne(c: Constraint, preserve : Set[Kind.Identifier], trace: Seq[Constraint]): Solution = { implicit val _trace: Seq[Constraint] = trace + def decomposedPreserve(cs: Seq[Constraint], preserve : Set[Kind.Identifier]) = solve(cs, preserve, c +: trace) def decomposed(cs: Seq[Constraint]) = solve(cs, preserve, c +: trace) c match { @@ -109,55 +105,27 @@ object Constraint { case (PairType(pa1, pa2), PairType(pb1, pb2)) => decomposed(Seq(TypeConstraint(pa1, pb1), TypeConstraint(pa2, pb2))) case (FunType(ina, outa), FunType(inb, outb)) => - decomposed( - Seq(TypeConstraint(ina, inb), TypeConstraint(outa, outb)) - ) + decomposed(Seq(TypeConstraint(ina, inb), TypeConstraint(outa, outb))) case ( DepFunType(na: NatIdentifier, ta), DepFunType(nb: NatIdentifier, tb) ) => - explDep match { - case ExplicitDependence.On => - val n = NatIdentifier(freshName("n"), isExplicit = true) - /** Note(federico): - * This step recurses in both functions and makes dependence between type - * variables and n explicit (by replacing type variables with NatToData/NatToNat). - * - * Perhaps this can be moved away from constraint solving, and pulled up in the - * initial constrain-types phase? - */ - val (nTa, nTaSub) = dependence.explicitlyDependent( - substitute.natInType(n, `for`=na, ta), n, preserve) - val (nTb, nTbSub) = dependence.explicitlyDependent( - substitute.natInType(n, `for`= nb, tb), n, preserve) - nTaSub ++ nTbSub ++ decomposed( - Seq( - NatConstraint(n, na.asImplicit), - NatConstraint(n, nb.asImplicit), - TypeConstraint(nTa, nTb) - )) - case ExplicitDependence.Off => - val n = NatIdentifier(freshName("n"), isExplicit = true) - decomposed( - Seq( - NatConstraint(n, na.asImplicit), - NatConstraint(n, nb.asImplicit), - TypeConstraint(ta, tb) - ) - ) - } + val n = NatIdentifier(freshName("n"), isExplicit = true) + decomposedPreserve(Seq( + NatConstraint(n, na), + NatConstraint(n, nb), + TypeConstraint(ta, tb), + ), preserve + n - na - nb) case ( DepFunType(dta: DataTypeIdentifier, ta), DepFunType(dtb: DataTypeIdentifier, tb) ) => val dt = DataTypeIdentifier(freshName("t"), isExplicit = true) - decomposed( - Seq( - TypeConstraint(dt, dta.asImplicit), - TypeConstraint(dt, dtb.asImplicit), - TypeConstraint(ta, tb) - ) - ) + decomposedPreserve(Seq( + TypeConstraint(dt, dta), + TypeConstraint(dt, dtb), + TypeConstraint(ta, tb), + ), preserve + dt - dta - dtb) case ( DepFunType(_: AddressSpaceIdentifier, _), DepFunType(_: AddressSpaceIdentifier, _) @@ -169,24 +137,22 @@ object Constraint { DepPairType(x2: NatIdentifier, t2) ) => val n = NatIdentifier(freshName("n"), isExplicit = true) - - decomposed(Seq( - NatConstraint(n, x1.asImplicit), - NatConstraint(n, x2.asImplicit), - TypeConstraint(t1, t2) - )) + decomposedPreserve(Seq( + NatConstraint(n, x1), + NatConstraint(n, x2), + TypeConstraint(t1, t2), + ), preserve + n - x1 - x2) case ( DepPairType(x1: NatCollectionIdentifier, t1), DepPairType(x2: NatCollectionIdentifier, t2) ) => val n = NatCollectionIdentifier(freshName("n"), isExplicit = true) - - decomposed(Seq( - NatCollectionConstraint(n, x1.asImplicit), - NatCollectionConstraint(n, x2.asImplicit), - TypeConstraint(t1, t2) - )) + decomposedPreserve(Seq( + NatCollectionConstraint(n, x1), + NatCollectionConstraint(n, x2), + TypeConstraint(t1, t2), + ), preserve + n - x1 - x2) case ( NatToDataApply(f: NatToDataIdentifier, _), @@ -236,11 +202,11 @@ object Constraint { case _ if a == b => Solution() case (NatToDataLambda(x1, dt1), NatToDataLambda(x2, dt2)) => val n = NatIdentifier(freshName("n"), isExplicit = true) - decomposed(Seq( - NatConstraint(n, x1.asImplicit), - NatConstraint(n, x2.asImplicit), - TypeConstraint(dt1, dt2) - )) + decomposedPreserve(Seq( + NatConstraint(n, x1), + NatConstraint(n, x2), + TypeConstraint(dt1, dt2), + ), preserve + n - x1 - x2) case _ => error(s"cannot unify $a and $b") } @@ -302,14 +268,14 @@ object Constraint { } else if (canBeSubstituted(preserve, j)) { Solution.subs(j, i) } else { - error(s"cannot unify $i and $j, they are both explicit or in $preserve") + error(s"cannot unify $i and $j, they are both in $preserve") } case _ if occurs(i, t) => error(s"circular use: $i occurs in $t") case _ => if (canBeSubstituted(preserve, i)) { Solution.subs(i, t) } else { - error(s"cannot substitute $i, it is explicit") + error(s"cannot substitute $i, it is $preserve") } } } @@ -318,7 +284,7 @@ object Constraint { import arithexpr.arithmetic._ def unify(a: Nat, b: Nat, preserve : Set[Kind.Identifier]) - (implicit trace: Seq[Constraint], explDep: Flags.ExplicitDependence): Solution = { + (implicit trace: Seq[Constraint]): Solution = { def decomposed(cs: Seq[Constraint]) = solve(cs, preserve, NatConstraint(a, b) +: trace) (a, b) match { case (i: NatIdentifier, _) => nat.unifyIdent(i, b, preserve) @@ -417,7 +383,8 @@ object Constraint { } } - def tryPivots(n: Nat, value: Nat, preserve : Set[Kind.Identifier])(implicit trace: Seq[Constraint]): Solution = { + def tryPivots(n: Nat, value: Nat, preserve : Set[Kind.Identifier]) + (implicit trace: Seq[Constraint]): Solution = { potentialPivots(n, preserve).foreach(pivotSolution(_, n, value) match { case Some(s) => return s case None => @@ -425,19 +392,20 @@ object Constraint { error(s"could not pivot $n = $value") } - def unifyProd(p: Nat, n: Nat, preserve : Set[Kind.Identifier])(implicit trace: Seq[Constraint]): Solution = { + def unifyProd(p: Nat, n: Nat, preserve : Set[Kind.Identifier]) + (implicit trace: Seq[Constraint]): Solution = { // n = p --> 1 = p * (1/n) tryPivots(p /^ n, 1, preserve) } - def unifySum(s: Sum, n: Nat, preserve : Set[Kind.Identifier])(implicit trace: Seq[Constraint]): Solution = { + def unifySum(s: Sum, n: Nat, preserve : Set[Kind.Identifier]) + (implicit trace: Seq[Constraint]): Solution = { // n = s --> 0 = s + (-n) tryPivots(s - n, 0, preserve) } - def unifyIdent(i: NatIdentifier, n: Nat, preserve : Set[Kind.Identifier])( - implicit trace: Seq[Constraint], explDep: Flags.ExplicitDependence - ): Solution = n match { + def unifyIdent(i: NatIdentifier, n: Nat, preserve : Set[Kind.Identifier]) + (implicit trace: Seq[Constraint]): Solution = n match { case j: NatIdentifier => if (i == j) { Solution() @@ -456,9 +424,8 @@ object Constraint { case _ => error(s"cannot unify $i and $n") } - def unifyApply(apply: NatToNatApply, nat: Nat, preserve: Set[Kind.Identifier])( - implicit trace: Seq[Constraint], explDep: Flags.ExplicitDependence - ): Solution = { + def unifyApply(apply: NatToNatApply, nat: Nat, preserve: Set[Kind.Identifier]) + (implicit trace: Seq[Constraint]): Solution = { val NatToNatApply(f1, n1) = apply nat match { case NatToNatApply(f2, n2) => @@ -477,9 +444,8 @@ object Constraint { private object bool { import arithexpr.arithmetic._ - def unify(a: BoolExpr, b: BoolExpr, preserve : Set[Kind.Identifier])( - implicit trace: Seq[Constraint], explDep: Flags.ExplicitDependence - ): Solution = { + def unify(a: BoolExpr, b: BoolExpr, preserve : Set[Kind.Identifier]) + (implicit trace: Seq[Constraint]): Solution = { def decomposed(cs: Seq[Constraint]) = solve(cs, preserve, BoolConstraint(a, b) +: trace) (a, b) match { case _ if a == b => Solution() @@ -491,7 +457,8 @@ object Constraint { } object natToData { - def unifyIdent(i: NatToDataIdentifier, n: NatToData)(implicit trace: Seq[Constraint]): Solution = n match { + def unifyIdent(i: NatToDataIdentifier, n: NatToData) + (implicit trace: Seq[Constraint]): Solution = n match { case j: NatToDataIdentifier => if (i == j) { Solution() @@ -503,9 +470,8 @@ object Constraint { } object natToNat { - def unify(f1: NatToNat, f2: NatToNat, preserve: Set[Kind.Identifier])( - implicit trace: Seq[Constraint], explDep: Flags.ExplicitDependence - ): Solution = f1 match { + def unify(f1: NatToNat, f2: NatToNat, preserve: Set[Kind.Identifier]) + (implicit trace: Seq[Constraint]): Solution = f1 match { case id1: NatToNatIdentifier => Solution.subs(id1, f2) case NatToNatLambda(x1, body1) => f2 match { case id2: NatToNatIdentifier => Solution.subs(id2, f1) @@ -514,15 +480,14 @@ object Constraint { nat.unify( substitute.natInNat(n, `for` = x1, body1), substitute.natInNat(n, `for`=x2, body2), - preserve) + preserve + n) } } } object natCollection { - def unifyIdent(i: NatCollectionIdentifier, n: NatCollection, preserve : Set[Kind.Identifier])( - implicit trace: Seq[Constraint] - ): Solution = n match { + def unifyIdent(i: NatCollectionIdentifier, n: NatCollection, preserve : Set[Kind.Identifier]) + (implicit trace: Seq[Constraint]): Solution = n match { case j: NatCollectionIdentifier => if (i == j) { Solution() @@ -538,43 +503,4 @@ object Constraint { case _ => false } -} - -object dependence { - /* - * Given a type t which is in the scope of a natIdentifier depVar, - * explicitly represent the dependence by replacing identifiers in t - * with applied nat-to-X functions. - */ - def explicitlyDependent(t: Type, depVar: NatIdentifier, preserve : Set[Kind.Identifier]): (Type, Solution) = { - val visitor = new PureAccumulatorTraversal[Seq[Solution]] { - override val accumulator = SeqMonoid - - override def nat: Nat => Pair[Nat] = { - case n2n@NatToNatApply(_, n) if n == depVar => return_(n2n : Nat) - case ident: NatIdentifier - if ident != depVar && Constraint.canBeSubstituted(preserve, ident) => - val sol = Solution.subs(ident, NatToNatApply(NatToNatIdentifier(freshName("nnf")), depVar)) - accumulate(Seq(sol))(ident.asImplicit : Nat) - case n => super.nat(n) - } - - override def `type`[T <: Type] : T => Pair[T] = { - case n2d@NatToDataApply(_, x) if x == depVar => return_(n2d : T) - case ident@TypeIdentifier(i) => - val application = NatToDataApply(NatToDataIdentifier(freshName("nnf")), depVar) - val sol = Solution.subs(ident, application) - accumulate(Seq(sol))(ident.asInstanceOf[T]) - case e => super.`type`(e) - } - - def apply(t: Type): (Type, Solution) = { - val (sols, rewrittenT) = traverse(t, this) - val solution = sols.foldLeft(Solution())(_ ++ _) - (solution.apply(rewrittenT), solution) - } - } - - visitor(t) - } -} +} \ No newline at end of file diff --git a/src/main/scala/rise/core/types/Solution.scala b/src/main/scala/rise/core/types/Solution.scala index fe6629dcb..c00b5fd27 100644 --- a/src/main/scala/rise/core/types/Solution.scala +++ b/src/main/scala/rise/core/types/Solution.scala @@ -132,29 +132,28 @@ case class Solution(ts: Map[Type, Type], combine(this, other) } - def apply(constraints: Seq[Constraint]): Seq[Constraint] = { - constraints.map { - case TypeConstraint(a, b) => TypeConstraint(apply(a), apply(b)) - case NatConstraint(a, b) => NatConstraint(apply(a), apply(b)) - case BoolConstraint(a, b) => BoolConstraint(apply(a), apply(b)) - case MatrixLayoutConstraint(a, b) => MatrixLayoutConstraint(apply(a), apply(b)) - case FragmentTypeConstraint(a, b) => FragmentTypeConstraint(apply(a), apply(b)) - case NatToDataConstraint(a, b) => NatToDataConstraint(apply(a), apply(b)) - case NatCollectionConstraint(a, b) => NatCollectionConstraint(apply(a), apply(b)) - case DepConstraint(df, arg: Nat, t) => DepConstraint[NatKind](apply(df), apply(arg), apply(t)) - case DepConstraint(df, arg: DataType, t) => - DepConstraint[DataKind](apply(df), apply(arg).asInstanceOf[DataType], apply(t)) - case DepConstraint(df, arg: Type, t) => - DepConstraint[TypeKind](apply(df), apply(arg), apply(t)) - case DepConstraint(df, arg: AddressSpace, t) => - DepConstraint[AddressSpaceKind](apply(df), apply(arg), apply(t)) - case DepConstraint(df, arg: NatToData, t) => - DepConstraint[NatToDataKind](apply(df), apply(arg), apply(t)) - case DepConstraint(df, arg: NatToNat, t) => - DepConstraint[NatToNatKind](apply(df), apply(arg), apply(t)) - case DepConstraint(df, arg: NatCollection, t) => - DepConstraint[NatCollectionKind](apply(df), apply(arg), apply(t)) - case DepConstraint(_, _, _) => throw new Exception("Impossible case") - } + def apply(constraints: Seq[Constraint]): Seq[Constraint] = constraints.map(apply) + def apply(constraint: Constraint): Constraint = constraint match { + case TypeConstraint(a, b) => TypeConstraint(apply(a), apply(b)) + case NatConstraint(a, b) => NatConstraint(apply(a), apply(b)) + case BoolConstraint(a, b) => BoolConstraint(apply(a), apply(b)) + case MatrixLayoutConstraint(a, b) => MatrixLayoutConstraint(apply(a), apply(b)) + case FragmentTypeConstraint(a, b) => FragmentTypeConstraint(apply(a), apply(b)) + case NatToDataConstraint(a, b) => NatToDataConstraint(apply(a), apply(b)) + case NatCollectionConstraint(a, b) => NatCollectionConstraint(apply(a), apply(b)) + case DepConstraint(df, arg: Nat, t) => DepConstraint[NatKind](apply(df), apply(arg), apply(t)) + case DepConstraint(df, arg: DataType, t) => + DepConstraint[DataKind](apply(df), apply(arg).asInstanceOf[DataType], apply(t)) + case DepConstraint(df, arg: Type, t) => + DepConstraint[TypeKind](apply(df), apply(arg), apply(t)) + case DepConstraint(df, arg: AddressSpace, t) => + DepConstraint[AddressSpaceKind](apply(df), apply(arg), apply(t)) + case DepConstraint(df, arg: NatToData, t) => + DepConstraint[NatToDataKind](apply(df), apply(arg), apply(t)) + case DepConstraint(df, arg: NatToNat, t) => + DepConstraint[NatToNatKind](apply(df), apply(arg), apply(t)) + case DepConstraint(df, arg: NatCollection, t) => + DepConstraint[NatCollectionKind](apply(df), apply(arg), apply(t)) + case DepConstraint(_, _, _) => throw new Exception("Impossible case") } } diff --git a/src/main/scala/rise/eqsat/NamedRewrite.scala b/src/main/scala/rise/eqsat/NamedRewrite.scala index 1bbcefdab..804cff60e 100644 --- a/src/main/scala/rise/eqsat/NamedRewrite.scala +++ b/src/main/scala/rise/eqsat/NamedRewrite.scala @@ -9,18 +9,18 @@ object NamedRewrite { def init(name: String, rule: (NamedRewriteDSL.Pattern, NamedRewriteDSL.Pattern) ): Rewrite[DefaultAnalysisData] = { - import rise.core.DSL.infer.{preservingWithEnv, collectFreeEnv} + import rise.core.DSL.infer import arithexpr.{arithmetic => ae} val (lhs, rhs) = rule - val untypedFreeV = collectFreeEnv(lhs).map { case (name, t) => + val untypedFreeV = infer.collectFreeEnv(lhs).map { case (name, t) => assert(t == rct.TypePlaceholder) name -> rct.TypeIdentifier("t" + name) } - val typedLhs = preservingWithEnv(lhs, untypedFreeV, Set()) - val freeV = collectFreeEnv(typedLhs) - val (_, freeT) = rise.core.IsClosedForm.freeVars(typedLhs) - val typedRhs = preservingWithEnv(rc.TypeAnnotation(rhs, typedLhs.t), freeV, freeT) + val typedLhs = infer(lhs, untypedFreeV, Set()) + val freeV = infer.collectFreeEnv(typedLhs) + val freeT = rise.core.IsClosedForm.freeVars(typedLhs)._2.set + val typedRhs = infer(rc.TypeAnnotation(rhs, typedLhs.t), freeV, freeT) trait PatVarStatus case object Unknown extends PatVarStatus diff --git a/src/main/scala/shine/DPIA/fromRise.scala b/src/main/scala/shine/DPIA/fromRise.scala index 32be7922a..a362bf72f 100644 --- a/src/main/scala/shine/DPIA/fromRise.scala +++ b/src/main/scala/shine/DPIA/fromRise.scala @@ -16,7 +16,8 @@ import scala.collection.mutable object fromRise { def apply(expr: r.Expr)(implicit ev: Traversable[Rise]): Phrase[_ <: PhraseType] = { if (!r.IsClosedForm(expr)) { - throw new Exception(s"expression is not in closed form: $expr\n\n with type ${expr.t}") + val (fV, fT) = r.IsClosedForm.varsToClose(expr) + throw new Exception(s"expression is not in closed form: $expr\n\n with type ${expr.t}\n free vars: $fV\n free type vars: $fT\n\n") } val bnfExpr = normalize(ev).apply(betaReduction)(expr).get val rwMap = inferAccess(bnfExpr) diff --git a/src/test/scala/rise/core/dependentTypes.scala b/src/test/scala/rise/core/dependentTypes.scala index b4a99fea4..2694055ac 100644 --- a/src/test/scala/rise/core/dependentTypes.scala +++ b/src/test/scala/rise/core/dependentTypes.scala @@ -153,12 +153,13 @@ class dependentTypes extends test_util.Tests { function.asStringFromExpr(inferred) } - test("Simple nested") { + // Relies on the explicitly dependent work + ignore("Simple nested") { val e = depFun((n: Nat) => fun(n `*.` (i => (i+1) `.` f32))(array => depMapSeq(depFun((_: Nat) => mapSeq(fun(x => x))))(array) )) - val inferred: Expr = inferDependent(e) + val inferred: Expr = infer(e) logger.debug(inferred) logger.debug(inferred.t) assert(inferred.t =~= @@ -167,12 +168,13 @@ class dependentTypes extends test_util.Tests { function.asStringFromExpr(inferred) } - test("Simple reduce") { + // Relies on the explicitly dependent work + ignore("Simple reduce") { val e = depFun((n: Nat) => fun(n `*.` (i => (i+1) `.` f32))(array => depMapSeq(depFun((_: Nat) => reduceSeq(fun(x => fun(y => x + y)))(lf32(0.0f))))(array) )) - val inferred: Expr = inferDependent(e) + val inferred: Expr = infer(e) logger.debug(inferred) logger.debug(inferred.t) assert(inferred.t =~= @@ -201,7 +203,7 @@ class dependentTypes extends test_util.Tests { } )))) - val inferred: Expr = inferDependent(e) + val inferred: Expr = infer(e) logger.debug(inferred) logger.debug(inferred.t) function.asStringFromExpr(inferred)