Skip to content

Commit

Permalink
Add elementwise, first draft
Browse files Browse the repository at this point in the history
  • Loading branch information
davschneller committed Apr 29, 2024
1 parent cc346f1 commit 413c425
Show file tree
Hide file tree
Showing 12 changed files with 588 additions and 112 deletions.
27 changes: 20 additions & 7 deletions kernelforge/backend/instructions/compute/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from kernelforge.common.operation import ReductionOperator
from typing import Union, List
from kernelforge.generators.optree import Assignment, writeAssignments
from kernelforge.backend.scopes import Scopes

class ElementwiseInstruction(ComputeInstruction):
def __init__(self,
Expand All @@ -21,47 +22,59 @@ def __init__(self,
num_threads: int):
super(ElementwiseInstruction, self).__init__(context)
self._assignments = assignments
self._productOperation = productOperation
self._sumOperation = sumOperation
self._prefer_align = prefer_align
self._is_ready = True
self._user_options = context.get_user_options()
self._gemm_meta_data = None
self._num_threads = num_threads

self._lead_dims = [0]

self.registers = None

# TODO: get index list
seen_tensors = set()
ranges = {}
for assignment in self._assignments:
assignment.assignSymbols(self.scopes)
assignment.assignSymbols(scopes)
ranges = assignment.getRanges(ranges)
for tensor in assignment.symbols():
if tensor not in seen_tensors:
tensor.add_user(self)
seen_tensors.add(tensor)
if not isinstance(op.obj, Tensor):
if not isinstance(tensor.obj, Tensor):
raise InternalError('elementwise: op is not a matrix')
self._ks = [None] * len(ranges)
for i in range(len(ranges)):
assert -i-1 in ranges
self._ks[i] = ranges[-i-1]

def gen_code_inner(self, writer: Writer):
self._assignment_loop(writer)

def _assignment_loop(self, writer: Writer):
loopstack = []

for i, (dimmin, dimmax) in enumerate(self._ns):
if len(self._ks) > 0:
writer(f'int n0 = {self._context.get_vm().get_lexic().thread_idx_x};')
for i, (dimmin, dimmax) in enumerate(self._ks):
if i not in self._lead_dims:
writer.insert_pragma_unroll()
loop = writer.For(f'int n{i} = {dimmin}; n{i} < {dimmax}; ++n{i}')
loop.__enter__()
loopstack += [loop]
else:
loop = writer.If(f'n{i} >= {dimmin} && n{i} < {dimmax}')
loop.__enter__()
loopstack += [loop]

writeAssignments(self._assignments, writer, self._context)

for loop in loopstack[::-1]:
loop.__exit__(None, None, None)

def get_operands(self):
return self._ops
return [] # TODO: for now

def __str__(self):
return ', '.join(f'{assignment.dest} = {assignment.optree}' for assignment in self._assignments)
return ', '.join(str(assignment) for assignment in self._assignments)
2 changes: 1 addition & 1 deletion kernelforge/backend/instructions/control/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

class ControlInstruction:
class ControlConstruct:
def __init__(self):
pass
2 changes: 2 additions & 0 deletions kernelforge/common/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class Operation(Enum):
MOD = 10,
NEG = 11,
RCP = 12,
RSQRT = 13,
RCBRT = 13,
CEIL = 30,
FLOOR = 31,
ROUND = 32,
Expand Down
2 changes: 1 addition & 1 deletion kernelforge/common/vm/lexic/hip_lexic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class HipLexic(CudaLexic):
def __init__(self, backend, underlying_hardware):
super().__init__(underlying_hardware)
super().__init__(backend, underlying_hardware)
self._backend = backend
self.thread_idx_y = "hipThreadIdx_y"
self.thread_idx_x = "hipThreadIdx_x"
Expand Down
5 changes: 3 additions & 2 deletions kernelforge/generators/descriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def __init__(self, oplist: List[Assignment],
self._strict_match = False

for op in oplist:
op.dest.set_data_flow_direction(DataFlowDirection.SINK)
for tensor in op.optree.tensors():
for tensor in op.tensors(outtensors=True, intensors=False):
tensor.set_data_flow_direction(DataFlowDirection.SINK)
for tensor in op.tensors(outtensors=False, intensors=True):
tensor.set_data_flow_direction(DataFlowDirection.SOURCE)

def get_num_threads(self, context: Context):
Expand Down
15 changes: 9 additions & 6 deletions kernelforge/generators/generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Union, Type
from copy import deepcopy
import hashlib
from kernelforge.generators.descriptions import OperationDescription
from kernelforge.generators.descriptions import OperationDescription, MultilinearDescr, ElementwiseDescr
from kernelforge.common.context import Context
from kernelforge.common.basic_types import Addressing, GeneralLexicon, DataFlowDirection
from kernelforge.common.aux import get_extra_offset_name
Expand All @@ -10,6 +10,7 @@
from kernelforge.backend.scopes import Scopes
from kernelforge.backend.symbol import Symbol, SymbolType
from kernelforge.backend.instructions.abstract_instruction import AbstractInstruction
from kernelforge.backend.instructions.compute.elementwise import ElementwiseInstruction
from kernelforge.backend.instructions.builders.multilinear_builder import MultilinearBuilder
from kernelforge.backend.instructions.builders.ptr_manip_builder import GetElementPtrBuilder
from kernelforge.backend.instructions.builders.allocator_builder import ShrMemAllocBuilder, RegistersAllocBuilder
Expand Down Expand Up @@ -206,14 +207,16 @@ def _emit_ir(self):
self._scopes.get_symbol(self._register_array_obj),
self._scopes.get_symbol(self._shr_mem_obj),
self._num_threads)

# builder.build_prologue()

for gemm_descr in self.descr_list:
builder.build(ops=[self._scopes.get_symbol(op) for op in gemm_descr.ops],
dest_obj=gemm_descr.dest,
descr=gemm_descr)
self._ir.extend(builder.get_instructions())
if isinstance(gemm_descr, MultilinearDescr):
builder.build(ops=[self._scopes.get_symbol(op) for op in gemm_descr.ops],
dest_obj=gemm_descr.dest,
descr=gemm_descr)
self._ir.extend(builder.get_instructions())
if isinstance(gemm_descr, ElementwiseDescr):
self._ir.append(ElementwiseInstruction(self._context, gemm_descr.oplist, self._scopes, False, self._num_threads))

builder.build_epilogue()
self._ir.extend(builder.get_instructions())
Expand Down
Loading

0 comments on commit 413c425

Please sign in to comment.