diff --git a/example/four_matrices_test.py b/example/four_matrices_test.py index c386199..04102cd 100644 --- a/example/four_matrices_test.py +++ b/example/four_matrices_test.py @@ -1,50 +1,28 @@ -from kernelforge.common.matrix.dense import DenseMatrix +from kernelforge.common.matrix.tensor import Tensor +from kernelforge.common.matrix.boundingbox import BoundingBox from kernelforge.common.context import Context from kernelforge.common.aux import generate_tmp_matrix -from kernelforge.generators.descriptions import GemmDescr, PointwiseDescr +from kernelforge.generators.descriptions import ElementwiseDescr from kernelforge.common.basic_types import FloatingPointType, Addressing from kernelforge.generators.generator import Generator from kernelforge.common.operation import Operation +from kernelforge.generators import optree -# Q += A x ((B x C) x D) -mat_q = DenseMatrix(num_rows=56, - num_cols=56, - addressing=Addressing.PTR_BASED, - bbox=[0, 0, 56, 9],) +mat_q = Tensor([56, 56], Addressing.PTR_BASED, BoundingBox([0,0], [56,9])) -mat_a = DenseMatrix(num_rows=56, - num_cols=56, - addressing=Addressing.NONE, - bbox=[0, 0, 56, 20]) +mat_a = Tensor([56, 56], Addressing.PTR_BASED, BoundingBox([0,0], [56,20])) -mat_b = DenseMatrix(num_rows=56, - num_cols=56, - addressing=Addressing.STRIDED, - bbox=[0, 0, 20, 56]) +mat_b = Tensor([56, 56], Addressing.PTR_BASED, BoundingBox([0,0], [20,56])) -mat_c = DenseMatrix(num_rows=56, - num_cols=9, - bbox=[0, 0, 56, 9], - addressing=Addressing.STRIDED) +mat_c = Tensor([56, 9], Addressing.PTR_BASED, BoundingBox([0,0], [56,9])) -mat_d = DenseMatrix(num_rows=9, - num_cols=9, - bbox=[0, 0, 9, 9], - addressing=Addressing.STRIDED) +mat_d = Tensor([9, 9], Addressing.PTR_BASED, BoundingBox([0,0], [9,9])) tmp1 = generate_tmp_matrix(mat_b, mat_c) tmp2 = generate_tmp_matrix(tmp1, mat_d) -gemm_list = [PointwiseDescr(False, False, a=mat_a, b=tmp2, c=mat_q, operation=Operation.EXP)] -[ - GemmDescr(trans_a=False, - trans_b=False, - a=tmp1, b=mat_d, c=tmp2), - GemmDescr(trans_a=False, trans_b=False, - a=mat_a, b=tmp2, c=mat_q, - alpha='alpha', - beta='beta')] +gemm_list = [ElementwiseDescr([optree.Assignment(mat_a, optree.TensorVar(mat_a, []))])] context = Context(arch='sm_60', backend='cuda', diff --git a/kernelforge/backend/instructions/builders/multilinear_builder.py b/kernelforge/backend/instructions/builders/multilinear_builder.py index 5abbf65..e0f4a3a 100644 --- a/kernelforge/backend/instructions/builders/multilinear_builder.py +++ b/kernelforge/backend/instructions/builders/multilinear_builder.py @@ -40,7 +40,9 @@ def __init__(self, self._dest_regs = self._temp_regs + self._use_registers_always = True self._deferred_stores = {} + self._temporaries = {} def build(self, ops: List[Symbol], dest_obj: Tensor, descr: MultilinearDescr): self._reset() @@ -150,7 +152,7 @@ def _alloc_register_array(self): regmem = RegMemObject(name, regsize) registers = Symbol(name=name, stype=SymbolType.Register, obj=regmem) self._scopes.add_symbol(registers) - registerAlloc = RegisterAlloc(self._context, registers, regsize) + registerAlloc = RegisterAlloc(self._context, registers, regsize, 0) self._instructions.append(registerAlloc) return registers # self._dest_regs = registers @@ -188,7 +190,7 @@ def _make_store(self): shr_mem=self._shr_mem, num_threads=self._num_threads)) elif dest_symbol.stype == SymbolType.Global: - if True: + if self._use_registers_always: if dest_symbol.name not in self._deferred_stores: dest_registers = self._alloc_register_array() self._deferred_stores[dest_symbol.name] = (dest_registers, dest_symbol) diff --git a/kernelforge/backend/symbol.py b/kernelforge/backend/symbol.py index b10e710..b243c89 100644 --- a/kernelforge/backend/symbol.py +++ b/kernelforge/backend/symbol.py @@ -14,7 +14,9 @@ class SymbolType(enum.Enum): Register = 4 Scratch = 5 Scalar = 6 - Data = 7 + Data = 7, + WarpwideSource = 8, + WarpwideAccumulator = 9 def determine_dim_index(term, index, shape, permute): divpos = reduce(lambda x,y: shape[x]*shape[y], permute[:index], 1) @@ -89,59 +91,24 @@ def get_address(self, lead_dim, nonlead_dim): def __str__(self): return f'shape: {self.shape}, permute: {self._permute}' -class OldDataView: - def __init__(self, rows: int, columns: int, is_transposed: bool, bbox: List[int] = None): - self._rows = rows - self._columns = columns - self.is_transposed = is_transposed - if not bbox: - bbox = [0, 0, rows, columns] - self._bbox = bbox - self._lead_dim = self.get_lead_dim() - self._offset = self.get_offset() - - def get_bbox(self): - return deepcopy(self._bbox) - - def reset_bbox(self, bbox): - assert bbox[2] - bbox[0] <= self._rows - assert bbox[3] - bbox[1] <= self._columns - self._bbox = bbox - self._offset = self.get_offset() - - def get_offset(self): - return self._bbox[0] + self._bbox[1] * self._lead_dim - - def get_volume(self): - return self._rows * self._columns - - def get_lead_dim(self): - return self._rows - - def get_dim_size(self, index): - assert index >= 0 and index < 2 - return self._bbox[2 + index] - self._bbox[index] - - def get_address(self, row_idx, column_idx): - addr = f'{row_idx} + {column_idx} * {self._lead_dim}' - if self._offset: - addr = f'{self._offset} + {addr}' - return addr - - def __str__(self): - return f'rows: {self.rows}, cols: {self.columns}, lid: {self.lead_dim}, trans: {self.is_transposed}' - class Immediate: def __init__(self, value, fptype: FloatingPointType): self._value = value self._type = fptype + def is_thread_dependent(self): + return False + def write(self, context: Context): return self._type.literal(self._value) class Variable: - def __init__(self, name): + def __init__(self, name, fptype: FloatingPointType): self._name = name + self._type = fptype + + def is_thread_dependent(self): + return False def write(self, context: Context): return self._name @@ -150,31 +117,38 @@ class LeadIndex: def __init__(self, lane, stride): self._lane = lane self._stride = stride + + def is_thread_dependent(self): + return True def write(self, context: Context): return f'(({context.get_vm().get_lexic().thread_idx_x} % {self._lane}) / {self._stride})' -class LoopDimension: - def __init__(self, unroll): - pass +class Loop: + def __init__(self, start, end, step=1, unroll=False): + self.start = start + self.end = end + self.step = step + self.unroll = unroll def write(self, context: Context, writer: Writer, inner): if unroll: - for value in TODO: - inner(Immediate(value, TODO)) + for value in range(self.start, self.end, self.step): + inner(Immediate(value, FloatingPointType.INT)) else: - with writer.For(''): - inner(Variable('TODO')) - -class Loop: - def __init__(self, dimensions: List[LoopDimension]): - pass - - def write(self, context: Context, writer: Writer, inner): - pass - - def indices(self): - pass + writer.insert_pragma_unroll() + var = writer.varalloc('i') + with writer.For(f'int {var}={self.start}; {var} < {self.stop}; {var} += {self.step}'): + inner(Variable(var, FloatingPointType.INT)) + +def write_loops(context: Context, writer: Writer, loops: List[Loop], inner): + def write_loops_inner(context: Context, writer: Writer, loops: List[Loop], inner, varlist): + if len(loops) == 1: + inner() + else: + inner_next = lambda v: write_loops_inner(context, writer, loops[1:], inner, varlist + [v]) + loops[0].write(context, writer, inner_next) + write_loops_inner(context, writer, loops, inner, []) class Symbol: def __init__(self, @@ -210,7 +184,7 @@ def address(self): else: return f'{self.name}' - def access_address(self, context: Context, index: List[Union[str, int]]): + def access_address(self, context: Context, index: List[Union[str, int, Immediate, Variable]]): if self.stype == SymbolType.Global or self.stype == SymbolType.Batch or self.stype == SymbolType.SharedMem: # lead_dim + nonlead_dim dimstr = " + ".join(f"{var} * {stride}" for var, stride in zip(index, self.data_view.get_dim_strides()) if var != 0) @@ -222,7 +196,7 @@ def access_address(self, context: Context, index: List[Union[str, int]]): return dimstr if len(dimstr) > 0 else "0" raise NotImplementedError('Not supposed to be called') - def access(self, context: Context, index: List[Union[str, int]]): + def access(self, context: Context, index: List[Union[str, int, Immediate, Variable]]): if self.stype == SymbolType.Global or self.stype == SymbolType.Batch or self.stype == SymbolType.SharedMem or self.stype == SymbolType.Register or self.stype == SymbolType.Scratch: return f'{self.name}[{self.access_address(context, index)}]' if self.stype == SymbolType.Scalar: @@ -267,7 +241,7 @@ def load(self, writer, context: Context, variable, index: List[Union[str, int]], writer(f'{self.get_fptype(context)} {variable} = {access};') def store(self, writer, context, variable, index: List[Union[str, int]], nontemp): - assert self.stype != SymbolType.Scalar and self.stype != SymbolType.Data + assert self.stype != SymbolType.Data access = self.access(context, index) diff --git a/kernelforge/backend/writer.py b/kernelforge/backend/writer.py index 4b79f6b..7466184 100644 --- a/kernelforge/backend/writer.py +++ b/kernelforge/backend/writer.py @@ -41,6 +41,13 @@ from io import StringIO +class VarAlloc: + def __init__(self): + self.counter = -1 + + def alloc(self, prefix='v'): + self.counter += 1 + return f'{prefix}{self.counter}' class NoScope: def __enter__(self): @@ -121,6 +128,11 @@ def __init__(self, stream=sys.stdout, factor=2): self.stream = StringIO() self.indent = 0 self.factor = 2 + self.alloc = VarAlloc() + + def varalloc(self, prefix='v'): + # TODO: maybe move out? + return self.alloc.alloc(prefix) def get_src(self): return self.stream.getvalue() diff --git a/kernelforge/common/operation.py b/kernelforge/common/operation.py index 3a0730a..937e23b 100644 --- a/kernelforge/common/operation.py +++ b/kernelforge/common/operation.py @@ -51,8 +51,9 @@ class Operation(Enum): class OperationType(Enum): FLOAT = 0, - INTEGER = 1, - BOOLEAN = 2 + SINT = 1, + UINT = 2, + BOOLEAN = 3 class Operator: @abstractmethod @@ -84,7 +85,7 @@ def format(self, *ops): return f'({ops[0]} + {ops[1]})' def datatype(self): - return [OperationType.FLOAT, OperationType.INTEGER] + return [OperationType.FLOAT, OperationType.SINT, OperationType.UINT] def __str__(self): return '+' @@ -98,7 +99,7 @@ def format(self, *ops): return f'({ops[0]} * {ops[1]})' def datatype(self): - return [OperationType.FLOAT, OperationType.INTEGER] + return [OperationType.FLOAT, OperationType.SINT, OperationType.UINT] def __str__(self): return '*' @@ -112,7 +113,7 @@ def format(self, *ops): return f'min({ops[0]}, {ops[1]})' def datatype(self): - return [OperationType.FLOAT, OperationType.INTEGER] + return [OperationType.FLOAT, OperationType.SINT, OperationType.UINT] def __str__(self): return 'min' @@ -126,7 +127,7 @@ def format(self, *ops): return f'max({ops[0]}, {ops[1]})' def datatype(self): - return [OperationType.FLOAT, OperationType.INTEGER] + return [OperationType.FLOAT, OperationType.SINT, OperationType.UINT] def __str__(self): return 'max' @@ -140,7 +141,7 @@ def format(self, *ops): return f'({ops[0]} & {ops[1]})' def datatype(self): - return [OperationType.BOOLEAN, OperationType.INTEGER] + return [OperationType.BOOLEAN, OperationType.UINT] def __str__(self): return '&' @@ -154,11 +155,25 @@ def format(self, *ops): return f'({ops[0]} | {ops[1]})' def datatype(self): - return [OperationType.BOOLEAN, OperationType.INTEGER] + return [OperationType.BOOLEAN, OperationType.UINT] def __str__(self): return '|' +class XorOperator(ReductionOperator): + def neutral(self): + return False + + @abstractmethod + def format(self, *ops): + return f'({ops[0]} ^ {ops[1]})' + + def datatype(self): + return [OperationType.BOOLEAN, OperationType.UINT] + + def __str__(self): + return '^' + class UnaryOperator(Operator): def num_operands(self): return 1 diff --git a/kernelforge/generators/descriptions.py b/kernelforge/generators/descriptions.py index 0912490..7c4e8b4 100644 --- a/kernelforge/generators/descriptions.py +++ b/kernelforge/generators/descriptions.py @@ -77,12 +77,10 @@ def is_strict_match(self): return self._strict_match def matrix_list(self): - return [tensor for op in oplist for tensor in op.tensors()] + return [tensor for op in self.oplist for tensor in op.tensors()] def __str__(self): - destdim = len(self.dest.shape) - desttarget = [i for i in range(destdim)] - return f'{self.dest}{desttarget} = {"×".join(f"{op}{optarget}" for op, optarget in zip(self.ops, self.target))}' + return '; '.join(str(op) for op in self.oplist) class GemmDescr(MultilinearDescr): def __init__(self, diff --git a/kernelforge/generators/generator.py b/kernelforge/generators/generator.py index 8a5f6b9..9ff9a46 100644 --- a/kernelforge/generators/generator.py +++ b/kernelforge/generators/generator.py @@ -47,7 +47,7 @@ def __init__(self, gemm_list: List[OperationDescription], context: Context, thread_block_policy_type: Type[AbstractThreadBlockPolicy] = SimpleThreadBlockPolicy): - self.gemm_list: List[OperationDescription] = gemm_list + self.descr_list: List[OperationDescription] = gemm_list self._context: Context = context self._thread_block_policy_type: Type[AbstractThreadBlockPolicy] = thread_block_policy_type self._base_kernel_name: Union[str, None] = None @@ -71,7 +71,7 @@ def __init__(self, self._ir: List[AbstractInstruction] = [] self._check_consistency_with_user_options() - self._name_operands(self.gemm_list) + self._name_operands(self.descr_list) self._persistent_threading = False @@ -168,15 +168,15 @@ def _generate_header(self): self._header = f'{self._generate_launcher_proto(with_defaults=True)};\n' def _deduce_num_threads(self): - for gemm in self.gemm_list: - num_threads, num_active_threads = gemm.get_num_threads(self._context) + for descr in self.descr_list: + num_threads, num_active_threads = descr.get_num_threads(self._context) self._num_threads = max(num_threads, self._num_threads) self._num_active_threads = max(num_active_threads, self._num_active_threads) def _deduce_accumulator_size(self): - for gemm in self.gemm_list: - local_acc_size = gemm.get_accumulator_size() + for descr in self.descr_list: + local_acc_size = descr.get_accumulator_size() self._accumulator_size = max(self._accumulator_size, local_acc_size) def _emit_ir(self): @@ -209,7 +209,7 @@ def _emit_ir(self): # builder.build_prologue() - for gemm_descr in self.gemm_list: + for gemm_descr in self.descr_list: builder.build(ops=[self._scopes.get_symbol(op) for op in gemm_descr.ops], dest_obj=gemm_descr.dest, descr=gemm_descr) @@ -236,10 +236,10 @@ def get_header(self): def _check_consistency_with_user_options(self): user_options = self._context.get_user_options() - for gemm in self.gemm_list: - if not gemm.is_strict_match() == user_options.exact_contraction_length: + for descr in self.descr_list: + if not descr.is_strict_match() == user_options.exact_contraction_length: msg = 'gemm list is not consistent with user options. ' - msg += f'`strict_math` in gemm descr. set to {gemm.is_strict_match()}, ' + msg += f'`strict_math` in gemm descr. set to {descr.is_strict_match()}, ' msg += f'but `exact_contraction_length` is set to {user_options.exact_contraction_length}' raise RuntimeError(msg) @@ -287,14 +287,14 @@ def _generate_kernel_name(self): for item in global_symbols: long_name.append(item.obj.gen_descr()) - for gemm in self.gemm_list: + for descr in self.descr_list: long_name.extend([ - str(gemm.dest) + str(descr) ]) result = hashlib.md5(', '.join(long_name).encode()) md5encoding = result.hexdigest() - self._base_kernel_name = f'cf_gemms_{md5encoding[:Generator.NAME_ENCODING_LENGTH]}' + self._base_kernel_name = f'kernel_{md5encoding[:Generator.NAME_ENCODING_LENGTH]}' def get_base_name(self): return self._base_kernel_name @@ -312,7 +312,7 @@ def _write_kernel_meta_data(self, writer): writer(f'// {matrix.obj.gen_descr()}') writer.new_line() - for item in self.gemm_list: + for item in self.descr_list: writer(f'// {item}') writer.new_line() diff --git a/kernelforge/generators/optree.py b/kernelforge/generators/optree.py index 16f18b7..69681f1 100644 --- a/kernelforge/generators/optree.py +++ b/kernelforge/generators/optree.py @@ -34,10 +34,13 @@ def store(self, writer: Writer, context: Context, value: str): pass class Assignment: - def __init__(dest: Variable, optree: Node): + def __init__(self, dest: Variable, optree: Node): self.dest = dest self.optree = optree + def tensors(self): + return [self.dest] + self.optree.tensors() + def assignSymbols(self, scopes: Scopes): self.dest.assignSymbols(scopes) self.optree.assignSymbols(scopes)