diff --git a/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabApplyCustom.scala b/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabApplyCustom.scala index 574c525..1a8063e 100644 --- a/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabApplyCustom.scala +++ b/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabApplyCustom.scala @@ -288,11 +288,12 @@ class ElabApplyCustom(pipelineContext: PipelineContext) { if (!subtype.subType(mapTy, anyMapTy)) throw ExpectedSubtype(map.pos, map, expected = anyMapTy, got = mapTy) val mapType = narrow.asMapType(mapTy) - keyTy match { - case AtomLitType(key) => - val valTy = narrow.getValType(key, mapType) - (valTy, env1) - case _ => + val atomKeys = narrow.asAtomLits(keyTy) + atomKeys match { + case Some(atoms) => + val valTys = atoms.map(narrow.getValType(_, mapType)) + (subtype.join(valTys), env1) + case None => val valTy = narrow.getValType(mapType) (valTy, env1) } @@ -334,11 +335,12 @@ class ElabApplyCustom(pipelineContext: PipelineContext) { if (!subtype.subType(mapTy, anyMapTy)) throw ExpectedSubtype(map.pos, map, expected = anyMapTy, got = mapTy) val mapType = narrow.asMapType(mapTy) - keyTy match { - case AtomLitType(key) => - val valTy = narrow.getValType(key, mapType) - (subtype.join(valTy, defaultValTy), env1) - case _ => + val atomKeys = narrow.asAtomLits(keyTy) + atomKeys match { + case Some(atoms) => + val valTys = atoms.map(narrow.getValType(_, mapType)) + (subtype.join(valTys + defaultValTy), env1) + case None => val valTy = narrow.getValType(mapType) (subtype.join(valTy, defaultValTy), env1) } diff --git a/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Narrow.scala b/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Narrow.scala index b8b0d81..56355be 100644 --- a/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Narrow.scala +++ b/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Narrow.scala @@ -538,4 +538,20 @@ class Narrow(pipelineContext: PipelineContext) { throw new IllegalStateException() } } + + // Recursion is sound since we don't unfold under constructors + def asAtomLits(t: Type): Option[Set[String]] = + t match { + case AtomLitType(s) => Some(Set(s)) + case BoundedDynamicType(bound) => + asAtomLits(bound) + case UnionType(ts) => + ts.foldLeft[Option[Set[String]]](Some(Set())) { (acc, ty) => + acc.flatMap(atoms => asAtomLits(ty).map(atoms2 => atoms ++ atoms2)) + } + case RemoteType(rid, args) => + val body = util.getTypeDeclBody(rid, args) + asAtomLits(body) + case _ => None + } }