Skip to content

Commit

Permalink
Merge branch 'dev/3.0' into feature/tile_graph_fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 authored Jan 14, 2025
2 parents 1124db7 + 1443a1f commit 403a577
Show file tree
Hide file tree
Showing 75 changed files with 2,368 additions and 1,441 deletions.
9 changes: 2 additions & 7 deletions modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,8 @@ public static class CSourceBuiltn

public static string TopoAwareRuntimeDef(CpuTargetOptions options, ulong dataAlign, ulong collective_pool_size)
{
if (options.Hierarchies[0].Any(i => i > 1))
{
var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/topo_aware_runtime.cshtml", new CpuTargetOptionsModel(options, dataAlign, collective_pool_size)).Result;
return content;
}

return string.Empty;
var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/topo_aware_runtime.cshtml", new CpuTargetOptionsModel(options, dataAlign, collective_pool_size)).Result;
return content;
}

public static string TopologyDef(CpuTargetOptions options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ protected override CSymbol VisitCall(Call expr)

break;
case TIR.Memcopy copy:
IndentScope.Writer.Write($"tensor_copy({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[1], local: true).Name});\n");
IndentScope.Writer.Write($"tensor_copy({VisitBuffer(args[1], local: true).Name}, {VisitBuffer(args[0], local: true).Name});\n");
break;
case TIR.CPU.Gather gather:
IndentScope.Writer.Write($"gather<{gather.Axis}>({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[1], local: true).Name}, {VisitBuffer(args[2], local: true).Name});\n");
Expand Down Expand Up @@ -505,10 +505,11 @@ protected override CSymbol VisitCall(Call expr)
break;
case TIR.CPU.GatherReduceScatter grs:
{
if (grs.InType.NdSBP.Any(s => s is SBPPartialSum))
if (grs.InType.NdSBP.Any(s => s is SBPPartial))
{
var reduceKind = "tar::reduce_kind::" + string.Join("_", grs.InType.NdSBP.Select((s, i) => (s is SBPPartialSum ? "r" : string.Empty) + TargetOptions.HierarchyNames[i]));
IndentScope.Writer.IndWrite($"tac::tensor_reduce_sync<ops::add, {reduceKind}>({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[1], local: true).Name});\n");
var sbpPartial = (SBPPartial)grs.InType.NdSBP.Where(s => s is SBPPartial).Distinct().First();
var reduceKind = "tar::reduce_kind::" + string.Join("_", grs.InType.NdSBP.Select((s, i) => (s is SBPPartial ? "r" : string.Empty) + TargetOptions.HierarchyNames[i]));
IndentScope.Writer.IndWrite($"tac::tensor_reduce_sync<reduce_op::{sbpPartial.Op.ToC()}, {reduceKind}>({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[1], local: true).Name});\n");
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/CodeGen/CPU/KernelUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public static string DistributedToC(DistributedType distributedType)
}
}

var implicitPolicy = ndSBP.Any(x => x is SBPPartialSum) ? "P<reduce_op::sum>" : "B";
var implicitPolicy = ndSBP.Any(x => x is SBPPartial) ? "P<reduce_op::sum>" : "B";
sb.Append($">, {implicitPolicy}");

for (int axis = 0; axis < distributedType.TensorType.Shape.Rank; axis++)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ template <tar::reduce_kind Kind> class group_heirarchy_getter;
@:};
}

template <template <class, class> class Op, tar::reduce_kind Kind>
template <ntt::reduce_op Op, tar::reduce_kind Kind>
class tensor_reduce_sync_impl {
public:
void reduce_group_sync() const noexcept {
Expand Down Expand Up @@ -154,6 +154,20 @@ class tensor_reduce_sync_impl {
var cur_index = string.Join(", ", Enumerable.Range(0, hierarchy.Length).Select(i => "ntt::distributed::" + hierarchyNames[i] + "id()"));
}

template <class TSliceIn, class TSliceOut>
void reduce_impl(TSliceIn &local, TSliceIn &remote, TSliceOut &dest) {
if constexpr (Op == ntt::reduce_op::max) {
ntt::binary<ntt::ops::max>(local, remote, dest);
} else if constexpr (Op == ntt::reduce_op::sum ||
Op == ntt::reduce_op::mean) {
ntt::binary<ntt::ops::add>(local, remote, dest);
} else if constexpr (Op == ntt::reduce_op::min) {
ntt::binary<ntt::ops::min>(local, remote, dest);
} else if constexpr (Op == ntt::reduce_op::prod) {
ntt::binary<ntt::ops::mul>(local, remote, dest);
}
}

template <class TIn, class TOut> void operator()(TIn &src, TOut &&dest) {
// collect all tensors pointer for access tensor from other nodes.
using TElem = typename TIn::element_type;
Expand Down Expand Up @@ -198,8 +212,8 @@ class tensor_reduce_sync_impl {
auto next_index = ntt::make_ranked_shape(@(cur_index));
index_group2global(next_index_g, next_index);

// reduce-scatter
for (auto i = 0; i < group_size - 1; i++) // communicate (group_size - 1) times
// reduce-scatter, communicate (group_size - 1) times
for (auto i = 0; i < group_size - 1; i++)
{
// check when the last time.
auto offset = (node_number_g + i + 2) % group_size;
Expand All @@ -216,21 +230,30 @@ class tensor_reduce_sync_impl {
auto viewed_dest_tensor = dest.view(starts, new_shape);

if (i == 0) {
auto src2_tensor = ntt::tensor_view<TElem, typename TIn::shape_type, typename TIn::strides_type>(
std::span<TElem, TIn::shape().length()>((TElem *)tar::src_ptr_tensor(next_index), src.shape().length()));
auto src2_tensor =
ntt::tensor_view<TElem, typename TIn::shape_type,
typename TIn::strides_type>(
std::span<TElem, TIn::shape().length()>(
(TElem *)tar::src_ptr_tensor(next_index),
src.shape().length()));
auto viewed_src2_tensor = src2_tensor.view(starts, new_shape);
ntt::binary<Op>(viewed_src1_tensor, viewed_src2_tensor, viewed_dest_tensor);
reduce_impl(viewed_src1_tensor, viewed_src2_tensor,
viewed_dest_tensor);
} else {
auto src2_tensor = ntt::tensor_view<TElem, typename TOutBase::shape_type, typename TOutBase::strides_type>(
std::span<TElem, TOutBase::shape().length()>((TElem *)tar::dest_ptr_tensor(next_index), dest.shape().length()));
auto src2_tensor =
ntt::tensor_view<TElem, typename TOutBase::shape_type,
typename TOutBase::strides_type>(
std::span<TElem, TOutBase::shape().length()>(
(TElem *)tar::dest_ptr_tensor(next_index),
dest.shape().length()));
auto viewed_src2_tensor = src2_tensor.view(starts, new_shape);
ntt::binary<Op>(viewed_src1_tensor, viewed_src2_tensor, viewed_dest_tensor);
reduce_impl(viewed_src1_tensor, viewed_src2_tensor,
viewed_dest_tensor);
}


reduce_group_sync();
}

// all gather
auto dest_index_g = ntt::unravel_index((node_number_g + group_size - 1) % group_size, group_heirarchy);
auto dest_index = ntt::make_ranked_shape(@(cur_index));
Expand All @@ -244,27 +267,30 @@ class tensor_reduce_sync_impl {
for (size_t j = 0; j < Rank; j++) {
if (j == axis) {
starts[j] = offset * frac;
}
else {
} else {
starts[j] = 0;
}
}

new_shape[axis] = (offset == (group_size - 1)) ? frac + remain : frac;
auto viewed_src_tensor = dest.view(starts, new_shape);
auto viewed_dest_tensor = dest_tensor.view(starts, new_shape);
ntt::tensor_copy(std::move(viewed_src_tensor), std::move(viewed_dest_tensor));
ntt::tensor_copy(std::move(viewed_src_tensor),
std::move(viewed_dest_tensor));

reduce_group_sync();
}

if (Op == ntt::reduce_op::mean) {
ntt::binary<ntt::ops::div>(dest, ntt::tensor<TElem, ntt::fixed_shape<1>>((TElem)group_size), dest);
}
}
};
} // namespace detail

template <template <class, class> class Op, tar::reduce_kind Kind, class TIn,
class TOut>
void tensor_reduce_sync(TIn &input, TOut &&output){
detail::tensor_reduce_sync_impl<Op, Kind> impl;
impl(input, output);
}
template <ntt::reduce_op Op, tar::reduce_kind Kind, class TIn, class TOut>
void tensor_reduce_sync(TIn &input, TOut &&output) {
detail::tensor_reduce_sync_impl<Op, Kind> impl;
impl(input, output);
}
} // namespace tac
Loading

0 comments on commit 403a577

Please sign in to comment.