diff --git a/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabApplyCustom.scala b/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabApplyCustom.scala index 38e092b..6680743 100644 --- a/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabApplyCustom.scala +++ b/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabApplyCustom.scala @@ -46,6 +46,7 @@ class ElabApplyCustom(pipelineContext: PipelineContext) { RemoteId("maps", "fold", 3), RemoteId("maps", "get", 2), RemoteId("maps", "get", 3), + RemoteId("maps", "intersect", 2), RemoteId("maps", "map", 2), RemoteId("maps", "put", 3), RemoteId("maps", "remove", 2), @@ -267,6 +268,17 @@ class ElabApplyCustom(pipelineContext: PipelineContext) { val resTy = UnionType(Set(TupleType(List(AtomLitType("ok"), valTy)), AtomLitType("error"))) (resTy, env1) + case RemoteId("maps", "intersect", 2) => + val List(map1, map2) = args + val List(ty1, ty2) = argTys + val coercedTy1 = coerce(map1, ty1, anyMapTy) + val coercedTy2 = coerce(map2, ty2, anyMapTy) + val mapTy1 = narrow.asMapType(coercedTy1) + val mapTy2 = narrow.asMapType(coercedTy2) + val reqKeys = narrow.getKeyType(mapTy1)(reqOnly = true) + val allKeys = narrow.getKeyType(mapTy1) + (narrow.selectKeys(reqKeys, allKeys, mapTy2), env1) + case RemoteId("maps", "fold", 3) => val List(funArg, _acc, collection) = args val List(funArgTy, accTy1, collectionTy) = argTys diff --git a/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Narrow.scala b/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Narrow.scala index 428b882..d82b84b 100644 --- a/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Narrow.scala +++ b/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Narrow.scala @@ -160,8 +160,9 @@ class Narrow(pipelineContext: PipelineContext) { case _ => NoneType } - def getKeyType(t: Type): Type = + def getKeyType(t: Type)(implicit reqOnly: Boolean = false): Type = t match { + case MapType(props, _, _) if reqOnly => subtype.join(props.filter(_._2.req).keySet.map(Key.asType)) case MapType(props, kType, _) if props.isEmpty => kType case MapType(props, kType, _) => subtype.join(kType, UnionType(props.keySet.map(Key.asType))) case UnionType(ts) => subtype.join(ts.map(getKeyType)) @@ -204,6 +205,20 @@ class Narrow(pipelineContext: PipelineContext) { NoneType } + def selectKeys(reqKeyT: Type, optKeyT: Type, mapT: Type): Type = + mapT match { + case MapType(props, kType, vType) => + val selectProps = props.collect { + case (key, Prop(true, tp)) if subtype.subType(Key.asType(key), reqKeyT) => (key, Prop(req = true, tp)) + case (key, Prop(_, tp)) if subtype.subType(Key.asType(key), optKeyT) => (key, Prop(req = false, tp)) + } + MapType(selectProps, meet(kType, optKeyT), vType) + case UnionType(ts) => + subtype.join(ts.map(selectKeys(reqKeyT, optKeyT, _))) + case _ => + NoneType + } + private def extractListElem(t: Type): List[Type] = t match { case DynamicType =>