Skip to content

Commit

Permalink
Custom type checking of maps:intersect
Browse files Browse the repository at this point in the history
Summary: Add custom type-checking of `maps:intersect`

Reviewed By: ilya-klyuchnikov

Differential Revision: D65605653

fbshipit-source-id: f52dd5fc82815cd881f72d5a7169ce3c38d02c29
  • Loading branch information
VLanvin authored and facebook-github-bot committed Nov 8, 2024
1 parent 0ddd302 commit 24c7788
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ElabApplyCustom(pipelineContext: PipelineContext) {
RemoteId("maps", "fold", 3),
RemoteId("maps", "get", 2),
RemoteId("maps", "get", 3),
RemoteId("maps", "intersect", 2),
RemoteId("maps", "map", 2),
RemoteId("maps", "put", 3),
RemoteId("maps", "remove", 2),
Expand Down Expand Up @@ -267,6 +268,17 @@ class ElabApplyCustom(pipelineContext: PipelineContext) {
val resTy = UnionType(Set(TupleType(List(AtomLitType("ok"), valTy)), AtomLitType("error")))
(resTy, env1)

case RemoteId("maps", "intersect", 2) =>
val List(map1, map2) = args
val List(ty1, ty2) = argTys
val coercedTy1 = coerce(map1, ty1, anyMapTy)
val coercedTy2 = coerce(map2, ty2, anyMapTy)
val mapTy1 = narrow.asMapType(coercedTy1)
val mapTy2 = narrow.asMapType(coercedTy2)
val reqKeys = narrow.getKeyType(mapTy1)(reqOnly = true)
val allKeys = narrow.getKeyType(mapTy1)
(narrow.selectKeys(reqKeys, allKeys, mapTy2), env1)

case RemoteId("maps", "fold", 3) =>
val List(funArg, _acc, collection) = args
val List(funArgTy, accTy1, collectionTy) = argTys
Expand Down
17 changes: 16 additions & 1 deletion eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Narrow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ class Narrow(pipelineContext: PipelineContext) {
case _ => NoneType
}

def getKeyType(t: Type): Type =
def getKeyType(t: Type)(implicit reqOnly: Boolean = false): Type =
t match {
case MapType(props, _, _) if reqOnly => subtype.join(props.filter(_._2.req).keySet.map(Key.asType))
case MapType(props, kType, _) if props.isEmpty => kType
case MapType(props, kType, _) => subtype.join(kType, UnionType(props.keySet.map(Key.asType)))
case UnionType(ts) => subtype.join(ts.map(getKeyType))
Expand Down Expand Up @@ -204,6 +205,20 @@ class Narrow(pipelineContext: PipelineContext) {
NoneType
}

def selectKeys(reqKeyT: Type, optKeyT: Type, mapT: Type): Type =
mapT match {
case MapType(props, kType, vType) =>
val selectProps = props.collect {
case (key, Prop(true, tp)) if subtype.subType(Key.asType(key), reqKeyT) => (key, Prop(req = true, tp))
case (key, Prop(_, tp)) if subtype.subType(Key.asType(key), optKeyT) => (key, Prop(req = false, tp))
}
MapType(selectProps, meet(kType, optKeyT), vType)
case UnionType(ts) =>
subtype.join(ts.map(selectKeys(reqKeyT, optKeyT, _)))
case _ =>
NoneType
}

private def extractListElem(t: Type): List[Type] =
t match {
case DynamicType =>
Expand Down

0 comments on commit 24c7788

Please sign in to comment.