Skip to content

Commit

Permalink
Rewrite ElabGuard
Browse files Browse the repository at this point in the history
Summary:
Several components to this change:
- merge `elabTest` and `elabTestT`, propagate an upper bound everywhere (possibly `AnyType`)
- gate a lot of "smart" logic behind condition `upper == trueType`
- restore smart elaboration of `not` disabled  in D58184266, gated behind upper bound condition
- smarter elaboration of `or` and `orelse`

Reviewed By: ilya-klyuchnikov

Differential Revision: D60389213

fbshipit-source-id: cefa2945c4d7f08b0d0f5ee57a5de969e6f7f364
  • Loading branch information
VLanvin authored and facebook-github-bot committed Jul 31, 2024
1 parent fe2b133 commit 9b8a196
Showing 1 changed file with 64 additions and 79 deletions.
143 changes: 64 additions & 79 deletions eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabGuard.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,24 +65,49 @@ final class ElabGuard(pipelineContext: PipelineContext) {
envAcc
}

private def elabTest(test: Test, env: Env): Env = {
def elabTestT(test: Test, upper: Type, env: Env): Env = {
test match {
case TestVar(v) =>
// safe because we assume no unbound vars
val ty = env.getOrElse(
v,
AnyType,
)
typeInfo.add(test.pos, ty)
env
val testType = env.get(v) match {
case Some(vt) =>
narrow.meet(vt, upper)
case None => upper
}
typeInfo.add(test.pos, testType)
env + (v -> testType)
case TestCall(Id(pred, 1), List(arg)) if upper == trueType && elabPredicateType1.isDefinedAt(pred) =>
elabTestT(arg, elabPredicateType1(pred), env)
case TestCall(Id(pred, 2), List(arg1, arg2))
if upper == trueType && elabPredicateType22.isDefinedAt((pred, arg2)) =>
elabTestT(arg1, elabPredicateType22(pred, arg2), env)
case TestCall(Id(pred, 2), List(arg1, arg2))
if upper == trueType && elabPredicateType21.isDefinedAt((pred, arg1)) =>
elabTestT(arg2, elabPredicateType21(pred, arg1), env)
case TestCall(Id(pred, 3), List(arg1, arg2, _))
if upper == trueType && elabPredicateType22.isDefinedAt((pred, arg2)) =>
elabTestT(arg1, elabPredicateType22(pred, arg2), env)
case TestBinOp("and" | "andalso", arg1, arg2) if upper == trueType =>
val env1 = elabTestT(arg1, trueType, env)
elabTestT(arg2, trueType, env1)
case TestBinOp("orelse", arg1, arg2) if upper == trueType =>
val envTrue = elabTestT(arg1, trueType, env)
val envFalse = elabTestT(arg1, falseType, env)
val envFalse2 = elabTestT(arg2, trueType, envFalse)
subtype.joinEnvs(List(envTrue, envFalse2))
case TestBinOp("or", arg1, arg2) if upper == trueType =>
val envTrue = elabTestT(arg1, trueType, env)
val envTrue2 = elabTestT(arg2, booleanType, envTrue)
val envFalse = elabTestT(arg1, booleanType, env)
val envFalse2 = elabTestT(arg2, trueType, envFalse)
subtype.joinEnvs(List(envTrue2, envFalse2))
case TestAtom(_) =>
env
case TestNumber(_) =>
env
case TestTuple(elems) =>
var envAcc: Env = env
for (elem <- elems) {
val elemEnv = elabTest(elem, envAcc)
val elemEnv = elabTestT(elem, AnyType, envAcc)
envAcc = elemEnv
}
envAcc
Expand All @@ -91,28 +116,28 @@ final class ElabGuard(pipelineContext: PipelineContext) {
case TestNil() =>
env
case TestCons(head, tail) =>
val env1 = elabTest(head, env)
val env2 = elabTest(tail, env1)
val env1 = elabTestT(head, AnyType, env)
val env2 = elabTestT(tail, AnyType, env1)
env2
case TestMapCreate(kvs) =>
var envAcc: Env = env
for ((k, v) <- kvs) {
val kEnv = elabTest(k, envAcc)
val vEnv = elabTest(v, kEnv)
val kEnv = elabTestT(k, AnyType, envAcc)
val vEnv = elabTestT(v, AnyType, kEnv)
envAcc = vEnv
}
envAcc
case TestCall(_, args) =>
var envAcc: Env = env
for (arg <- args) {
val argEnv = elabTest(arg, envAcc)
val argEnv = elabTestT(arg, AnyType, envAcc)
envAcc = argEnv
}
env
case unOp: TestUnOp =>
elabUnOp(unOp, env)
elabUnOp(unOp, upper, env)
case binOp: TestBinOp =>
elabBinOp(binOp, env)
elabBinOp(binOp, upper, env)
case TestBinaryLit() =>
env
case TestRecordIndex(_, _) =>
Expand Down Expand Up @@ -162,54 +187,12 @@ final class ElabGuard(pipelineContext: PipelineContext) {
}
}

def elabTestT(test: Test, upper: Type, env: Env): Env = {
test match {
case TestVar(v) =>
val testType = env.get(v) match {
case Some(vt) =>
narrow.meet(vt, upper)
case None => upper
}
typeInfo.add(test.pos, testType)
env + (v -> testType)
case TestCall(Id(pred, 1), List(arg)) if upper == trueType && elabPredicateType1.isDefinedAt(pred) =>
elabTestT(arg, elabPredicateType1(pred), env)
case TestCall(Id(pred, 2), List(arg1, arg2))
if upper == trueType && elabPredicateType22.isDefinedAt((pred, arg2)) =>
elabTestT(arg1, elabPredicateType22(pred, arg2), env)
case TestCall(Id(pred, 2), List(arg1, arg2))
if upper == trueType && elabPredicateType21.isDefinedAt((pred, arg1)) =>
elabTestT(arg2, elabPredicateType21(pred, arg1), env)
case TestCall(Id(pred, 3), List(arg1, arg2, _))
if upper == trueType && elabPredicateType22.isDefinedAt((pred, arg2)) =>
elabTestT(arg1, elabPredicateType22(pred, arg2), env)
case TestBinOp("and", arg1, arg2) =>
val env1 = elabTestT(arg1, AtomLitType("true"), env)
elabTestT(arg2, upper, env1)
case TestBinOp("andalso", arg1, arg2) =>
val env1 = elabTestT(arg1, AtomLitType("true"), env)
elabTestT(arg2, upper, env1)
case TestBinOp("orelse", arg1, arg2) =>
val envTrue = elabTestT(arg1, trueType, env)
val envFalse = elabTestT(arg2, upper, env)
subtype.joinEnvs(List(envTrue, envFalse))
case TestBinOp("or", arg1, arg2) =>
val env1 = elabTestT(arg1, booleanType, env)
// "or" is not short-circuiting
elabTestT(arg2, booleanType, env1)
case _ =>
elabTest(test, env)
}
}

def elabUnOp(unOp: TestUnOp, env: Env): Env = {
private def elabUnOp(unOp: TestUnOp, upper: Type, env: Env): Env = {
val TestUnOp(op, arg) = unOp
op match {
case "not" =>
arg match {
case TestVar(_) => elabTestT(arg, booleanType, env)
case _ => env
}
case "not" if upper == trueType => elabTestT(arg, falseType, env)
case "not" if upper == falseType => elabTestT(arg, trueType, env)
case "not" => elabTestT(arg, booleanType, env)
case "bnot" | "+" | "-" =>
elabTestT(arg, NumberType, env)
case _ =>
Expand All @@ -226,84 +209,86 @@ final class ElabGuard(pipelineContext: PipelineContext) {
}
}

private def elabComparison(binOp: TestBinOp, env: Env): Env =
private def elabComparison(binOp: TestBinOp, upper: Type, env: Env): Env =
binOp match {
case TestBinOp("=:=" | "==", TestVar(v), NumTest()) =>
case TestBinOp("=:=" | "==", TestVar(v), NumTest()) if upper == trueType =>
env.get(v) match {
case Some(ty) =>
env + (v -> narrow.meet(ty, NumberType))
case None =>
env
}
case TestBinOp("=:=" | "==", NumTest(), TestVar(v)) =>
case TestBinOp("=:=" | "==", NumTest(), TestVar(v)) if upper == trueType =>
env.get(v) match {
case Some(ty) =>
env + (v -> narrow.meet(ty, NumberType))
case None =>
env
}
case TestBinOp("=:=" | "==", TestVar(v), TestString()) =>
case TestBinOp("=:=" | "==", TestVar(v), TestString()) if upper == trueType =>
env.get(v) match {
case Some(ty) =>
env + (v -> narrow.meet(ty, stringType))
case None =>
env
}
case TestBinOp("=:=" | "==", TestString(), TestVar(v)) =>
case TestBinOp("=:=" | "==", TestString(), TestVar(v)) if upper == trueType =>
env.get(v) match {
case Some(ty) =>
env + (v -> narrow.meet(ty, stringType))
case None =>
env
}
case TestBinOp("=:=" | "==", TestVar(v), TestAtom(a)) =>
case TestBinOp("=:=" | "==", TestVar(v), TestAtom(a)) if upper == trueType =>
env.get(v) match {
case Some(ty) =>
env + (v -> narrow.meet(ty, AtomLitType(a)))
case None =>
env
}
case TestBinOp("=:=" | "==", TestAtom(a), TestVar(v)) =>
case TestBinOp("=:=" | "==", TestAtom(a), TestVar(v)) if upper == trueType =>
env.get(v) match {
case Some(ty) =>
env + (v -> narrow.meet(ty, AtomLitType(a)))
case None =>
env
}
case TestBinOp("=:=" | "==", TestCall(Id("element", 2), List(TestNumber(Some(i)), TestVar(v))), TestAtom(a)) =>
case TestBinOp("=:=" | "==", TestCall(Id("element", 2), List(TestNumber(Some(i)), TestVar(v))), TestAtom(a))
if upper == trueType =>
env.get(v) match {
case Some(ty) =>
env + (v -> narrow.filterTupleType(ty, i, AtomLitType(a)))
case None =>
env
}
case TestBinOp("=:=" | "==", TestAtom(a), TestCall(Id("element", 2), List(TestNumber(Some(i)), TestVar(v)))) =>
case TestBinOp("=:=" | "==", TestAtom(a), TestCall(Id("element", 2), List(TestNumber(Some(i)), TestVar(v))))
if upper == trueType =>
env.get(v) match {
case Some(ty) =>
env + (v -> narrow.filterTupleType(ty, i, AtomLitType(a)))
case None =>
env
}
case TestBinOp("=/=" | "/=", TestVar(v), TestAtom(a)) =>
case TestBinOp("=/=" | "/=", TestVar(v), TestAtom(a)) if upper == trueType =>
env.get(v) match {
case Some(ty) =>
env + (v -> occurrence.remove(ty, AtomLitType(a)))
case None =>
env
}
case TestBinOp("=/=" | "/=", TestAtom(a), TestVar(v)) =>
case TestBinOp("=/=" | "/=", TestAtom(a), TestVar(v)) if upper == trueType =>
env.get(v) match {
case Some(ty) =>
env + (v -> occurrence.remove(ty, AtomLitType(a)))
case None =>
env
}
case TestBinOp(_, arg1, arg2) =>
val env1 = elabTest(arg1, env)
elabTest(arg2, env1)
val env1 = elabTestT(arg1, AnyType, env)
elabTestT(arg2, AnyType, env1)
}

private def elabBinOp(binOp: TestBinOp, env: Env): Env = {
private def elabBinOp(binOp: TestBinOp, upper: Type, env: Env): Env = {
val TestBinOp(op, arg1, arg2) = binOp
op match {
case "/" | "*" | "-" | "+" | "div" | "rem" | "band" | "bor" | "bxor" | "bsl" | "bsr" =>
Expand All @@ -313,13 +298,13 @@ final class ElabGuard(pipelineContext: PipelineContext) {
val env1 = elabTestT(arg1, booleanType, env)
elabTestT(arg2, booleanType, env1)
case ">=" | ">" | "=<" | "<" | "/=" | "=/=" | "==" | "=:=" =>
elabComparison(binOp, env)
elabComparison(binOp, upper, env)
case "andalso" =>
val env1 = elabTestT(arg1, booleanType, env)
elabTest(arg2, env1)
elabTestT(arg2, upper, env1)
case "orelse" =>
val env1 = elabTestT(arg1, booleanType, env)
elabTest(arg2, env1)
elabTestT(arg2, upper, env1)
case _ =>
throw new IllegalStateException(s"unexpected $op")
}
Expand Down

0 comments on commit 9b8a196

Please sign in to comment.