Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

performance: parallelize diffgraph application #288

Merged
merged 15 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@ Code formatting is maintained via
sbt scalafmt Test/scalafmt
```

## Diverse notes
By default, diffgraph application, deserialization from storage, and serialization to storage are all multi-threaded.

This can be globally disabled via `flatgraph.Misc.force_singlethreaded()`, for easier debugging.

In order to quickly glance the input of flatgraph files, you can extract the manifest json with `tail`, e.g. `tail someGraph.fg | jless`:
bbrehm marked this conversation as resolved.
Show resolved Hide resolved
Our output writer always places the manifest at the end, with a bunch of preceding newlines, such that this will not contain binary garbage.

This is suitable for quick command-line debugging. However, that approach will fail if e.g. somebody appended two flatgraph files -- deserialization will
read the file from the beginning, and find the offset of the true manifest from the header, and ignore trailing garbage like an appended fake manifest.
So don't dare to do security checks with that!




## Core Features
- [x] Access nodes and neighbors
- [x] Add nodes and edges
Expand Down
454 changes: 258 additions & 196 deletions core/src/main/scala/flatgraph/DiffGraphApplier.scala

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions core/src/main/scala/flatgraph/Misc.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package flatgraph

import java.util.concurrent

object Misc {

@volatile var _overrideExecutor: Option[concurrent.ExecutorService] => concurrent.ExecutorService = defaultExecutorProvider

def force_singlethreaded(): Unit = {
// this one is magic -- it can get garbage collected, no manual shutdown required!
bbrehm marked this conversation as resolved.
Show resolved Hide resolved
mpollmeier marked this conversation as resolved.
Show resolved Hide resolved
this._overrideExecutor = (something: Option[concurrent.ExecutorService]) => concurrent.Executors.newSingleThreadExecutor()
}

def defaultExecutorProvider(requested: Option[concurrent.ExecutorService]): concurrent.ExecutorService = requested.getOrElse {
java.lang.Thread.currentThread() match {
case fjt: concurrent.ForkJoinWorkerThread => fjt.getPool
case _ => concurrent.ForkJoinPool.commonPool()
}
}

def maybeOverrideExecutor(requested: Option[concurrent.ExecutorService]): concurrent.ExecutorService =
this._overrideExecutor.apply(requested)
}
100 changes: 64 additions & 36 deletions core/src/main/scala/flatgraph/storage/Deserialization.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package flatgraph.storage

import com.github.luben.zstd.Zstd
import flatgraph.*
import flatgraph.{AccessHelpers, FreeSchema, GNode, Graph, Schema}
import flatgraph.Edge.Direction
import flatgraph.storage.Manifest.{GraphItem, OutlineStorage}

Expand All @@ -11,15 +10,30 @@ import java.nio.file.Path
import java.nio.{ByteBuffer, ByteOrder}
import java.util.Arrays
import scala.collection.mutable
import java.util.concurrent

object Deserialization {

def readGraph(storagePath: Path, schemaMaybe: Option[Schema], persistOnClose: Boolean = true): Graph = {
def readGraph(
storagePath: Path,
schemaMaybe: Option[Schema],
persistOnClose: Boolean = true,
requestedExecutor: Option[concurrent.ExecutorService] = None
): Graph = {
val executor = flatgraph.Misc.maybeOverrideExecutor(requestedExecutor)
val fileChannel = new java.io.RandomAccessFile(storagePath.toAbsolutePath.toFile, "r").getChannel
val queue = mutable.ArrayBuffer[concurrent.Future[Any]]()
val zstdCtx = new ZstdWrapper.ZstdCtx
def submitJob[T](block: => T): concurrent.Future[T] = {
val res = executor.submit((() => block))
queue.addOne(res.asInstanceOf[concurrent.Future[Any]])
res
}

try {
// fixme: Use convenience methods from schema to translate string->id. Fix after we get strict schema checking.
val manifest = GraphItem.read(readManifest(fileChannel))
val pool = readPool(manifest, fileChannel)
val pool = submitJob { readPool(manifest, fileChannel, zstdCtx) }
val schema = schemaMaybe.getOrElse(freeSchemaFromManifest(manifest))
val storagePathMaybe =
if (persistOnClose) Option(storagePath)
Expand Down Expand Up @@ -66,11 +80,17 @@ object Deserialization {
val direction = Direction.fromOrdinal(edgeItem.inout)
if (nodeKind.isDefined && edgeKind.isDefined) {
val pos = g.schema.neighborOffsetArrayIndex(nodeKind.get, direction, edgeKind.get)
g.neighbors(pos) = deltaDecode(readArray(fileChannel, edgeItem.qty, nodeRemapper, pool).asInstanceOf[Array[Int]])
g.neighbors(pos + 1) = readArray(fileChannel, edgeItem.neighbors, nodeRemapper, pool)
val property = readArray(fileChannel, edgeItem.property, nodeRemapper, pool)
if (property != null)
g.neighbors(pos + 2) = property
submitJob {
g.neighbors(pos) = deltaDecode(readArray(fileChannel, edgeItem.qty, nodeRemapper, pool, zstdCtx).asInstanceOf[Array[Int]])
}
submitJob {
g.neighbors(pos + 1) = readArray(fileChannel, edgeItem.neighbors, nodeRemapper, pool, zstdCtx)
}
submitJob {
val property = readArray(fileChannel, edgeItem.property, nodeRemapper, pool, zstdCtx)
if (property != null)
g.neighbors(pos + 2) = property
}
}
}

Expand All @@ -91,12 +111,18 @@ object Deserialization {
val propertyKind = propertykinds.get((property.nodeLabel, property.propertyLabel))
if (nodeKind.isDefined && propertyKind.isDefined) {
val pos = g.schema.propertyOffsetArrayIndex(nodeKind.get, propertyKind.get)
g.properties(pos) = deltaDecode(readArray(fileChannel, property.qty, nodeRemapper, pool).asInstanceOf[Array[Int]])
g.properties(pos + 1) = readArray(fileChannel, property.property, nodeRemapper, pool)
submitJob {
g.properties(pos) = deltaDecode(readArray(fileChannel, property.qty, nodeRemapper, pool, zstdCtx).asInstanceOf[Array[Int]])
}
submitJob { g.properties(pos + 1) = readArray(fileChannel, property.property, nodeRemapper, pool, zstdCtx) }
}
}
queue.foreach { _.get() }
g
} finally fileChannel.close()
} catch {
case ex: java.util.concurrent.ExecutionException =>
throw ex.getCause()
} finally { fileChannel.close(); zstdCtx.close(); }
}

private def freeSchemaFromManifest(manifest: Manifest.GraphItem): FreeSchema = {
Expand Down Expand Up @@ -171,23 +197,17 @@ object Deserialization {

}

private def readPool(manifest: GraphItem, fileChannel: FileChannel): Array[String] = {
val stringPoolLength = ZstdWrapper(
Zstd
.decompress(
fileChannel.map(FileChannel.MapMode.READ_ONLY, manifest.stringPoolLength.startOffset, manifest.stringPoolLength.compressedLength),
manifest.stringPoolLength.decompressedLength
)
.order(ByteOrder.LITTLE_ENDIAN)
)
val stringPoolBytes = ZstdWrapper(
Zstd
.decompress(
fileChannel.map(FileChannel.MapMode.READ_ONLY, manifest.stringPoolBytes.startOffset, manifest.stringPoolBytes.compressedLength),
manifest.stringPoolBytes.decompressedLength
)
.order(ByteOrder.LITTLE_ENDIAN)
)
private def readPool(manifest: GraphItem, fileChannel: FileChannel, zstdCtx: ZstdWrapper.ZstdCtx): Array[String] = {
val stringPoolLength = zstdCtx
.decompress(
fileChannel.map(FileChannel.MapMode.READ_ONLY, manifest.stringPoolLength.startOffset, manifest.stringPoolLength.compressedLength),
manifest.stringPoolLength.decompressedLength
)
val stringPoolBytes = zstdCtx
.decompress(
fileChannel.map(FileChannel.MapMode.READ_ONLY, manifest.stringPoolBytes.startOffset, manifest.stringPoolBytes.compressedLength),
manifest.stringPoolBytes.decompressedLength
)
val poolBytes = new Array[Byte](manifest.stringPoolBytes.decompressedLength)
stringPoolBytes.get(poolBytes)
val pool = new Array[String](manifest.stringPoolLength.decompressedLength >> 2)
Expand Down Expand Up @@ -215,11 +235,18 @@ object Deserialization {
a
}

private def readArray(channel: FileChannel, ptr: OutlineStorage, nodes: Array[Array[GNode]], stringPool: Array[String]): Array[?] = {
private def readArray(
channel: FileChannel,
ptr: OutlineStorage,
nodes: Array[Array[GNode]],
stringPoolFuture: concurrent.Future[Array[String]],
zstdCtx: ZstdWrapper.ZstdCtx
): Array[?] = {
if (ptr == null) return null
val dec = ZstdWrapper(
Zstd.decompress(channel.map(FileChannel.MapMode.READ_ONLY, ptr.startOffset, ptr.compressedLength), ptr.decompressedLength)
).order(ByteOrder.LITTLE_ENDIAN)
if (ptr.typ == StorageType.String) stringPoolFuture.get()

val dec =
zstdCtx.decompress(channel.map(FileChannel.MapMode.READ_ONLY, ptr.startOffset, ptr.compressedLength), ptr.decompressedLength)
ptr.typ match {
case StorageType.Bool =>
val bytes = new Array[Byte](dec.limit())
Expand Down Expand Up @@ -253,9 +280,10 @@ object Deserialization {
dec.asDoubleBuffer().get(res)
res
case StorageType.String =>
val res = new Array[String](dec.limit() >> 2)
val intbuf = dec.asIntBuffer()
var idx = 0
val stringPool = stringPoolFuture.get()
val res = new Array[String](dec.limit() >> 2)
val intbuf = dec.asIntBuffer()
var idx = 0
while (idx < res.length) {
val offset = intbuf.get(idx)
if (offset >= 0) res(idx) = stringPool(offset)
Expand Down
33 changes: 25 additions & 8 deletions core/src/main/scala/flatgraph/storage/Manifest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ object Manifest {
var nodes: Array[NodeItem],
var edges: Array[EdgeItem],
var properties: Array[PropertyItem],
val stringPoolLength: OutlineStorage,
val stringPoolBytes: OutlineStorage
val stringPoolLength: OutlineStorage = new OutlineStorage(StorageType.Int),
val stringPoolBytes: OutlineStorage = new OutlineStorage(StorageType.Byte)
) {
var version = 0
}
Expand Down Expand Up @@ -96,9 +96,9 @@ object Manifest {
val nodeLabel: String,
val edgeLabel: String,
val inout: Byte, // 0: Incoming, 1: Outgoing; see Edge.Direction enum
var qty: OutlineStorage,
var neighbors: OutlineStorage,
var property: OutlineStorage
var qty: OutlineStorage = new OutlineStorage,
var neighbors: OutlineStorage = new OutlineStorage,
var property: OutlineStorage = new OutlineStorage
) {
Edge.Direction.verifyEncodingRange(inout)
}
Expand All @@ -122,11 +122,20 @@ object Manifest {
}
}

class PropertyItem(val nodeLabel: String, val propertyLabel: String, var qty: OutlineStorage, var property: OutlineStorage)
class PropertyItem(
val nodeLabel: String,
val propertyLabel: String,
var qty: OutlineStorage = new OutlineStorage,
var property: OutlineStorage = new OutlineStorage
)

object OutlineStorage {
def write(item: OutlineStorage): ujson.Value = {
if (item == null) return ujson.Null
if (item.typ == null) {
assert(item.startOffset == -1L && item.compressedLength == -1 && item.decompressedLength == -1, s"bad OutlineStorage ${item}")
return ujson.Null
}
val res = ujson.Obj()
res(Keys.Type) = item.typ
res(Keys.StartOffset) = ujson.Num(item.startOffset.toDouble)
Expand All @@ -143,17 +152,25 @@ object Manifest {

def read(item: ujson.Value): OutlineStorage = {
if (item.isNull) return null
val res = new OutlineStorage(item.obj(Keys.Type).str)
val res = new OutlineStorage
res.typ = item.obj(Keys.Type).str
res.startOffset = item.obj(Keys.StartOffset).num.toLong
res.compressedLength = item.obj(Keys.CompressedLength).num.toInt
res.decompressedLength = item.obj(Keys.DecompressedLength).num.toInt
res
}
}

class OutlineStorage(var typ: String) {
class OutlineStorage {
var typ: String = null
var startOffset: Long = -1L
var compressedLength: Int = -1
var decompressedLength: Int = -1
def this(_typ: String) = {
this()
this.typ = _typ
}

override def toString: String = super.toString + s"($typ, $startOffset, $compressedLength, $decompressedLength)"
}
}
Loading
Loading