Skip to content

Commit

Permalink
Various improvements to POSet (#3834)
Browse files Browse the repository at this point in the history
This PR makes various changes to `POSet` which were necessary during the
development of the new type inference algorithm:

- Change Java wrappers to use
`Collections.mutable`/`Collections.immutable` rather than
`asScala`/`asJava` to avoid an issue where the returned Java `Set` is
not mutable
- Add `minimalElements` and `maximalElements` methods
- Optimize the `upperBounds` computation to iterate over the provided
`sorts` rather than filtering keys from the `relations` map

Co-authored-by: rv-jenkins <[email protected]>
  • Loading branch information
Scott-Guest and rv-jenkins authored Dec 1, 2023
1 parent 5c76463 commit 17f2c27
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ private static Sort lub(

Set<Sort> nonParametric =
filteredEntries.stream().filter(s -> s.params().isEmpty()).collect(Collectors.toSet());
Set<Sort> bounds = mutable(mod.subsorts().upperBounds(immutable(nonParametric)));
Set<Sort> bounds = mod.subsorts().upperBounds(nonParametric);
// Anything less than KBott or greater than K is a syntactic sort from kast.md which should not
// be considered
bounds.removeIf(
Expand Down
37 changes: 19 additions & 18 deletions kore/src/main/scala/org/kframework/POSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,27 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable {
def upperBounds(sorts: Iterable[T]): Set[T] =
if (sorts.isEmpty) elements else POSet.upperBounds(sorts, relations)

def upperBounds(sorts: util.Collection[T]): util.Set[T] =
Collections.mutable(upperBounds(Collections.immutable(sorts)))

/**
* Return the set of all lower bounds of the input.
*/
def lowerBounds(sorts: Iterable[T]): Set[T] =
if (sorts.isEmpty) elements else POSet.upperBounds(sorts, relationsOp)

lazy val lub: Option[T] = {
val mins = minimal(upperBounds(elements))
if (mins.size == 1) Some(mins.head) else None
}
def lowerBounds(sorts: util.Collection[T]): util.Set[T] =
Collections.mutable(lowerBounds(Collections.immutable(sorts)))

lazy val glb: Option[T] = {
val maxs = maximal(lowerBounds(elements))
if (maxs.size == 1) Some(maxs.head) else None
}
lazy val minimalElements: Set[T] = minimal(elements)

lazy val maximalElements: Set[T] = maximal(elements)

lazy val maximum: Option[T] =
if (maximalElements.size == 1) Some(maximalElements.head) else None

lazy val minimum: Option[T] =
if (minimalElements.size == 1) Some(minimalElements.head) else None

lazy val asOrdering: Ordering[T] = (x: T, y: T) => if (lessThanEq(x, y)) -1 else if (lessThanEq(y, x)) 1 else 0

Expand All @@ -119,10 +125,8 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable {
def maximal(sorts: Iterable[T]): Set[T] =
sorts.filter(s1 => !sorts.exists(s2 => lessThan(s1,s2))).toSet

def maximal(sorts: util.Collection[T]): util.Set[T] = {
import scala.collection.JavaConverters._
maximal(sorts.asScala).asJava
}
def maximal(sorts: util.Collection[T]): util.Set[T] =
Collections.mutable(maximal(Collections.immutable(sorts)))

/**
* Return the subset of items from the argument which are not
Expand All @@ -131,10 +135,8 @@ class POSet[T](val directRelations: Set[(T, T)]) extends Serializable {
def minimal(sorts: Iterable[T]): Set[T] =
sorts.filter(s1 => !sorts.exists(s2 => >(s1,s2))).toSet

def minimal(sorts: util.Collection[T]): util.Set[T] = {
import scala.collection.JavaConverters._
minimal(sorts.asScala).asJava
}
def minimal(sorts: util.Collection[T]): util.Set[T] =
Collections.mutable(minimal(Collections.immutable(sorts)))

override def toString: String = {
"POSet(" + (relations flatMap { case (from, tos) => tos map { to => from + "<" + to } }).mkString(",") + ")"
Expand Down Expand Up @@ -165,6 +167,5 @@ object POSet {
* using the provided relations map. Input must be non-empty.
*/
private def upperBounds[T](sorts: Iterable[T], relations: Map[T, Set[T]]): Set[T] =
(((sorts filterNot relations.keys.toSet[T]) map {Set.empty + _}) ++
((relations filterKeys sorts.toSet) map { case (k, v) => v + k })) reduce { (a, b) => a & b }
sorts map { s => relations.getOrElse(s, Set.empty) + s } reduce { (a, b) => a & b }
}
20 changes: 10 additions & 10 deletions kore/src/test/scala/org/kframework/POSetTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ class POSetTest {
}

@Test def lub() {
assertEquals(Some(b2), POSet(b1 -> b2).lub)
assertEquals(Some(b3), POSet(b1 -> b3, b2 -> b3).lub)
assertEquals(Some(b4), POSet(b1 -> b3, b2 -> b3, b3 -> b4).lub)
assertEquals(None, POSet(b1 -> b2, b2 -> b3, b4 -> b5).lub)
assertEquals(None, POSet(b1 -> b2, b2 -> b3, b2 -> b4).lub)
assertEquals(Some(b2), POSet(b1 -> b2).maximum)
assertEquals(Some(b3), POSet(b1 -> b3, b2 -> b3).maximum)
assertEquals(Some(b4), POSet(b1 -> b3, b2 -> b3, b3 -> b4).maximum)
assertEquals(None, POSet(b1 -> b2, b2 -> b3, b4 -> b5).maximum)
assertEquals(None, POSet(b1 -> b2, b2 -> b3, b2 -> b4).maximum)
}

@Test def glb() {
assertEquals(Some(b2), POSet(b2 -> b1).glb)
assertEquals(Some(b3), POSet(b3 -> b1, b3 -> b2).glb)
assertEquals(Some(b4), POSet(b3 -> b1, b3 -> b2, b4 -> b3).glb)
assertEquals(None, POSet(b2 -> b1, b3 -> b2, b5 -> b4).glb)
assertEquals(None, POSet(b2 -> b1, b3 -> b2, b4 -> b2).glb)
assertEquals(Some(b2), POSet(b2 -> b1).minimum)
assertEquals(Some(b3), POSet(b3 -> b1, b3 -> b2).minimum)
assertEquals(Some(b4), POSet(b3 -> b1, b3 -> b2, b4 -> b3).minimum)
assertEquals(None, POSet(b2 -> b1, b3 -> b2, b5 -> b4).minimum)
assertEquals(None, POSet(b2 -> b1, b3 -> b2, b4 -> b2).minimum)
}
}

0 comments on commit 17f2c27

Please sign in to comment.