Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Circ #318

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions argon/src/argon/Ref.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ abstract class ExpType[+C:ClassTag,A](implicit protected[argon] val evRef: A <:<
* 4. The result of an operation of type S (has a defining node)
*/
sealed trait Exp[+C,+A] extends Equals { self =>
type R = A@uV
type L = C@uV

private[argon] var _tp: ExpType[C@uV,A@uV] = _
Expand Down
24 changes: 21 additions & 3 deletions src/spatial/Spatial.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ trait Spatial extends Compiler with ParamLoader {
lazy val printer = IRPrinter(state, enable = config.enDbg)
lazy val finalIRPrinter = IRPrinter(state, enable = true)

// --- Desugaring
lazy val circDesugaring = CircDesugaring(state)

// --- Checking
lazy val userSanityChecks = UserSanityChecks(state, enable = !spatialConfig.allowInsanity)
lazy val transformerChecks = CompilerSanityChecks(state, enable = spatialConfig.enLog && !spatialConfig.allowInsanity)
Expand Down Expand Up @@ -139,12 +142,12 @@ trait Spatial extends Compiler with ParamLoader {
/** More black box lowering */
(blackboxLowering2) ==> printer ==> transformerChecks ==>
/** DSE */
((spatialConfig.enableArchDSE) ? paramAnalyzer) ==>
((spatialConfig.enableArchDSE) ? paramAnalyzer) ==>
/** Optional scala model generator */
((spatialConfig.enableRuntimeModel) ? retimingAnalyzer) ==>
((spatialConfig.enableRuntimeModel) ? initiationAnalyzer) ==>
((spatialConfig.enableRuntimeModel) ? dseRuntimeModelGen) ==>
(spatialConfig.enableArchDSE ? dsePass) ==>
(spatialConfig.enableArchDSE ? dsePass) ==>
//blackboxLowering ==> printer ==> transformerChecks ==>
switchTransformer ==> printer ==> transformerChecks ==>
switchOptimizer ==> printer ==> transformerChecks ==>
Expand All @@ -157,8 +160,23 @@ trait Spatial extends Compiler with ParamLoader {
/** Dead code elimination */
useAnalyzer ==>
transientCleanup ==> printer ==> transformerChecks ==>
/** #################################################################### */
/** FIXME: Desugaring */
circDesugaring ==> printer ==> transformerChecks ==>
switchTransformer ==> printer ==> transformerChecks ==>
switchOptimizer ==> printer ==> transformerChecks ==>
memoryDealiasing ==> printer ==> transformerChecks ==>
((!spatialConfig.vecInnerLoop) ? laneStaticTransformer) ==> printer ==>
/** Control insertion */
pipeInserter ==> printer ==> transformerChecks ==>
/** CSE on regs */
regReadCSE ==>
/** Dead code elimination */
useAnalyzer ==>
transientCleanup ==> printer ==> transformerChecks ==>
/** #################################################################### */
/** Stream controller rewrites */
(spatialConfig.distributeStreamCtr ? streamTransformer) ==> printer ==>
(spatialConfig.distributeStreamCtr ? streamTransformer) ==> printer ==>
/** Memory analysis */
retimingAnalyzer ==>
accessAnalyzer ==>
Expand Down
3 changes: 3 additions & 0 deletions src/spatial/lang/Aliases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ trait ExternalAliases extends InternalAliases {

// --- Primitives

type Circ[A,B] = spatial.lang.Circ[A,B]
lazy val Circ = spatial.lang.Circ

type Counter[F] = spatial.lang.Counter[F]
lazy val Counter = spatial.lang.Counter

Expand Down
97 changes: 97 additions & 0 deletions src/spatial/lang/Circ.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package spatial.lang

import argon._
import forge.tags._
import spatial.node._

@ref class Circ[A:Bits,B:Bits] extends Ref[Any,Circ[A,B]] {
override protected val __neverMutable: Boolean = true

private var numApps = 0
def getNumApps: Int = numApps

@api def apply(x: A): B = {
val id = numApps
numApps += 1
stage(CircApply(this,id,x))
}

// TODO: we need access to the underlying function
// @api def func: A => B = ...
// @api def duplicate(): Circ[A,B] = Circ(func())
}

object Circ {
@api def apply[A:Bits,B:Bits](func: A => B, factory: CircExecutorFactory[A,B]): Circ[A,B] = {
stage(CircNew(func, factory))
}

@api def apply[A:Bits,B:Bits](func: A => B): Circ[A,B] = {
Circ(func, PriorityCircExecutorFactory(8, 8))
}
}

abstract class CircExecutorFactory[A:Bits,B:Bits] {
type Executor <: CircExecutor[A,B]
def stageExecutor(nApps: Int, func: A => B): Executor
}

abstract class CircExecutor[A:Bits,B:Bits] {
def stageDone(appId: Int): Void
def stageEnq(appId: Int, data: A): Void
def stageDeq(appId: Int): B
}

case class PriorityCircExecutorFactory[A:Bits,B:Bits](
input_fifo_depth: Int,
output_fifo_depth: Int,
)(
implicit IR: State
) extends CircExecutorFactory[A,B] with Mirrorable[PriorityCircExecutorFactory[A,B]] {
override def mirror(f: Tx): PriorityCircExecutorFactory[A,B] = this

type Id = I32
type Input = Tup2[Id, A]

private val Id = I32
private def none: Input = Tup2(Id(-1), Bits[A].zero)
private def some(i: Id, a: A): Input = Tup2(i, a)
private def id(i: Input): Id = i._1
private def data(i: Input): A = i._2

case class PriorityCircExecutor(
kill: Reg[Bit],
inputs: Seq[FIFO[Input]],
outputs: Seq[FIFO[B]],
) extends CircExecutor[A,B] {
override def stageDone(appId: Int): Void = inputs(appId).enq(none)
override def stageEnq(appId: Int, data: A): Void = inputs(appId).enq(some(Id(appId), data))
override def stageDeq(appId: Int): B = outputs(appId).deq()
}

override type Executor = PriorityCircExecutor

override def stageExecutor(nApps: Int, func: A => B): Executor = {
val kill = Reg[Bit](Bit(false))
val inputs = Range(0, nApps).map(_ => FIFO[Input](input_fifo_depth))
val outputs = Range(0, nApps).map(_ => FIFO[B](output_fifo_depth))
val count = Reg[Id](0)
val executor = PriorityCircExecutor(kill, inputs, outputs)

Sequential(breakWhen = kill).Foreach(*) { _ =>
val input = priorityDeq(inputs: _*)
val output = func(data(input))
outputs.zipWithIndex foreach {
case (fifo, idx) =>
val writeEnable = Id(idx) === id(input)
fifo.enq(output, writeEnable)
}
retimeGate()
val newCount = count.value + mux(id(input) === Id(-1), Id(1), Id(0))
count.write(newCount)
kill.write(true, newCount === Id(nApps))
}

executor
}
}
22 changes: 22 additions & 0 deletions src/spatial/node/Circ.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package spatial.node

import argon.node._
import forge.tags._
import spatial.lang._

@op case class CircNew[_A:Bits,_B:Bits](
func: _A => _B,
factory: CircExecutorFactory[_A,_B]
) extends Alloc[Circ[_A,_B]] {
type A = _A
type B = _B
val evA: Bits[A] = implicitly[Bits[A]]
val evB: Bits[B] = implicitly[Bits[B]]
}

@op case class CircApply[_A:Bits,_B:Bits](circ: Circ[_A,_B], id: Int, arg: _A) extends Primitive[_B] {
type A = _A
type B = _B
val evA: Bits[A] = implicitly[Bits[A]]
val evB: Bits[B] = implicitly[Bits[B]]
}
Loading