From 17f2c27e4124f4ea094542e1551ad087261ea3c6 Mon Sep 17 00:00:00 2001 From: Scott Guest Date: Fri, 1 Dec 2023 16:03:22 +0200 Subject: [PATCH] Various improvements to `POSet` (#3834) This PR makes various changes to `POSet` which were necessary during the development of the new type inference algorithm: - Change Java wrappers to use `Collections.mutable`/`Collections.immutable` rather than `asScala`/`asJava` to avoid an issue where the returned Java `Set` is not mutable - Add `minimalElements` and `maximalElements` methods - Optimize the `upperBounds` computation to iterate over the provided `sorts` rather than filtering keys from the `relations` map Co-authored-by: rv-jenkins --- .../kframework/compile/AddSortInjections.java | 2 +- .../src/main/scala/org/kframework/POSet.scala | 37 ++++++++++--------- .../test/scala/org/kframework/POSetTest.scala | 20 +++++----- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/kernel/src/main/java/org/kframework/compile/AddSortInjections.java b/kernel/src/main/java/org/kframework/compile/AddSortInjections.java index ae2db921a28..ea10a19af4c 100644 --- a/kernel/src/main/java/org/kframework/compile/AddSortInjections.java +++ b/kernel/src/main/java/org/kframework/compile/AddSortInjections.java @@ -510,7 +510,7 @@ private static Sort lub( Set nonParametric = filteredEntries.stream().filter(s -> s.params().isEmpty()).collect(Collectors.toSet()); - Set bounds = mutable(mod.subsorts().upperBounds(immutable(nonParametric))); + Set bounds = mod.subsorts().upperBounds(nonParametric); // Anything less than KBott or greater than K is a syntactic sort from kast.md which should not // be considered bounds.removeIf( diff --git a/kore/src/main/scala/org/kframework/POSet.scala b/kore/src/main/scala/org/kframework/POSet.scala index f39fa63421f..73467b0e117 100644 --- a/kore/src/main/scala/org/kframework/POSet.scala +++ b/kore/src/main/scala/org/kframework/POSet.scala @@ -94,21 +94,27 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { def upperBounds(sorts: Iterable[T]): Set[T] = if (sorts.isEmpty) elements else POSet.upperBounds(sorts, relations) + def upperBounds(sorts: util.Collection[T]): util.Set[T] = + Collections.mutable(upperBounds(Collections.immutable(sorts))) + /** * Return the set of all lower bounds of the input. */ def lowerBounds(sorts: Iterable[T]): Set[T] = if (sorts.isEmpty) elements else POSet.upperBounds(sorts, relationsOp) - lazy val lub: Option[T] = { - val mins = minimal(upperBounds(elements)) - if (mins.size == 1) Some(mins.head) else None - } + def lowerBounds(sorts: util.Collection[T]): util.Set[T] = + Collections.mutable(lowerBounds(Collections.immutable(sorts))) - lazy val glb: Option[T] = { - val maxs = maximal(lowerBounds(elements)) - if (maxs.size == 1) Some(maxs.head) else None - } + lazy val minimalElements: Set[T] = minimal(elements) + + lazy val maximalElements: Set[T] = maximal(elements) + + lazy val maximum: Option[T] = + if (maximalElements.size == 1) Some(maximalElements.head) else None + + lazy val minimum: Option[T] = + if (minimalElements.size == 1) Some(minimalElements.head) else None lazy val asOrdering: Ordering[T] = (x: T, y: T) => if (lessThanEq(x, y)) -1 else if (lessThanEq(y, x)) 1 else 0 @@ -119,10 +125,8 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { def maximal(sorts: Iterable[T]): Set[T] = sorts.filter(s1 => !sorts.exists(s2 => lessThan(s1,s2))).toSet - def maximal(sorts: util.Collection[T]): util.Set[T] = { - import scala.collection.JavaConverters._ - maximal(sorts.asScala).asJava - } + def maximal(sorts: util.Collection[T]): util.Set[T] = + Collections.mutable(maximal(Collections.immutable(sorts))) /** * Return the subset of items from the argument which are not @@ -131,10 +135,8 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable { def minimal(sorts: Iterable[T]): Set[T] = sorts.filter(s1 => !sorts.exists(s2 => >(s1,s2))).toSet - def minimal(sorts: util.Collection[T]): util.Set[T] = { - import scala.collection.JavaConverters._ - minimal(sorts.asScala).asJava - } + def minimal(sorts: util.Collection[T]): util.Set[T] = + Collections.mutable(minimal(Collections.immutable(sorts))) override def toString: String = { "POSet(" + (relations flatMap { case (from, tos) => tos map { to => from + "<" + to } }).mkString(",") + ")" @@ -165,6 +167,5 @@ object POSet { * using the provided relations map. Input must be non-empty. */ private def upperBounds[T](sorts: Iterable[T], relations: Map[T, Set[T]]): Set[T] = - (((sorts filterNot relations.keys.toSet[T]) map {Set.empty + _}) ++ - ((relations filterKeys sorts.toSet) map { case (k, v) => v + k })) reduce { (a, b) => a & b } + sorts map { s => relations.getOrElse(s, Set.empty) + s } reduce { (a, b) => a & b } } diff --git a/kore/src/test/scala/org/kframework/POSetTest.scala b/kore/src/test/scala/org/kframework/POSetTest.scala index c2a5dfaf6ff..cf6f909585b 100644 --- a/kore/src/test/scala/org/kframework/POSetTest.scala +++ b/kore/src/test/scala/org/kframework/POSetTest.scala @@ -39,18 +39,18 @@ class POSetTest { } @Test def lub() { - assertEquals(Some(b2), POSet(b1 -> b2).lub) - assertEquals(Some(b3), POSet(b1 -> b3, b2 -> b3).lub) - assertEquals(Some(b4), POSet(b1 -> b3, b2 -> b3, b3 -> b4).lub) - assertEquals(None, POSet(b1 -> b2, b2 -> b3, b4 -> b5).lub) - assertEquals(None, POSet(b1 -> b2, b2 -> b3, b2 -> b4).lub) + assertEquals(Some(b2), POSet(b1 -> b2).maximum) + assertEquals(Some(b3), POSet(b1 -> b3, b2 -> b3).maximum) + assertEquals(Some(b4), POSet(b1 -> b3, b2 -> b3, b3 -> b4).maximum) + assertEquals(None, POSet(b1 -> b2, b2 -> b3, b4 -> b5).maximum) + assertEquals(None, POSet(b1 -> b2, b2 -> b3, b2 -> b4).maximum) } @Test def glb() { - assertEquals(Some(b2), POSet(b2 -> b1).glb) - assertEquals(Some(b3), POSet(b3 -> b1, b3 -> b2).glb) - assertEquals(Some(b4), POSet(b3 -> b1, b3 -> b2, b4 -> b3).glb) - assertEquals(None, POSet(b2 -> b1, b3 -> b2, b5 -> b4).glb) - assertEquals(None, POSet(b2 -> b1, b3 -> b2, b4 -> b2).glb) + assertEquals(Some(b2), POSet(b2 -> b1).minimum) + assertEquals(Some(b3), POSet(b3 -> b1, b3 -> b2).minimum) + assertEquals(Some(b4), POSet(b3 -> b1, b3 -> b2, b4 -> b3).minimum) + assertEquals(None, POSet(b2 -> b1, b3 -> b2, b5 -> b4).minimum) + assertEquals(None, POSet(b2 -> b1, b3 -> b2, b4 -> b2).minimum) } }