Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not extend NamedVar with Kind.Identifier in DPIA #190

Merged
merged 8 commits into from
Jun 21, 2021
17 changes: 16 additions & 1 deletion meta/src/main/scala/meta/generator/DPIAPrimitives.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import arithexpr.arithmetic._
import shine.DPIA.Phrases._
import shine.DPIA.Types.DataType._
import shine.DPIA.Types._
import shine.DPIA.Types.Kind.{Identifier => _, _}
import shine.DPIA._

${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), params, returnType)}
Expand Down Expand Up @@ -140,7 +141,7 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param
case DPIA.Type.AST.CommType => t"CommType"
case DPIA.Type.AST.PairType(lhs, rhs) => t"PhrasePairType[${generatePhraseType(lhs)}, ${generatePhraseType(rhs)}]"
case DPIA.Type.AST.FunType(inT, outT) => t"FunType[${generatePhraseType(inT)}, ${generatePhraseType(outT)}]"
case DPIA.Type.AST.DepFunType(id, kind, t) => t"DepFunType[${generateKindIdentifierType(kind)}, ${generatePhraseType(t)}]"
case DPIA.Type.AST.DepFunType(id, kind, t) => t"DepFunType[${generateKindIdentifierType(kind)}, ${generateKindIdentifierConstr(kind)}, ${generatePhraseType(t)}]"
case DPIA.Type.AST.Identifier(name) => Type.Name(name)
case DPIA.Type.AST.VariadicType(_, _) => throw new Exception("Can not generate Phrase Type for Variadic Type")
}
Expand All @@ -159,6 +160,20 @@ ${generateCaseClass(Type.Name(name), toParamList(definition, scalaParams), param
case DPIA.Kind.AST.VariadicKind(_, _) => throw new Exception("Can not generate Kind for Variadic Kind")
}

def generateKindIdentifierConstr(kindAST: DPIA.Kind.AST): scala.meta.Type = kindAST match {
case DPIA.Kind.AST.RiseKind(riseKind) => riseKind match {
case rise.Kind.AST.Data => Type.Name("IDataType")
case rise.Kind.AST.Address => Type.Name("IAddressSpace")
case rise.Kind.AST.Nat2Nat => Type.Name("INatToNat")
case rise.Kind.AST.Nat2Data => Type.Name("INatToData")
case rise.Kind.AST.Nat => Type.Name("INat")
case rise.Kind.AST.Fragment => throw new Exception("Can not generate Kind for Fragment")
case rise.Kind.AST.MatrixLayout => throw new Exception("Can not generate Kind for Matrix Layout")
}
case DPIA.Kind.AST.Access => Type.Name("IAccessType")
case DPIA.Kind.AST.VariadicKind(_, _) => throw new Exception("Can not generate Kind for Variadic Kind")
}

// generate type checks in the body of the generated case classes, e.g. for map:
// f :: FunType(expT(dt1, a), expT(dt2, a))
// array :: expT(ArrayType(n, dt1), a)
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/rise/core/types/Kinds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ sealed trait Kind[+T, +I, +KI <: Kind.Identifier] {


object Kind {
trait Identifier { def name: String }
sealed trait Identifier { def name: String }
case class IType(id : TypeIdentifier) extends Identifier { def name : String = id.name }
case class IDataType(id : DataTypeIdentifier) extends Identifier { def name : String = id.name }
case class INat(id : NatIdentifier) extends Identifier { def name : String = id.name }
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/shine/C/Compilation/CodeGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ class CodeGenerator(val decls: CodeGenerator.Declarations,
case Apply(fun, arg) => Lifting.liftFunction(fun).reducing(arg) |> cmd(env)
case DepApply(kind, fun, arg) => arg match {
case a: Nat =>
Lifting.liftDependentFunction(fun.asInstanceOf[Phrase[NatIdentifier `()->:` CommType]])(a) |> cmd(env)
Lifting.liftDependentFunction(fun.asInstanceOf[Phrase[`(nat)->:`[CommType]]])(a) |> cmd(env)
case a: DataType =>
Lifting.liftDependentFunction(fun.asInstanceOf[Phrase[NatIdentifier `()->:` CommType]])(a) |> cmd(env)
Lifting.liftDependentFunction(fun.asInstanceOf[Phrase[`(nat)->:`[CommType]]])(a) |> cmd(env)
}

case DMatchI(x, inT, _, f, dPair) =>
Expand Down Expand Up @@ -638,7 +638,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations,
case None => error("Parameter missing")
case Some(Left(param)) => generateInlinedCall(l(param), env, args.tail, cont)
}
case ndl: DepLambda[Nat, NatIdentifier, _]@unchecked => args.headOption match {
case ndl: DepLambda[Nat, NatIdentifier, Kind.INat, _]@unchecked => args.headOption match {
case Some(Right(nat)) => generateInlinedCall(ndl(nat), env, args.tail, cont)
case None => error("Parameter missing")
case Some(Left(_)) => error("Expression phrase argument passed but nat expected")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ object AcceptorTranslation {
case DepApply(kind, fun, arg) => arg match {
case a: Nat =>
acc(Lifting.liftDependentFunction(
fun.asInstanceOf[ Phrase[NatIdentifier `()->:` ExpType]])(a))(A)
fun.asInstanceOf[ Phrase[`(nat)->:`[ExpType]]])(a))(A)
case a: DataType =>
acc(Lifting.liftDependentFunction(
fun.asInstanceOf[Phrase[DataTypeIdentifier `()->:` ExpType]])(a))(A)
fun.asInstanceOf[Phrase[`(dt)->:`[ExpType]]])(a))(A)
}

case e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ object ContinuationTranslation {
case DepApply(kind, fun, arg) => arg match {
case a: Nat =>
con(Lifting.liftDependentFunction(
fun.asInstanceOf[Phrase[NatIdentifier `()->:` ExpType]])(a))(C)
fun.asInstanceOf[Phrase[`(nat)->:`[ExpType]]])(a))(C)
case a: DataType =>
con(Lifting.liftDependentFunction(
fun.asInstanceOf[Phrase[DataTypeIdentifier `()->:` ExpType]])(a))(C)
fun.asInstanceOf[Phrase[`(dt)->:`[ExpType]]])(a))(C)
}

case IfThenElse(cond, thenP, elseP) =>
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/shine/DPIA/Compilation/FedeTranslation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ object FedeTranslation {
case DepApply(kind, fun, arg) => arg match {
case a: Nat => fedAcc(env)(
Lifting.liftDependentFunction(
fun.asInstanceOf[Phrase[NatIdentifier `()->:` ExpType]])(a))(C)
fun.asInstanceOf[Phrase[`(nat)->:`[ExpType]]])(a))(C)
case a: DataType => fedAcc(env)(
Lifting.liftDependentFunction(
fun.asInstanceOf[Phrase[DataTypeIdentifier `()->:` ExpType]])(a))(C)
fun.asInstanceOf[Phrase[`(dt)->:`[ExpType]]])(a))(C)
}

case IfThenElse(cond, thenP, elseP) => ???
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/shine/DPIA/Compilation/FunDef.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class FunDef(val name: String,
splitBodyAndParams(Lifting.liftDependentFunction(f)(a), ps, defs)
case l: Lambda[ExpType, _]@unchecked =>
splitBodyAndParams(l.body, l.param +: ps, defs)
case ndl: DepLambda[_, _, _] =>
case ndl: DepLambda[_, _, _, _] =>
splitBodyAndParams(ndl.body,
Identifier(ndl.x.name, ExpType(int, read)) +: ps, defs)
Identifier(Kind.idName(ndl.kind, ndl.x), ExpType(int, read)) +: ps, defs)
case ln:LetNat[ExpType, _]@unchecked =>
splitBodyAndParams(ln.body, ps, (ln.binder, ln.defn) +: defs)
case ep: Phrase[ExpType]@unchecked => (ep, ps.reverse, defs.reverse)
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/shine/DPIA/Compilation/StreamTranslation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ object StreamTranslation {
case DepApply(_, fun, arg) => arg match {
case a: Nat => str(
Lifting.liftDependentFunction(
fun.asInstanceOf[Phrase[NatIdentifier `()->:` ExpType]])(a)
fun.asInstanceOf[Phrase[`(nat)->:`[ExpType]]])(a)
)(C)
case a: DataType => str(
Lifting.liftDependentFunction(
fun.asInstanceOf[Phrase[DataTypeIdentifier `()->:` ExpType]])(a)
fun.asInstanceOf[Phrase[`(dt)->:`[ExpType]]])(a)
)(C)
}

Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/shine/DPIA/DSL/Core.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ object λ extends funDef

object nFun {
def apply[T <: PhraseType](f: NatIdentifier => Phrase[T],
range: arithexpr.arithmetic.Range): DepLambda[Nat, NatIdentifier, T] = {
range: arithexpr.arithmetic.Range): DepLambda[Nat, NatIdentifier, Kind.INat, T] = {
val x = NatIdentifier(freshName("n"), range)
DepLambda(NatKind, x, f(x))
}
}

trait depFunDef {
def apply[T, I <: Kind.Identifier](kind: Kind[T, I]): Object {
def apply[U <: PhraseType](f: I => Phrase[U]): DepLambda[T, I, U]
def apply[T, I, KI <: Kind.Identifier](kind: Kind[T, I, KI]): Object {
def apply[U <: PhraseType](f: I => Phrase[U]): DepLambda[T, I, KI, U]
} = new {
def apply[U <: PhraseType](f: I => Phrase[U]): DepLambda[T, I, U] = {
def apply[U <: PhraseType](f: I => Phrase[U]): DepLambda[T, I, KI, U] = {
val x = kind.makeIdentifier
DepLambda(kind, x, f(x))
}
Expand Down
53 changes: 25 additions & 28 deletions src/main/scala/shine/DPIA/InferAccessAnnotation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ private class InferAccessAnnotation {
val depLambdaType =
depLambda.x match {
case n: rt.NatIdentifier =>
DepFunType(NatKind, natIdentifier(n), eType)
DepFunType(NatKind, n, eType)
case dt: rt.DataTypeIdentifier =>
DepFunType(DataKind, dataTypeIdentifier(dt), eType)
case ad: rt.AddressSpaceIdentifier =>
Expand Down Expand Up @@ -304,8 +304,8 @@ private class InferAccessAnnotation {
(gs1 `(Nat)->:` (gs2 `(Nat)->:` (gs3 `(Nat)->:`
((t: rt.DataType) ->: (_: rt.DataType))
))))) =>
nFunT(fromRise.natIdentifier(ls1), nFunT(fromRise.natIdentifier(ls2), nFunT(fromRise.natIdentifier(ls3),
nFunT(fromRise.natIdentifier(gs1), nFunT(fromRise.natIdentifier(gs2), nFunT(fromRise.natIdentifier(gs3),
nFunT(ls1, nFunT(ls2, nFunT(ls3,
nFunT(gs1, nFunT(gs2, nFunT(gs3,
expT(t, write) ->: expT(t, write)))))))
case _ => error()
}
Expand Down Expand Up @@ -350,7 +350,7 @@ private class InferAccessAnnotation {
case n `(Nat)->:` ((dt1: rt.DataType) ->: (dt2: rt.DataType)) =>

val ai = accessTypeIdentifier()
nFunT(fromRise.natIdentifier(n), expT(dt1, ai) ->: expT(dt2, ai))
nFunT(n, expT(dt1, ai) ->: expT(dt2, ai))
case _ => error()
}

Expand All @@ -377,7 +377,7 @@ private class InferAccessAnnotation {

case rp.natAsIndex() | rp.take() | rp.drop() => p.t match {
case n `(Nat)->:` ((dt1: rt.DataType) ->: (dt2: rt.DataType)) =>
nFunT(fromRise.natIdentifier(n), expT(dt1, read) ->: expT(dt2, read))
nFunT(n, expT(dt1, read) ->: expT(dt2, read))
case _ => error()
}

Expand Down Expand Up @@ -414,7 +414,7 @@ private class InferAccessAnnotation {
case tile `(Nat)->:`
(((s: rt.DataType) ->: (t: rt.DataType)) ->:
(inT: rt.ArrayType) ->: (outT: rt.ArrayType)) =>
nFunT(fromRise.natIdentifier(tile),
nFunT(tile,
(expT(s, read) ->: expT(t, write)) ->:
expT(inT, read) ->: expT(outT, write))
case _ => error()
Expand All @@ -425,7 +425,7 @@ private class InferAccessAnnotation {
case sz `(Nat)->:`
(((s: rt.DataType) ->: (_: rt.DataType)) ->:
(inT: rt.ArrayType) ->: (outT: rt.ArrayType)) =>
nFunT(fromRise.natIdentifier(sz),
nFunT(sz,
(expT(s, read) ->: expT(s, write)) ->:
expT(inT, read) ->: expT(outT, read))
case _ => error()
Expand All @@ -435,7 +435,7 @@ private class InferAccessAnnotation {
case alloc `(Nat)->:` (sz `(Nat)->:`
(((s: rt.DataType) ->: (t: rt.DataType)) ->:
(inT: rt.ArrayType) ->: (outT: rt.ArrayType))) =>
nFunT(fromRise.natIdentifier(alloc), nFunT(fromRise.natIdentifier(sz),
nFunT(alloc, nFunT(sz,
(expT(s, read) ->: expT(t, write)) ->:
expT(inT, read) ->: expT(outT, read)))
case _ => error()
Expand All @@ -446,7 +446,7 @@ private class InferAccessAnnotation {
(((s: rt.DataType) ->: (_: rt.DataType)) ->:
(inT: rt.ArrayType) ->: (outT: rt.ArrayType))) =>
aFunT(a,
nFunT(fromRise.natIdentifier(sz),
nFunT(sz,
(expT(s, read) ->: expT(s, write)) ->:
expT(inT, read) ->: expT(outT, read)))
case _ => error()
Expand All @@ -457,7 +457,7 @@ private class InferAccessAnnotation {
(((s: rt.DataType) ->: (t: rt.DataType)) ->:
(inT: rt.ArrayType) ->: (outT: rt.ArrayType)))) =>

aFunT(a, nFunT(fromRise.natIdentifier(alloc), nFunT(fromRise.natIdentifier(sz),
aFunT(a, nFunT(alloc, nFunT(sz,
(expT(s, read) ->: expT(t, write)) ->:
expT(inT, read) ->: expT(outT, read))))
case _ => error()
Expand All @@ -466,7 +466,7 @@ private class InferAccessAnnotation {
case rp.slide() | rp.padClamp() => p.t match {
case sz `(Nat)->:` (sp `(Nat)->:`
((dt1: rt.DataType) ->: (dt2: rt.DataType))) =>
nFunT(fromRise.natIdentifier(sz), nFunT(fromRise.natIdentifier(sp),
nFunT(sz, nFunT(sp,
expT(dt1, read) ->: expT(dt2, read)))
case _ => error()
}
Expand All @@ -475,8 +475,8 @@ private class InferAccessAnnotation {
case k `(Nat)->:`
((l `(Nat)->:` ((at1: rt.ArrayType) ->: (at2: rt.ArrayType))) ->:
(at3: rt.ArrayType) ->: (at4: rt.ArrayType)) =>
nFunT(fromRise.natIdentifier(k),
nFunT(fromRise.natIdentifier(l), expT(at1, read) ->: expT(at2, write)) ->:
nFunT(k,
nFunT(l, expT(at1, read) ->: expT(at2, write)) ->:
expT(at3, read) ->: expT(at4, write) )
case _ => error()
}
Expand All @@ -485,8 +485,8 @@ private class InferAccessAnnotation {
case a `(Addr)->:` (k `(Nat)->:`
((l `(Nat)->:` ((at1: rt.ArrayType) ->: (at2: rt.ArrayType))) ->:
(at3: rt.ArrayType) ->: (at4: rt.ArrayType))) =>
aFunT(a, nFunT(fromRise.natIdentifier(k),
nFunT(fromRise.natIdentifier(l), expT(at1, read) ->: expT(at2, write)) ->:
aFunT(a, nFunT(k,
nFunT(l, expT(at1, read) ->: expT(at2, write)) ->:
expT(at3, read) ->: expT(at4, write) ))
case _ => error()
}
Expand All @@ -502,15 +502,15 @@ private class InferAccessAnnotation {

case rp.padEmpty() => p.t match {
case r `(Nat)->:` ((n`.`t) ->: (_`.`_)) =>
nFunT(fromRise.natIdentifier(r), expT(n`.`t, write) ->: expT((n + r)`.`t, write))
nFunT(r, expT(n`.`t, write) ->: expT((n + r)`.`t, write))
case _ => error()
}

case rp.padCst() => p.t match {
case l `(Nat)->:` (q `(Nat)->:`
((t: rt.DataType) ->: (n`.`_) ->: (_`.`_))) =>

nFunT(fromRise.natIdentifier(l), nFunT(fromRise.natIdentifier(q),
nFunT(l, nFunT(q,
expT(t, read) ->: expT(n`.`t, read) ->:
expT((l + n + q)`.`t, read)))
case _ => error()
Expand All @@ -527,7 +527,7 @@ private class InferAccessAnnotation {
case (n `(Nat)->:` (idxF `(NatToNat)->:` (idxFinv `(NatToNat)->:` ((_`.`t) ->: (_`.`_) )))) =>

val ai = accessTypeIdentifier()
nFunT(fromRise.natIdentifier(n), n2nFunT(idxF, n2nFunT(idxFinv, expT(n`.`t, ai) ->: expT(n`.`t, ai))))
nFunT(n, n2nFunT(idxF, n2nFunT(idxFinv, expT(n`.`t, ai) ->: expT(n`.`t, ai))))
case _ => error()
}

Expand Down Expand Up @@ -556,8 +556,7 @@ private class InferAccessAnnotation {
def buildType(t: rt.Type): PhraseType = t match {
case rt.FunType(rt.DepFunType(rt.NatKind, i: rt.NatIdentifier, rt.FunType(elemInT:rt.DataType, elemOutT:rt.DataType)),
rt.FunType([email protected](_, _), [email protected](_, _))) =>
val iNat = natIdentifier(i)
nFunT(iNat, expT(dataType(elemInT), read) ->: expT(dataType(elemOutT), write)) ->:
nFunT(i, expT(dataType(elemInT), read) ->: expT(dataType(elemOutT), write)) ->:
expT(dataType(inArr), read) ->: expT(dataType(outArr), write)
case _ => error("did not expect t")
}
Expand All @@ -570,9 +569,8 @@ private class InferAccessAnnotation {
rt.FunType(rt.DepFunType(rt.NatKind, i: rt.NatIdentifier,
rt.FunType(app1:rt.DataType, outT:rt.DataType)), retT:rt.DataType)) =>

val i_ = natIdentifier(i.asInstanceOf[rt.NatIdentifier])
expT(DepPairType(natIdentifier(x), dataType(elemT)), read) ->:
nFunT(i_, expT(dataType(app1), read) ->: expT(dataType(outT), a)) ->:
expT(DepPairType(x, dataType(elemT)), read) ->:
nFunT(i, expT(dataType(app1), read) ->: expT(dataType(outT), a)) ->:
expT(dataType(retT), a)
case _ => error(s"did not expect t")
}
Expand All @@ -582,8 +580,7 @@ private class InferAccessAnnotation {
def buildType(t: rt.Type): PhraseType = t match {
case rt.DepFunType(rt.NatKind, fst: rt.NatIdentifier, rt.FunType(sndT:rt.DataType, outT:rt.DataType)) =>
val a1 = accessTypeIdentifier()
val fst_ = natIdentifier(fst)
nFunT(fst_, expT(dataType(sndT), a1) ->: expT(dataType(outT), a1))
nFunT(fst, expT(dataType(sndT), a1) ->: expT(dataType(outT), a1))

case _ => error(s"did not expect $t")
}
Expand Down Expand Up @@ -667,7 +664,7 @@ private class InferAccessAnnotation {
case dt: rt.DataTypeIdentifier =>
dataTypeIdentifier(dt) ->: `type`(t)
case n: rt.NatIdentifier =>
natIdentifier(n) ->: `type`(t)
n ->: `type`(t)
case n2n: rt.NatToNatIdentifier =>
natToNatIdentifier(n2n) ->: `type`(t)
case n2d: rt.NatToDataIdentifier =>
Expand All @@ -681,8 +678,8 @@ private class InferAccessAnnotation {
case (rt.FunType(inT, outT), FunType(inPT, outPT)) =>
checkConsistency(inT, inPT)
checkConsistency(outT, outPT)
case (rt.DepFunType(k, x, t), DepFunType(_, y, pt)) =>
if (rt.Kind.idName(k, x) != y.name) error(s"Identifiers $x and $y differ")
case (rt.DepFunType(kx, x, t), DepFunType(ky, y, pt)) =>
if (rt.Kind.idName(kx, x) != Kind.idName(ky, y)) error(s"Identifiers $x and $y differ")
checkConsistency(t, pt)
case (dt: rt.DataType, ExpType(dpt: DataType, _)) =>

Expand Down
12 changes: 6 additions & 6 deletions src/main/scala/shine/DPIA/Lifting.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ import scala.language.{postfixOps, reflectiveCalls}
object Lifting {
import rise.core.lifting.{Expanding, Reducing, Result}

def liftDependentFunction[T, I <: Kind.Identifier, U <: PhraseType](p: Phrase[I `()->:` U]): T => Phrase[U] = {
def liftDependentFunction[T, I, KI <: Kind.Identifier, U <: PhraseType](p: Phrase[DepFunType[I, KI, U]]): T => Phrase[U] = {
p match {
case l: DepLambda[T, I, U]@unchecked =>
(arg: T) => PhraseType.substitute[T, I, U](l.kind, arg, `for`=l.x, in=l.body)
case app: Apply[_, I `()->:` U] =>
case l: DepLambda[T, I, KI, U]@unchecked =>
(arg: T) => PhraseType.substitute[T, I, KI, U](l.kind, arg, `for`=l.x, in=l.body)
case app: Apply[_, DepFunType[I, KI, U]] =>
val fun = liftFunction(app.fun).reducing
liftDependentFunction(fun(app.arg))
case DepApply(_, f, arg) =>
val fun = liftDependentFunction(f)
liftDependentFunction(fun(arg))
case p1: Proj1[I `()->:` U, b] =>
case p1: Proj1[DepFunType[I, KI, U], b] =>
val pair = liftPair(p1.pair)
liftDependentFunction(pair._1)
case p2: Proj2[a, I `()->:` U] =>
case p2: Proj2[a, DepFunType[I, KI, U]] =>
val pair = liftPair(p2.pair)
liftDependentFunction(pair._2)
case Identifier(_, _) | IfThenElse(_, _, _) | LetNat(_, _, _) =>
Expand Down
Loading