Skip to content

Commit

Permalink
Add custom eqWAlizer type checking for maps:merge
Browse files Browse the repository at this point in the history
Summary:
`maps:merge` currently loses all type information of the maps being merged. On this diff I'm adding custom handling of that function to deal with more map manipulation cases in WASERVER.

EqWAlizer already had some nice utilities to merge map types, but the existing implementation was a bit confusing and it the merging logic didn't assume any overwrite behavior (in `maps:merge`, when there are conflicting keys the right-ahnd side map's values are used). I expanded the existing merge logic to support this case, making the code more readable (I hope?) in the process.

Reviewed By: VLanvin

Differential Revision: D59677724

fbshipit-source-id: d0664130eda5e607a92aa110b221d83db95f6a19
  • Loading branch information
ruippeixotog authored and facebook-github-bot committed Jul 22, 2024
1 parent 3ff553c commit d6e6421
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 25 deletions.
2 changes: 2 additions & 0 deletions eqwalizer/src/main/resources/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ eqwalizer {
clause_coverage = ${?EQWALIZER_CLAUSE_COVERAGE}
overloaded_spec_dynamic_result = false
overloaded_spec_dynamic_result = ${?EQWALIZER_OVERLOADED_SPEC_DYNAMIC_RESULT}
custom_maps_merge = false
custom_maps_merge = ${?EQWALIZER_CUSTOM_MAPS_MERGE}
}
2 changes: 2 additions & 0 deletions eqwalizer/src/main/scala/com/whatsapp/eqwalizer/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ package object eqwalizer {
checkRedundantGuards: Boolean,
clauseCoverage: Boolean,
overloadedSpecDynamicResult: Boolean,
customMapsMerge: Boolean,
mode: Mode.Mode,
errorDepth: Int,
) {
Expand Down Expand Up @@ -80,6 +81,7 @@ package object eqwalizer {
checkRedundantGuards = config.getBoolean("check_redundant_guards"),
clauseCoverage = config.getBoolean("clause_coverage"),
overloadedSpecDynamicResult = config.getBoolean("overloaded_spec_dynamic_result"),
customMapsMerge = config.getBoolean("custom_maps_merge"),
mode,
errorDepth = config.getInt("error_depth"),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ class ElabApplyCustom(pipelineContext: PipelineContext) {
RemoteId("maps", "with", 2),
RemoteId("maps", "without", 2),
RemoteId(CompilerMacro.fake_module, "record_info", 2),
)
) ++ experimentalCustom

private def experimentalCustom: Set[RemoteId] =
if (pipelineContext.customMapsMerge) Set(RemoteId("maps", "merge", 2))
else Set()

private lazy val customPredicate: Set[RemoteId] =
Set(RemoteId("lists", "member", 2))
Expand Down Expand Up @@ -458,6 +462,15 @@ class ElabApplyCustom(pipelineContext: PipelineContext) {
val pairTys = narrow.asMapType(coerce(args.head, argTys.head, anyMapTy))
(mapToList(pairTys), env1)

case RemoteId("maps", "merge", 2) =>
val List(mapTy1, mapTy2) = args
.zip(argTys)
.map { case (arg, ty) => narrow.asMapType(coerce(arg, ty, anyMapTy)) }
val resMapTys = util
.cartesianProduct(mapTy1, mapTy2)
.map { case (ty1, ty2) => narrow.joinAndMergeShapes(List(ty1, ty2), true) }
(subtype.join(resMapTys), env1)

case RemoteId("maps", "with", 2) =>
@tailrec
def toKey(ty: Type)(implicit pos: Pos): Option[String] = ty match {
Expand Down
48 changes: 24 additions & 24 deletions eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Narrow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import com.whatsapp.eqwalizer.ast.Forms.RecDeclTyped
import com.whatsapp.eqwalizer.ast.{RemoteId, TypeVars}
import com.whatsapp.eqwalizer.ast.Types._

import scala.collection.mutable

class Narrow(pipelineContext: PipelineContext) {
private val subtype = pipelineContext.subtype
private val util = pipelineContext.util
Expand Down Expand Up @@ -606,30 +604,32 @@ class Narrow(pipelineContext: PipelineContext) {
case _ => None
}

private def mergeShapes(s1: ShapeMap, s2: ShapeMap): ShapeMap = {
var optProps = mutable.HashMap.empty[String, Type]
var reqProps = mutable.HashMap.empty[String, Type]
val commonReqKeys = s1.props.collect { case ReqProp(key, _) => key }.toSet & s2.props.collect {
case ReqProp(key, _) => key
}.toSet
for (p <- s1.props.toSet ++ s2.props.toSet) {
p match {
case OptProp(key, tp) => optProps.updateWith(key)(ty => Some(subtype.join(tp, ty.getOrElse(NoneType))))
case ReqProp(key, tp) =>
if (commonReqKeys.contains(key)) {
reqProps.updateWith(key)(ty => Some(subtype.join(tp, ty.getOrElse(NoneType))))
} else {
optProps.updateWith(key)(ty => Some(subtype.join(tp, ty.getOrElse(NoneType))))
}
}
private def mergeShapes(s1: ShapeMap, s2: ShapeMap, inOrder: Boolean): ShapeMap = {
ShapeMap {
(s1.props ++ s2.props)
.groupBy(_.key)
.values
.map {
// prop is only defined in one of the maps
case List(p) if inOrder => p
case List(p) => OptProp(p.key, p.tp)
// prop is optional on both sides
case List(OptProp(key, tp1), OptProp(_, tp2)) => OptProp(key, subtype.join(tp1, tp2))
// prop is required on both sides
case List(ReqProp(key, tp1), ReqProp(_, tp2)) if inOrder => ReqProp(key, tp2)
case List(ReqProp(key, tp1), ReqProp(_, tp2)) => ReqProp(key, subtype.join(tp1, tp2))
// prop is required on one side and optional on the other
case List(OptProp(key, tp1), ReqProp(_, tp2)) if inOrder => ReqProp(key, tp2)
case List(ReqProp(key, tp1), OptProp(_, tp2)) if inOrder => ReqProp(key, subtype.join(tp1, tp2))
case List(p1, p2) => OptProp(p1.key, subtype.join(p1.tp, p2.tp))

case _ => throw new IllegalStateException()
}
.toList
}
val allProps = optProps.map { case (k, t) => OptProp(k, t) }.toList ++ reqProps.map { case (k, t) =>
ReqProp(k, t)
}.toList
ShapeMap(allProps)
}

def joinAndMergeShapes(tys: Iterable[Type]): Type = {
def joinAndMergeShapes(tys: Iterable[Type], inOrder: Boolean = false): Type = {
val (shapes, notShapes) = tys.partition {
case s: ShapeMap => true
case _ => false
Expand All @@ -640,7 +640,7 @@ class Narrow(pipelineContext: PipelineContext) {
joinedNotShapes
} else {
subtype.join(
shapesCoerced.tail.foldLeft(shapesCoerced.head)((acc, shape) => mergeShapes(acc, shape)),
shapesCoerced.tail.foldLeft(shapesCoerced.head)((acc, shape) => mergeShapes(acc, shape, inOrder)),
joinedNotShapes,
)
}
Expand Down
8 changes: 8 additions & 0 deletions eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,12 @@ class Util(pipelineContext: PipelineContext) {
case BoundedDynamicType(_) => true
case _ => false
}

def cartesianProduct(ty1: Type, ty2: Type): List[(Type, Type)] = {
def expand(ty: Type): List[Type] = ty match {
case UnionType(tys) => tys.toList.flatMap(expand)
case ty => List(ty)
}
for (t1 <- expand(ty1); t2 <- expand(ty2)) yield (t1, t2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,6 @@ package object tc {
options.errorDepth.getOrElse(config.errorDepth)
val clauseCoverage: Boolean = config.clauseCoverage
val overloadedSpecDynamicResult: Boolean = config.overloadedSpecDynamicResult
val customMapsMerge: Boolean = config.customMapsMerge
}
}
2 changes: 2 additions & 0 deletions eqwalizer/src/test/resources/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ eqwalizer {
clause_coverage = ${?EQWALIZER_CLAUSE_COVERAGE}
overloaded_spec_dynamic_result = false
overloaded_spec_dynamic_result = ${?EQWALIZER_OVERLOADED_SPEC_DYNAMIC_RESULT}
custom_maps_merge = true
custom_maps_merge = ${?EQWALIZER_CUSTOM_MAPS_MERGE}
}
30 changes: 30 additions & 0 deletions eqwalizer/test_projects/check/src/custom.erl
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,36 @@ maps_to_list_6(M) -> maps:to_list(M).
-spec maps_to_list_7_neg(number()) -> dynamic().
maps_to_list_7_neg(Num) -> maps:to_list(Num).

-spec maps_merge_1(#{a => string(), b => number()}, #{b => number(), c => atom()}) ->
#{a => string(), b => number(), c => atom()}.
maps_merge_1(M1, M2) -> maps:merge(M1, M2).

-spec maps_merge_2(#{a => string(), b => number()}, #{b => boolean(), c => atom()}) ->
#{a => string(), b => number() | boolean(), c => atom()}.
maps_merge_2(M1, M2) -> maps:merge(M1, M2).

-spec maps_merge_3(#{a := string(), b => number()}, #{b := boolean(), c => atom()}) ->
#{a := string(), b := boolean(), c => atom()}.
maps_merge_3(M1, M2) -> maps:merge(M1, M2).

-spec maps_merge_4(#{a => string(), b => number()}, #{atom() => boolean()}) ->
#{atom() => boolean() | string() | number()}.
maps_merge_4(M1, M2) -> maps:merge(M1, M2).

-spec maps_merge_5(#{string() => number()}, #{atom() => boolean()}) ->
#{string() | atom() => boolean() | number()}.
maps_merge_5(M1, M2) -> maps:merge(M1, M2).

-spec maps_merge_6(#{a => binary()}, map()) -> map().
maps_merge_6(M1, M2) -> maps:merge(M1, M2).

-spec maps_merge_7_neg(#{a => binary()}, number()) -> term().
maps_merge_7_neg(M1, M2) -> maps:merge(M1, M2).

-spec maps_merge_8(#{a := atom()}, #{b := number()} | #{}) ->
#{a := atom(), b := number()} | #{a := atom()}.
maps_merge_8(M1, M2) -> maps:merge(M1, M2).

-spec lists_filtermap_1() -> [number()].
lists_filtermap_1() ->
lists:filtermap(
Expand Down
32 changes: 32 additions & 0 deletions eqwalizer/test_projects/check/src/custom.erl.check
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,38 @@ maps_to_list_7_neg(Num) -> maps:to_list(Nu…… ERROR | Num.
| | Expression has type: number()
| | Context expected type: #D{term() => term()}
| |
-spec maps_merge_1(#{a => string(), b => n…… |
#{a => string(), b => number(), c => a…… |
maps_merge_1(M1, M2) -> maps:merge(M1, M2)…… OK |
| |
-spec maps_merge_2(#{a => string(), b => n…… |
#{a => string(), b => number() | boole…… |
maps_merge_2(M1, M2) -> maps:merge(M1, M2)…… OK |
| |
-spec maps_merge_3(#{a := string(), b => n…… |
#{a := string(), b := boolean(), c => …… |
maps_merge_3(M1, M2) -> maps:merge(M1, M2)…… OK |
| |
-spec maps_merge_4(#{a => string(), b => n…… |
#{atom() => boolean() | string() | num…… |
maps_merge_4(M1, M2) -> maps:merge(M1, M2)…… OK |
| |
-spec maps_merge_5(#{string() => number()}…… |
#{string() | atom() => boolean() | num…… |
maps_merge_5(M1, M2) -> maps:merge(M1, M2)…… OK |
| |
-spec maps_merge_6(#{a => binary()}, map()…… |
maps_merge_6(M1, M2) -> maps:merge(M1, M2)…… OK |
| |
-spec maps_merge_7_neg(#{a => binary()}, n…… |
maps_merge_7_neg(M1, M2) -> maps:merge(M1,…… ERROR | M2.
| | Expression has type: number()
| | Context expected type: #D{term() => term()}
| |
-spec maps_merge_8(#{a := atom()}, #{b := …… |
#{a := atom(), b := number()} | #{a :=…… |
maps_merge_8(M1, M2) -> maps:merge(M1, M2)…… OK |
| |
-spec lists_filtermap_1() -> [number()]. | |
lists_filtermap_1() -> | OK |
lists:filtermap( | |
Expand Down

0 comments on commit d6e6421

Please sign in to comment.