Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Graph Partition #1291

Open
wants to merge 10 commits into
base: dev/3.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public record BufferRenderInfo(string Name, string ElemType, ulong Offset, ulong
{
}

public record KernelMainModel(TIR.PrimFunction PrimFunction, TIR.Buffer[] RDataBuffers, CpuTargetOptions Options, ulong Alignment, ulong DataSize, ulong RDataSize)
public record KernelMainModel(TIR.PrimFunction PrimFunction, TIR.Buffer[] RDataBuffers, CpuTargetOptions Options, ulong Alignment, ulong DataSize, ulong RDataSize, ulong LocalRdataPoolSize)
{
public BufferRenderInfo GetInfo(TIR.Buffer buffer)
{
Expand Down Expand Up @@ -64,9 +64,9 @@ public static string CMakeDef(string name)
return content;
}

public static string MakeMain(TIR.PrimFunction primFunction, ulong dataAlign, ulong dataUsage, ulong rdataPoolSize, IEnumerable<TIR.Buffer> rdataBuffers, CpuTargetOptions options)
public static string MakeMain(TIR.PrimFunction primFunction, ulong dataAlign, ulong dataUsage, ulong rdataPoolSize, ulong localRdataPoolSize, IEnumerable<TIR.Buffer> rdataBuffers, CpuTargetOptions options)
{
var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/thread_main.cpp.cshtml", new KernelMainModel(primFunction, rdataBuffers.ToArray(), options, dataAlign, dataUsage, rdataPoolSize)).Result;
var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/thread_main.cpp.cshtml", new KernelMainModel(primFunction, rdataBuffers.ToArray(), options, dataAlign, dataUsage, rdataPoolSize, localRdataPoolSize)).Result;
return content;
}

Expand Down
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public unsafe ILinkableFunction Build(TIR.PrimFunction function)
}

// 3. build function.
var visitor = new KernelCSourceConvertVisitor(function.SchedResult.DataAlign, function.SchedResult.DataUsage, rdataPoolSize, TargetOptions);
var visitor = new KernelCSourceConvertVisitor(function.SchedResult.DataAlign, function.SchedResult.DataUsage, rdataPoolSize, localRdataPoolSize, TargetOptions);
visitor.Visit(function);
var functionCSource = visitor.GetCSource();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,12 @@ internal sealed class KernelCSourceConvertVisitor : ExprFunctor<CSymbol, Unit>,
private readonly StringWriter _sharedWriter;
private ulong _collective_pool_size;

public KernelCSourceConvertVisitor(ulong dataAlign, ulong dataUsage, ulong rdataPoolSize, CpuTargetOptions targetOptions)
public KernelCSourceConvertVisitor(ulong dataAlign, ulong dataUsage, ulong rdataPoolSize, ulong localRdataPoolSize, CpuTargetOptions targetOptions)
{
DataAlign = dataAlign;
DataUsage = dataUsage;
RdataPoolSize = rdataPoolSize;
LocalRdataPoolSize = localRdataPoolSize;
_kernelBuilder = new StringBuilder();
_sharedBuilder = new StringBuilder();
_sharedWriter = new StringWriter(_sharedBuilder);
Expand All @@ -145,11 +146,13 @@ public KernelCSourceConvertVisitor(ulong dataAlign, ulong dataUsage, ulong rdata

public ulong RdataPoolSize { get; }

public ulong LocalRdataPoolSize { get; }

public KernelCSource GetCSource()
{
var ctype = $"void {VisitEntry.Name}({string.Join(", ", VisitEntry.Parameters.AsValueEnumerable().Select(Visit).Select(s => $"{s.Type} {s.Name}").ToArray().Concat(_exprMemo.Keys.OfType<TIR.Buffer>().Where(b => b.MemSpan.Location is MemoryLocation.Rdata or MemoryLocation.ThreadLocalRdata).Select(Visit).Select(s => $" {s.Type} {s.Name}").ToArray()))}, uint8_t* data)";
return new(
CSourceBuiltn.MakeMain(VisitEntry, DataAlign, DataUsage, RdataPoolSize, _exprMemo.Keys.OfType<TIR.Buffer>().Where(b => b.MemSpan.Location is MemoryLocation.Rdata or MemoryLocation.ThreadLocalRdata), TargetOptions),
CSourceBuiltn.MakeMain(VisitEntry, DataAlign, DataUsage, RdataPoolSize, LocalRdataPoolSize, _exprMemo.Keys.OfType<TIR.Buffer>().Where(b => b.MemSpan.Location is MemoryLocation.Rdata or MemoryLocation.ThreadLocalRdata), TargetOptions),
CSourceBuiltn.MakeKernel(ctype, _kernelBuilder.ToString()),
CSourceBuiltn.TopoAwareRuntimeDef(TargetOptions, DataAlign, _collective_pool_size),
CSourceBuiltn.TopologyDef(TargetOptions));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char** argv) {
}

std::byte* rdata = (std::byte *)nncase::ntt::runtime::thread_alloc(@Model.RDataSize, align);
std::byte* local_rdata = (std::byte *)nncase::ntt::runtime::thread_alloc(@Model.LocalRdataPoolSize, align);
uint64_t local_rdata_header[@Model.Options.Hierarchies[0][^1] * 2];
for (size_t tid = 0; tid < tdim(); tid++) {
local_rdata_header[tid * 2] = tid * ( @Model.LocalRdataPoolSize / tdim());
}

#ifdef __APPLE__
pthread_key_t cpu_thread_context_key_ = {};
Expand All @@ -73,7 +78,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char** argv) {
std::vector<std::thread> blocks;
for (size_t cid = 0; cid < cdim(); cid++) {
for (size_t bid = 0; bid < bdim(); bid++) {
blocks.emplace_back([cid, bid, inputs, rdata
blocks.emplace_back([cid, bid, inputs, rdata, local_rdata_header, local_rdata
#ifdef __APPLE__
, &cpu_thread_context_key_
#endif
Expand All @@ -87,6 +92,8 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char** argv) {
.cpu_id_offset = (cid * bdim() + bid) * tdim(),
.inouts = inputs,
.rdata = rdata,
.local_rdata_header = local_rdata_header,
.local_rdata = local_rdata,
#ifdef __APPLE__
.cpu_thread_context_key = cpu_thread_context_key_,
#endif
Expand Down
7 changes: 6 additions & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace Nncase.Evaluator.TIR.CPU;

public sealed class BinaryEvaluator : ITypeInferencer<Binary>, IKernelInfoEvaluator<Binary>
public sealed class BinaryEvaluator : ITypeInferencer<Binary>, IKernelInfoEvaluator<Binary>, IOpPrinter<Binary>
{
public IRType Visit(ITypeInferenceContext context, Binary target)
{
Expand All @@ -29,6 +29,11 @@ public MicroKernelInfo Visit(Binary op, MicroKernelContext context)
return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle);
}

public string Visit(IIRPrinterContext context, Binary target, bool iLmode)
{
return $"Binary({target.DisplayProperty()}, {context.GetArgument(target, Binary.Lhs)}, {context.GetArgument(target, Binary.Rhs)}, {context.GetArgument(target, Binary.Output)})";
}

private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context)
{
var factora = System.Math.Min(context.BufferShapes[0][^1], 32);
Expand Down
7 changes: 6 additions & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace Nncase.Evaluator.TIR.CPU;

public sealed class UnaryEvaluator : ITypeInferencer<Unary>, IKernelInfoEvaluator<Unary>
public sealed class UnaryEvaluator : ITypeInferencer<Unary>, IKernelInfoEvaluator<Unary>, IOpPrinter<Unary>
{
public IRType Visit(ITypeInferenceContext context, Unary target)
{
Expand All @@ -31,6 +31,11 @@ public MicroKernelInfo Visit(Unary op, MicroKernelContext context)
return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle);
}

public string Visit(IIRPrinterContext context, Unary target, bool iLmode)
{
return $"Unary({target.DisplayProperty()}, {context.GetArgument(target, Unary.Input)}, {context.GetArgument(target, Unary.Output)})";
}

private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context)
{
var factor = System.Math.Min(context.BufferShapes[0][^1], 32);
Expand Down
Loading
Loading