Skip to content

Commit

Permalink
Update several parts in kernelforge
Browse files Browse the repository at this point in the history
  • Loading branch information
davschneller committed Apr 22, 2024
1 parent 98a6245 commit 788580b
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 125 deletions.
42 changes: 10 additions & 32 deletions example/four_matrices_test.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,28 @@
from kernelforge.common.matrix.dense import DenseMatrix
from kernelforge.common.matrix.tensor import Tensor
from kernelforge.common.matrix.boundingbox import BoundingBox
from kernelforge.common.context import Context
from kernelforge.common.aux import generate_tmp_matrix
from kernelforge.generators.descriptions import GemmDescr, PointwiseDescr
from kernelforge.generators.descriptions import ElementwiseDescr
from kernelforge.common.basic_types import FloatingPointType, Addressing
from kernelforge.generators.generator import Generator
from kernelforge.common.operation import Operation
from kernelforge.generators import optree

# Q += A x ((B x C) x D)
mat_q = DenseMatrix(num_rows=56,
num_cols=56,
addressing=Addressing.PTR_BASED,
bbox=[0, 0, 56, 9],)
mat_q = Tensor([56, 56], Addressing.PTR_BASED, BoundingBox([0,0], [56,9]))

mat_a = DenseMatrix(num_rows=56,
num_cols=56,
addressing=Addressing.NONE,
bbox=[0, 0, 56, 20])
mat_a = Tensor([56, 56], Addressing.PTR_BASED, BoundingBox([0,0], [56,20]))

mat_b = DenseMatrix(num_rows=56,
num_cols=56,
addressing=Addressing.STRIDED,
bbox=[0, 0, 20, 56])
mat_b = Tensor([56, 56], Addressing.PTR_BASED, BoundingBox([0,0], [20,56]))

mat_c = DenseMatrix(num_rows=56,
num_cols=9,
bbox=[0, 0, 56, 9],
addressing=Addressing.STRIDED)
mat_c = Tensor([56, 9], Addressing.PTR_BASED, BoundingBox([0,0], [56,9]))

mat_d = DenseMatrix(num_rows=9,
num_cols=9,
bbox=[0, 0, 9, 9],
addressing=Addressing.STRIDED)
mat_d = Tensor([9, 9], Addressing.PTR_BASED, BoundingBox([0,0], [9,9]))


tmp1 = generate_tmp_matrix(mat_b, mat_c)
tmp2 = generate_tmp_matrix(tmp1, mat_d)

gemm_list = [PointwiseDescr(False, False, a=mat_a, b=tmp2, c=mat_q, operation=Operation.EXP)]
[
GemmDescr(trans_a=False,
trans_b=False,
a=tmp1, b=mat_d, c=tmp2),
GemmDescr(trans_a=False, trans_b=False,
a=mat_a, b=tmp2, c=mat_q,
alpha='alpha',
beta='beta')]
gemm_list = [ElementwiseDescr([optree.Assignment(mat_a, optree.TensorVar(mat_a, []))])]

context = Context(arch='sm_60',
backend='cuda',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def __init__(self,

self._dest_regs = self._temp_regs

self._use_registers_always = True
self._deferred_stores = {}
self._temporaries = {}

def build(self, ops: List[Symbol], dest_obj: Tensor, descr: MultilinearDescr):
self._reset()
Expand Down Expand Up @@ -150,7 +152,7 @@ def _alloc_register_array(self):
regmem = RegMemObject(name, regsize)
registers = Symbol(name=name, stype=SymbolType.Register, obj=regmem)
self._scopes.add_symbol(registers)
registerAlloc = RegisterAlloc(self._context, registers, regsize)
registerAlloc = RegisterAlloc(self._context, registers, regsize, 0)
self._instructions.append(registerAlloc)
return registers
# self._dest_regs = registers
Expand Down Expand Up @@ -188,7 +190,7 @@ def _make_store(self):
shr_mem=self._shr_mem,
num_threads=self._num_threads))
elif dest_symbol.stype == SymbolType.Global:
if True:
if self._use_registers_always:
if dest_symbol.name not in self._deferred_stores:
dest_registers = self._alloc_register_array()
self._deferred_stores[dest_symbol.name] = (dest_registers, dest_symbol)
Expand Down
102 changes: 38 additions & 64 deletions kernelforge/backend/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ class SymbolType(enum.Enum):
Register = 4
Scratch = 5
Scalar = 6
Data = 7
Data = 7,
WarpwideSource = 8,
WarpwideAccumulator = 9

def determine_dim_index(term, index, shape, permute):
divpos = reduce(lambda x,y: shape[x]*shape[y], permute[:index], 1)
Expand Down Expand Up @@ -89,59 +91,24 @@ def get_address(self, lead_dim, nonlead_dim):
def __str__(self):
return f'shape: {self.shape}, permute: {self._permute}'

class OldDataView:
def __init__(self, rows: int, columns: int, is_transposed: bool, bbox: List[int] = None):
self._rows = rows
self._columns = columns
self.is_transposed = is_transposed
if not bbox:
bbox = [0, 0, rows, columns]
self._bbox = bbox
self._lead_dim = self.get_lead_dim()
self._offset = self.get_offset()

def get_bbox(self):
return deepcopy(self._bbox)

def reset_bbox(self, bbox):
assert bbox[2] - bbox[0] <= self._rows
assert bbox[3] - bbox[1] <= self._columns
self._bbox = bbox
self._offset = self.get_offset()

def get_offset(self):
return self._bbox[0] + self._bbox[1] * self._lead_dim

def get_volume(self):
return self._rows * self._columns

def get_lead_dim(self):
return self._rows

def get_dim_size(self, index):
assert index >= 0 and index < 2
return self._bbox[2 + index] - self._bbox[index]

def get_address(self, row_idx, column_idx):
addr = f'{row_idx} + {column_idx} * {self._lead_dim}'
if self._offset:
addr = f'{self._offset} + {addr}'
return addr

def __str__(self):
return f'rows: {self.rows}, cols: {self.columns}, lid: {self.lead_dim}, trans: {self.is_transposed}'

class Immediate:
def __init__(self, value, fptype: FloatingPointType):
self._value = value
self._type = fptype

def is_thread_dependent(self):
return False

def write(self, context: Context):
return self._type.literal(self._value)

class Variable:
def __init__(self, name):
def __init__(self, name, fptype: FloatingPointType):
self._name = name
self._type = fptype

def is_thread_dependent(self):
return False

def write(self, context: Context):
return self._name
Expand All @@ -150,31 +117,38 @@ class LeadIndex:
def __init__(self, lane, stride):
self._lane = lane
self._stride = stride

def is_thread_dependent(self):
return True

def write(self, context: Context):
return f'(({context.get_vm().get_lexic().thread_idx_x} % {self._lane}) / {self._stride})'

class LoopDimension:
def __init__(self, unroll):
pass
class Loop:
def __init__(self, start, end, step=1, unroll=False):
self.start = start
self.end = end
self.step = step
self.unroll = unroll

def write(self, context: Context, writer: Writer, inner):
if unroll:
for value in TODO:
inner(Immediate(value, TODO))
for value in range(self.start, self.end, self.step):
inner(Immediate(value, FloatingPointType.INT))
else:
with writer.For(''):
inner(Variable('TODO'))

class Loop:
def __init__(self, dimensions: List[LoopDimension]):
pass

def write(self, context: Context, writer: Writer, inner):
pass

def indices(self):
pass
writer.insert_pragma_unroll()
var = writer.varalloc('i')
with writer.For(f'int {var}={self.start}; {var} < {self.stop}; {var} += {self.step}'):
inner(Variable(var, FloatingPointType.INT))

def write_loops(context: Context, writer: Writer, loops: List[Loop], inner):
def write_loops_inner(context: Context, writer: Writer, loops: List[Loop], inner, varlist):
if len(loops) == 1:
inner()
else:
inner_next = lambda v: write_loops_inner(context, writer, loops[1:], inner, varlist + [v])
loops[0].write(context, writer, inner_next)
write_loops_inner(context, writer, loops, inner, [])

class Symbol:
def __init__(self,
Expand Down Expand Up @@ -210,7 +184,7 @@ def address(self):
else:
return f'{self.name}'

def access_address(self, context: Context, index: List[Union[str, int]]):
def access_address(self, context: Context, index: List[Union[str, int, Immediate, Variable]]):
if self.stype == SymbolType.Global or self.stype == SymbolType.Batch or self.stype == SymbolType.SharedMem:
# lead_dim + nonlead_dim
dimstr = " + ".join(f"{var} * {stride}" for var, stride in zip(index, self.data_view.get_dim_strides()) if var != 0)
Expand All @@ -222,7 +196,7 @@ def access_address(self, context: Context, index: List[Union[str, int]]):
return dimstr if len(dimstr) > 0 else "0"
raise NotImplementedError('Not supposed to be called')

def access(self, context: Context, index: List[Union[str, int]]):
def access(self, context: Context, index: List[Union[str, int, Immediate, Variable]]):
if self.stype == SymbolType.Global or self.stype == SymbolType.Batch or self.stype == SymbolType.SharedMem or self.stype == SymbolType.Register or self.stype == SymbolType.Scratch:
return f'{self.name}[{self.access_address(context, index)}]'
if self.stype == SymbolType.Scalar:
Expand Down Expand Up @@ -267,7 +241,7 @@ def load(self, writer, context: Context, variable, index: List[Union[str, int]],
writer(f'{self.get_fptype(context)} {variable} = {access};')

def store(self, writer, context, variable, index: List[Union[str, int]], nontemp):
assert self.stype != SymbolType.Scalar and self.stype != SymbolType.Data
assert self.stype != SymbolType.Data

access = self.access(context, index)

Expand Down
12 changes: 12 additions & 0 deletions kernelforge/backend/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@

from io import StringIO

class VarAlloc:
def __init__(self):
self.counter = -1

def alloc(self, prefix='v'):
self.counter += 1
return f'{prefix}{self.counter}'

class NoScope:
def __enter__(self):
Expand Down Expand Up @@ -121,6 +128,11 @@ def __init__(self, stream=sys.stdout, factor=2):
self.stream = StringIO()
self.indent = 0
self.factor = 2
self.alloc = VarAlloc()

def varalloc(self, prefix='v'):
# TODO: maybe move out?
return self.alloc.alloc(prefix)

def get_src(self):
return self.stream.getvalue()
Expand Down
31 changes: 23 additions & 8 deletions kernelforge/common/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ class Operation(Enum):

class OperationType(Enum):
FLOAT = 0,
INTEGER = 1,
BOOLEAN = 2
SINT = 1,
UINT = 2,
BOOLEAN = 3

class Operator:
@abstractmethod
Expand Down Expand Up @@ -84,7 +85,7 @@ def format(self, *ops):
return f'({ops[0]} + {ops[1]})'

def datatype(self):
return [OperationType.FLOAT, OperationType.INTEGER]
return [OperationType.FLOAT, OperationType.SINT, OperationType.UINT]

def __str__(self):
return '+'
Expand All @@ -98,7 +99,7 @@ def format(self, *ops):
return f'({ops[0]} * {ops[1]})'

def datatype(self):
return [OperationType.FLOAT, OperationType.INTEGER]
return [OperationType.FLOAT, OperationType.SINT, OperationType.UINT]

def __str__(self):
return '*'
Expand All @@ -112,7 +113,7 @@ def format(self, *ops):
return f'min({ops[0]}, {ops[1]})'

def datatype(self):
return [OperationType.FLOAT, OperationType.INTEGER]
return [OperationType.FLOAT, OperationType.SINT, OperationType.UINT]

def __str__(self):
return 'min'
Expand All @@ -126,7 +127,7 @@ def format(self, *ops):
return f'max({ops[0]}, {ops[1]})'

def datatype(self):
return [OperationType.FLOAT, OperationType.INTEGER]
return [OperationType.FLOAT, OperationType.SINT, OperationType.UINT]

def __str__(self):
return 'max'
Expand All @@ -140,7 +141,7 @@ def format(self, *ops):
return f'({ops[0]} & {ops[1]})'

def datatype(self):
return [OperationType.BOOLEAN, OperationType.INTEGER]
return [OperationType.BOOLEAN, OperationType.UINT]

def __str__(self):
return '&'
Expand All @@ -154,11 +155,25 @@ def format(self, *ops):
return f'({ops[0]} | {ops[1]})'

def datatype(self):
return [OperationType.BOOLEAN, OperationType.INTEGER]
return [OperationType.BOOLEAN, OperationType.UINT]

def __str__(self):
return '|'

class XorOperator(ReductionOperator):
def neutral(self):
return False

@abstractmethod
def format(self, *ops):
return f'({ops[0]} ^ {ops[1]})'

def datatype(self):
return [OperationType.BOOLEAN, OperationType.UINT]

def __str__(self):
return '^'

class UnaryOperator(Operator):
def num_operands(self):
return 1
Expand Down
6 changes: 2 additions & 4 deletions kernelforge/generators/descriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,10 @@ def is_strict_match(self):
return self._strict_match

def matrix_list(self):
return [tensor for op in oplist for tensor in op.tensors()]
return [tensor for op in self.oplist for tensor in op.tensors()]

def __str__(self):
destdim = len(self.dest.shape)
desttarget = [i for i in range(destdim)]
return f'{self.dest}{desttarget} = {"×".join(f"{op}{optarget}" for op, optarget in zip(self.ops, self.target))}'
return '; '.join(str(op) for op in self.oplist)

class GemmDescr(MultilinearDescr):
def __init__(self,
Expand Down
Loading

0 comments on commit 788580b

Please sign in to comment.