From 413c425dfa9e3cd661a9432286788524ceb44fa9 Mon Sep 17 00:00:00 2001 From: David Schneller Date: Tue, 30 Apr 2024 01:43:44 +0200 Subject: [PATCH] Add elementwise, first draft --- .../instructions/compute/elementwise.py | 27 +- .../backend/instructions/control/__init__.py | 2 +- kernelforge/common/operation.py | 2 + kernelforge/common/vm/lexic/hip_lexic.py | 2 +- kernelforge/generators/descriptions.py | 5 +- kernelforge/generators/generator.py | 15 +- kernelforge/generators/optree.py | 283 ++++++++++++++---- tensorforge/functions.py | 154 ++++++++++ yateto/ast/node.py | 79 +++++ yateto/ast/transformer.py | 8 + yateto/codegen/factory.py | 6 + yateto/codegen/gpukernel.py | 117 +++++--- 12 files changed, 588 insertions(+), 112 deletions(-) create mode 100644 tensorforge/functions.py diff --git a/kernelforge/backend/instructions/compute/elementwise.py b/kernelforge/backend/instructions/compute/elementwise.py index 499358b..881fc09 100644 --- a/kernelforge/backend/instructions/compute/elementwise.py +++ b/kernelforge/backend/instructions/compute/elementwise.py @@ -11,6 +11,7 @@ from kernelforge.common.operation import ReductionOperator from typing import Union, List from kernelforge.generators.optree import Assignment, writeAssignments +from kernelforge.backend.scopes import Scopes class ElementwiseInstruction(ComputeInstruction): def __init__(self, @@ -21,26 +22,32 @@ def __init__(self, num_threads: int): super(ElementwiseInstruction, self).__init__(context) self._assignments = assignments - self._productOperation = productOperation - self._sumOperation = sumOperation self._prefer_align = prefer_align self._is_ready = True self._user_options = context.get_user_options() self._gemm_meta_data = None self._num_threads = num_threads + self._lead_dims = [0] + self.registers = None # TODO: get index list seen_tensors = set() + ranges = {} for assignment in self._assignments: - assignment.assignSymbols(self.scopes) + assignment.assignSymbols(scopes) + ranges = assignment.getRanges(ranges) for tensor in assignment.symbols(): if tensor not in seen_tensors: tensor.add_user(self) seen_tensors.add(tensor) - if not isinstance(op.obj, Tensor): + if not isinstance(tensor.obj, Tensor): raise InternalError('elementwise: op is not a matrix') + self._ks = [None] * len(ranges) + for i in range(len(ranges)): + assert -i-1 in ranges + self._ks[i] = ranges[-i-1] def gen_code_inner(self, writer: Writer): self._assignment_loop(writer) @@ -48,12 +55,18 @@ def gen_code_inner(self, writer: Writer): def _assignment_loop(self, writer: Writer): loopstack = [] - for i, (dimmin, dimmax) in enumerate(self._ns): + if len(self._ks) > 0: + writer(f'int n0 = {self._context.get_vm().get_lexic().thread_idx_x};') + for i, (dimmin, dimmax) in enumerate(self._ks): if i not in self._lead_dims: writer.insert_pragma_unroll() loop = writer.For(f'int n{i} = {dimmin}; n{i} < {dimmax}; ++n{i}') loop.__enter__() loopstack += [loop] + else: + loop = writer.If(f'n{i} >= {dimmin} && n{i} < {dimmax}') + loop.__enter__() + loopstack += [loop] writeAssignments(self._assignments, writer, self._context) @@ -61,7 +74,7 @@ def _assignment_loop(self, writer: Writer): loop.__exit__(None, None, None) def get_operands(self): - return self._ops + return [] # TODO: for now def __str__(self): - return ', '.join(f'{assignment.dest} = {assignment.optree}' for assignment in self._assignments) + return ', '.join(str(assignment) for assignment in self._assignments) diff --git a/kernelforge/backend/instructions/control/__init__.py b/kernelforge/backend/instructions/control/__init__.py index c464317..f31dcae 100644 --- a/kernelforge/backend/instructions/control/__init__.py +++ b/kernelforge/backend/instructions/control/__init__.py @@ -1,4 +1,4 @@ -class ControlInstruction: +class ControlConstruct: def __init__(self): pass diff --git a/kernelforge/common/operation.py b/kernelforge/common/operation.py index 937e23b..fcf0862 100644 --- a/kernelforge/common/operation.py +++ b/kernelforge/common/operation.py @@ -17,6 +17,8 @@ class Operation(Enum): MOD = 10, NEG = 11, RCP = 12, + RSQRT = 13, + RCBRT = 13, CEIL = 30, FLOOR = 31, ROUND = 32, diff --git a/kernelforge/common/vm/lexic/hip_lexic.py b/kernelforge/common/vm/lexic/hip_lexic.py index 8fad769..8bde449 100644 --- a/kernelforge/common/vm/lexic/hip_lexic.py +++ b/kernelforge/common/vm/lexic/hip_lexic.py @@ -4,7 +4,7 @@ class HipLexic(CudaLexic): def __init__(self, backend, underlying_hardware): - super().__init__(underlying_hardware) + super().__init__(backend, underlying_hardware) self._backend = backend self.thread_idx_y = "hipThreadIdx_y" self.thread_idx_x = "hipThreadIdx_x" diff --git a/kernelforge/generators/descriptions.py b/kernelforge/generators/descriptions.py index 7c4e8b4..2376e95 100644 --- a/kernelforge/generators/descriptions.py +++ b/kernelforge/generators/descriptions.py @@ -62,8 +62,9 @@ def __init__(self, oplist: List[Assignment], self._strict_match = False for op in oplist: - op.dest.set_data_flow_direction(DataFlowDirection.SINK) - for tensor in op.optree.tensors(): + for tensor in op.tensors(outtensors=True, intensors=False): + tensor.set_data_flow_direction(DataFlowDirection.SINK) + for tensor in op.tensors(outtensors=False, intensors=True): tensor.set_data_flow_direction(DataFlowDirection.SOURCE) def get_num_threads(self, context: Context): diff --git a/kernelforge/generators/generator.py b/kernelforge/generators/generator.py index 9ff9a46..d832c97 100644 --- a/kernelforge/generators/generator.py +++ b/kernelforge/generators/generator.py @@ -1,7 +1,7 @@ from typing import List, Union, Type from copy import deepcopy import hashlib -from kernelforge.generators.descriptions import OperationDescription +from kernelforge.generators.descriptions import OperationDescription, MultilinearDescr, ElementwiseDescr from kernelforge.common.context import Context from kernelforge.common.basic_types import Addressing, GeneralLexicon, DataFlowDirection from kernelforge.common.aux import get_extra_offset_name @@ -10,6 +10,7 @@ from kernelforge.backend.scopes import Scopes from kernelforge.backend.symbol import Symbol, SymbolType from kernelforge.backend.instructions.abstract_instruction import AbstractInstruction +from kernelforge.backend.instructions.compute.elementwise import ElementwiseInstruction from kernelforge.backend.instructions.builders.multilinear_builder import MultilinearBuilder from kernelforge.backend.instructions.builders.ptr_manip_builder import GetElementPtrBuilder from kernelforge.backend.instructions.builders.allocator_builder import ShrMemAllocBuilder, RegistersAllocBuilder @@ -206,14 +207,16 @@ def _emit_ir(self): self._scopes.get_symbol(self._register_array_obj), self._scopes.get_symbol(self._shr_mem_obj), self._num_threads) - # builder.build_prologue() for gemm_descr in self.descr_list: - builder.build(ops=[self._scopes.get_symbol(op) for op in gemm_descr.ops], - dest_obj=gemm_descr.dest, - descr=gemm_descr) - self._ir.extend(builder.get_instructions()) + if isinstance(gemm_descr, MultilinearDescr): + builder.build(ops=[self._scopes.get_symbol(op) for op in gemm_descr.ops], + dest_obj=gemm_descr.dest, + descr=gemm_descr) + self._ir.extend(builder.get_instructions()) + if isinstance(gemm_descr, ElementwiseDescr): + self._ir.append(ElementwiseInstruction(self._context, gemm_descr.oplist, self._scopes, False, self._num_threads)) builder.build_epilogue() self._ir.extend(builder.get_instructions()) diff --git a/kernelforge/generators/optree.py b/kernelforge/generators/optree.py index 69681f1..fe8dc45 100644 --- a/kernelforge/generators/optree.py +++ b/kernelforge/generators/optree.py @@ -14,15 +14,24 @@ def alloc(self): return f'v{self.counter}' class Node: - def tensors(self): + def tensors(self, intensors=True, outtensors=True): + pass + + def pretensors(self, intensors=True, outtensors=True): pass - def symbols(self): + def symbols(self, intensors=True, outtensors=True): + pass + + def getRanges(self, ranges): pass def assignSymbols(self, scopes: Scopes): pass + def assignTensor(self, assigner): + pass + def declare(self, alloc: VarAlloc, writer: Writer, context: Context): pass @@ -33,72 +42,154 @@ class Variable(Node): def store(self, writer: Writer, context: Context, value: str): pass -class Assignment: +class Statement: + def tensors(self, intensors=True, outtensors=True): + pass + + def pretensors(self, intensors=True, outtensors=True): + pass + + def symbols(self, intensors=True, outtensors=True): + pass + + def getRanges(self, ranges): + pass + + def assignSymbols(self, scopes: Scopes): + pass + + def assignTensor(self, assigner): + pass + + def declare(self, alloc: VarAlloc, writer: Writer, context: Context): + pass + + def write(self, alloc: VarAlloc, writer: Writer, context: Context): + pass + +class Assignment(Statement): def __init__(self, dest: Variable, optree: Node): self.dest = dest self.optree = optree + + def symbols(self, intensors=True, outtensors=True): + tensorlist = [] + if intensors: + tensorlist += self.optree.symbols(intensors, outtensors) + if outtensors: + tensorlist += self.dest.symbols(intensors, outtensors) + return tensorlist + + def tensors(self, intensors=True, outtensors=True): + tensorlist = [] + if intensors: + tensorlist += self.optree.tensors(intensors, outtensors) + if outtensors: + tensorlist += self.dest.tensors(intensors, outtensors) + return tensorlist + + def pretensors(self, intensors=True, outtensors=True): + tensorlist = [] + if intensors: + tensorlist += self.optree.pretensors(intensors, outtensors) + if outtensors: + tensorlist += self.dest.pretensors(intensors, outtensors) + return tensorlist - def tensors(self): - return [self.dest] + self.optree.tensors() + def getRanges(self, ranges): + ranges = self.dest.getRanges(ranges) + ranges = self.optree.getRanges(ranges) + return ranges + + def assignTensor(self, assigner): + self.dest.assignTensor(assigner) + self.optree.assignTensor(assigner) def assignSymbols(self, scopes: Scopes): self.dest.assignSymbols(scopes) self.optree.assignSymbols(scopes) def declare(self, alloc: VarAlloc, writer: Writer, context: Context): - optree.declare(alloc, writer, context) self.dest.declare(alloc, writer, context) + self.optree.declare(alloc, writer, context) def write(self, alloc: VarAlloc, writer: Writer, context: Context): - value = optree.write(alloc, writer, context) + value = self.optree.write(alloc, writer, context) self.dest.store(alloc, writer, context, value) class TensorVar(Variable): - def __init__(self, tensor, slicing): + def __init__(self, tensor, slicing, pretensor=None): self.tensor = tensor self.slicing = slicing self.symbol: Union[None, Symbol] = None + self.pretensor = pretensor self.variable = None + self.indices = None - def tensors(self): + def tensors(self, intensors=True, outtensors=True): return [self.tensor] - def symbols(self): + def pretensors(self, intensors=True, outtensors=True): + return [self.pretensor] + + def symbols(self, intensors=True, outtensors=True): return [self.symbol] + def getRanges(self, ranges): + for i in range(len(self.indices)): + if self.indices[i] not in ranges: + ranges[self.indices[i]] = (self.symbol.data_view.get_bbox().lower()[i], self.symbol.data_view.get_bbox().upper()[i]) + crange = ranges[self.indices[i]] + crange = (min(crange[0], self.symbol.data_view.get_bbox().lower()[i]), max(crange[1], self.symbol.data_view.get_bbox().upper()[i])) + ranges[self.indices[i]] = crange + return ranges + def assignSymbols(self, scopes: Scopes): if self.symbol is None: - pass + self.symbol = scopes.get_symbol(self.tensor) + + def assignTensor(self, assigner): + self.tensor, self.indices = assigner(self.pretensor) def declare(self, alloc: VarAlloc, writer: Writer, context: Context): pass def write(self, alloc: VarAlloc, writer: Writer, context: Context): - if self.variable is None: - self.variable = alloc.alloc() - self.symbol.load(context, writer, self.variable, [], False) + # TODO: re-enable caching + # if self.variable is None: + self.variable = alloc.alloc() + self.symbol.load(writer, context, self.variable, [f'n{-i-1}' for i in self.indices], False) + return self.variable def store(self, alloc: VarAlloc, writer: Writer, context: Context, value: str): # assume that we don't have to reload - self.symbol.store(writer, context, value, [], False) + self.symbol.store(writer, context, value, [f'n{-i-1}' for i in self.indices], False) class TempVar(Variable): def __init__(self): self.variable = None - def tensors(self): + def tensors(self, intensors=True, outtensors=True): return [] - def symbols(self): + def symbols(self, intensors=True, outtensors=True): + return [] + + def pretensors(self, intensors=True, outtensors=True): return [] def assignSymbols(self, scopes: Scopes): pass + + def assignTensor(self, assigner): + pass + + def getRanges(self, ranges): + return ranges def declare(self, alloc: VarAlloc, writer: Writer, context: Context): - pass - # self.variable = alloc.alloc() - # write(f'{context.fp_as_str()} {self.variable};') + self.variable = alloc.alloc() + writer(f'{context.fp_as_str()} {self.variable};') def write(self, alloc: VarAlloc, writer: Writer, context: Context): assert self.variable is not None @@ -107,18 +198,58 @@ def write(self, alloc: VarAlloc, writer: Writer, context: Context): def store(self, alloc: VarAlloc, writer: Writer, context: Context, value: str): if self.variable is None: self.variable = alloc.alloc() - writer(f'auto {self.variable} = {value};') + writer(f'{self.variable} = {value};') + +class Immediate(Variable): + def __init__(self, value, fptype: FloatingPointType): + self.value = value + self.fptype = fptype + + def tensors(self, intensors=True, outtensors=True): + return [] + + def symbols(self, intensors=True, outtensors=True): + return [] + + def pretensors(self, intensors=True, outtensors=True): + return [] + + def getRanges(self, ranges): + return ranges + + def assignSymbols(self, scopes: Scopes): + pass + + def assignTensor(self, assigner): + pass + + def declare(self, alloc: VarAlloc, writer: Writer, context: Context): + pass + + def write(self, alloc: VarAlloc, writer: Writer, context: Context): + return self.fptype.literal(self.value) + + def store(self, alloc: VarAlloc, writer: Writer, context: Context, value: str): + pass class OpNode(Node): def __init__(self, operands: List[Node], optype: Operation): self.operands = operands self.optype = optype self.variable = None + + def getRanges(self, ranges): + for op in self.operands: + ranges = op.getRanges(ranges) + return ranges - def tensors(self): + def tensors(self, intensors=True, outtensors=True): return [tensor for operand in self.operands for tensor in operand.tensors()] - def symbols(self): + def pretensors(self, intensors=True, outtensors=True): + return [tensor for operand in self.operands for tensor in operand.pretensors()] + + def symbols(self, intensors=True, outtensors=True): return [symbol for operand in self.operands for symbol in operand.symbols()] def assignSymbols(self, scopes: Scopes): @@ -126,9 +257,13 @@ def assignSymbols(self, scopes: Scopes): op.assignSymbols(scopes) def declare(self, alloc: VarAlloc, writer: Writer, context: Context): - for i, op in enumerate(self.operands): + for op in self.operands: op.declare(alloc, writer, context) + def assignTensor(self, assigner): + for op in self.operands: + op.assignTensor(assigner) + def operation(self, context: Context, var: List[str]): pass @@ -142,44 +277,64 @@ def write(self, alloc: VarAlloc, writer: Writer, context: Context): return self.variable class LexicOpNode(OpNode): - def __init__(self, optype: Operation): - self.optype = optype - def operation(self, context: Context, var: List[str]): - return context.get_vm().get_lexic().get_operation(self.optype, context.fp_type, *var) + return context.get_vm().get_lexic().get_operation(self.optype, context.fp_type, *(var + [''])) class ConditionalOpNode(OpNode): - def __init__(self, optype: Operation): - self.optype = optype - def operation(self, context: Context, var: List[str]): return f'({var[0]}) ? ({var[1]}) : ({var[2]})' class CastOpNode(OpNode): - def __init__(self, targetType: FloatingPointType): + def __init__(self, operands: List[Node], targetType: FloatingPointType): + self.operands = operands self.targetType = targetType + self.variable = None + self.optype = None def operation(self, context: Context, var: List[str]): return 'static_cast<{self.targetType}>({var[0]})' -class IfNode(Node): - def __init__(self, condition: Node, subassignments: List[Assignment]): +class IfNode(Statement): + def __init__(self, condition: Node, subassignments: List[Statement]): self.condition = condition self.subassignments = subassignments + def getRanges(self, ranges): + ranges = self.condition.getRanges(ranges) + for subassignment in self.subassignments: + ranges = subassignment.getRanges(ranges) + return ranges + def assignSymbols(self, scopes: Scopes): self.condition.assignSymbols(scopes) for subassignment in self.subassignments: - subassignments.assignSymbols(scopes) + subassignment.assignSymbols(scopes) - def tensors(self): - return [tensor for operand in self.operands for tensor in operand.tensors()] + def assignTensor(self, assigner): + self.condition.assignTensor(assigner) + for subassignment in self.subassignments: + subassignment.assignTensor(assigner) + + def tensors(self, intensors=True, outtensors=True): + tensorlist = [] + if intensors: + tensorlist += self.condition.tensors(intensors, outtensors) + return tensorlist + [tensor for operand in self.subassignments for tensor in operand.tensors(intensors, outtensors)] - def symbols(self): - return [symbol for operand in self.operands for symbol in operand.symbols()] + def pretensors(self, intensors=True, outtensors=True): + tensorlist = [] + if intensors: + tensorlist += self.condition.pretensors(intensors, outtensors) + return tensorlist + [tensor for operand in self.subassignments for tensor in operand.pretensors(intensors, outtensors)] + + def symbols(self, intensors=True, outtensors=True): + tensorlist = [] + if intensors: + tensorlist += self.condition.symbols(intensors, outtensors) + return tensorlist + [tensor for operand in self.subassignments for tensor in operand.symbols(intensors, outtensors)] def declare(self, alloc: VarAlloc, writer: Writer, context: Context): - condition.declare(alloc, writer, context) + self.condition.declare(alloc, writer, context) for subassignment in self.subassignments: subassignment.declare(alloc, writer, context) @@ -189,39 +344,63 @@ def write(self, alloc: VarAlloc, writer: Writer, context: Context): for subassignment in self.subassignments: subassignment.write(alloc, writer, context) -class WhileNode(Node): - def __init__(self, condition: Node, subassignments: List[Assignment]): +class WhileNode(Statement): + def __init__(self, condition: Node, subassignments: List[Statement]): self.condition = condition self.subassignments = subassignments self.conditionVar = TempVar() + def getRanges(self, ranges): + ranges = self.condition.getRanges(ranges) + for subassignment in self.subassignments: + ranges = subassignment.getRanges(ranges) + return ranges + def assignSymbols(self, scopes: Scopes): self.condition.assignSymbols(scopes) for subassignment in self.subassignments: - subassignments.assignSymbols(scopes) + subassignment.assignSymbols(scopes) - def tensors(self): - return [tensor for operand in self.operands for tensor in operand.tensors()] - - def symbols(self): - return [symbol for operand in self.operands for symbol in operand.symbols()] + def assignTensor(self, assigner): + self.condition.assignTensor(assigner) + for subassignment in self.subassignments: + subassignment.assignTensor(assigner) + + def tensors(self, intensors=True, outtensors=True): + tensorlist = [] + if intensors: + tensorlist += self.condition.tensors(intensors, outtensors) + return tensorlist + [tensor for operand in self.subassignments for tensor in operand.tensors(intensors, outtensors)] + def pretensors(self, intensors=True, outtensors=True): + tensorlist = [] + if intensors: + tensorlist += self.condition.pretensors(intensors, outtensors) + return tensorlist + [tensor for operand in self.subassignments for tensor in operand.pretensors(intensors, outtensors)] + + def symbols(self, intensors=True, outtensors=True): + tensorlist = [] + if intensors: + tensorlist += self.condition.symbols(intensors, outtensors) + return tensorlist + [tensor for operand in self.subassignments for tensor in operand.symbols(intensors, outtensors)] + def declare(self, alloc: VarAlloc, writer: Writer, context: Context): - condition.declare(alloc, writer, context) + self.conditionVar.declare(alloc, writer, context) + self.condition.declare(alloc, writer, context) for subassignment in self.subassignments: subassignment.declare(alloc, writer, context) def write(self, alloc: VarAlloc, writer: Writer, context: Context): resultCondition = self.condition.write(alloc, writer, context) - self.conditionVar.store(resultCondition) + self.conditionVar.store(alloc, writer, context, resultCondition) result = self.conditionVar.write(alloc, writer, context) with writer.While(result): for subassignment in self.subassignments: subassignment.write(alloc, writer, context) resultCondition = self.condition.write(alloc, writer, context) - self.conditionVar.store(resultCondition) + self.conditionVar.store(alloc, writer, context, resultCondition) -def writeAssignments(assignments: List[Assignment], writer: Writer, context: Context): +def writeAssignments(assignments: List[Statement], writer: Writer, context: Context): alloc = VarAlloc() for assignment in assignments: assignment.declare(alloc, writer, context) diff --git a/tensorforge/functions.py b/tensorforge/functions.py new file mode 100644 index 0000000..a5a9828 --- /dev/null +++ b/tensorforge/functions.py @@ -0,0 +1,154 @@ +from kernelforge.common.operation import Operation +from kernelforge.common.basic_types import FloatingPointType +from kernelforge.generators.optree import Statement, Node, TensorVar, OpNode, LexicOpNode, ConditionalOpNode, CastOpNode, IfNode, WhileNode, Assignment, Immediate, TempVar +import yateto.ast.node as ytt +from typing import List, Union +import math + +def scalarblock(statements: List[Statement]): + tensors = set() + for statement in statements: + tensors.update(statement.pretensors()) + return ytt.ScalarRegion([tensor for tensor in tensors], statements) + +BaseType = Union[ytt.Node, Node, float, int, bool] + +def assign(target: Union[ytt.Node, TensorVar], source: BaseType): + if isinstance(target, ytt.Node): + target = tensor(target) + return Assignment(target, immc(source)) + +def conditional(condition: BaseType, subnodes: list[Statement]): + return IfNode(immc(condition), subnodes) + +def ternary(condition: BaseType, yesnode: BaseType, nonode: BaseType): + return ConditionalOpNode([immc(condition), immc(yesnode), immc(nonode)], None) + +def loop(condition: BaseType, subnodes: list[Statement]): + return WhileNode(immc(condition), subnodes) + +def imm(value, fptype): + return Immediate(value, FloatingPointType.FLOAT) + +def tensor(x: ytt.Node, slicing=None): + return TensorVar(None, slicing, x) + +def immc(x: BaseType): + if isinstance(x, float): + return imm(x, FloatingPointType.FLOAT) + if isinstance(x, int): + return imm(x, FloatingPointType.INT) + if isinstance(x, bool): + return imm(x, FloatingPointType.BOOL) + if isinstance(x, ytt.Node): + return tensor(x) + return x + +def cos(x: BaseType): + return LexicOpNode([immc(x)], Operation.COS) + +def sin(x: BaseType): + return LexicOpNode([immc(x)], Operation.SIN) + +def tan(x: BaseType): + return LexicOpNode([immc(x)], Operation.TAN) + +def acos(x: BaseType): + return LexicOpNode([immc(x)], Operation.ACOS) + +def asin(x: BaseType): + return LexicOpNode([immc(x)], Operation.ASIN) + +def atan(x: BaseType): + return LexicOpNode([immc(x)], Operation.ATAN) + +def cosh(x: BaseType): + return LexicOpNode([immc(x)], Operation.COSH) + +def sinh(x: BaseType): + return LexicOpNode([immc(x)], Operation.SINH) + +def tanh(x: BaseType): + return LexicOpNode([immc(x)], Operation.TANH) + +def acosh(x: BaseType): + return LexicOpNode([immc(x)], Operation.ACOSH) + +def asinh(x: BaseType): + return LexicOpNode([immc(x)], Operation.ASINH) + +def atanh(x: BaseType): + return LexicOpNode([immc(x)], Operation.ATANH) + +def sqrt(x: BaseType): + return LexicOpNode([immc(x)], Operation.SQRT) + +def cbrt(x: BaseType): + return LexicOpNode([immc(x)], Operation.CBRT) + +def max(x: BaseType, y: BaseType): + return LexicOpNode([immc(x), immc(y)], Operation.MAX) + +def min(x: BaseType, y: BaseType): + return LexicOpNode([immc(x), immc(y)], Operation.MIN) + +def div(x: BaseType, y: BaseType): + xconv = immc(x) + yconv = immc(y) + # TODO: move these optimizations to a visitor + if isinstance(xconv, Immediate): + if xconv.value == 1 or xconv.value == 1.0: + return LexicOpNode([yconv], Operation.RCP) + if isinstance(yconv, Immediate): + if yconv.value == 1 or yconv.value == 1.0: + return xconv + return LexicOpNode([xconv, yconv], Operation.DIV) + +def mod(x: BaseType, y: BaseType): + xconv = immc(x) + yconv = immc(y) + return LexicOpNode([xconv, yconv], Operation.MOD) + +def round(x: BaseType): + return LexicOpNode([immc(x)], Operation.ROUND) + +def rcp(x: BaseType): + return LexicOpNode([immc(x)], Operation.RCP) + +def pow(x: BaseType, y: BaseType): + xconv = immc(x) + yconv = immc(y) + # TODO: move these optimizations to a visitor + if isinstance(yconv, Immediate): + if yconv.value == 2 or yconv.value == 2.0: + return LexicOpNode([xconv, xconv], Operation.MUL) + if yconv.value == 0.5: + return LexicOpNode([xconv], Operation.SQRT) + if yconv.value == -0.5: + return LexicOpNode([xconv], Operation.RSQRT) + if yconv.value == 1/3: + return LexicOpNode([xconv], Operation.CBRT) + if yconv.value == -1/3: + return LexicOpNode([xconv], Operation.RCBRT) + if yconv.value == -1 or yconv.value == -1.0: + return LexicOpNode([xconv], Operation.RCP) + if yconv.value == 1 or yconv.value == 1.0: + return xconv + if isinstance(xconv, Immediate): + if xconv.value == math.e: + return LexicOpNode([immc(y)], Operation.EXP) + if xconv.value == 1 or xconv.value == 1.0: + return xconv + return LexicOpNode([xconv, yconv], Operation.POW) + +def exp(x: BaseType): + return LexicOpNode([immc(x)], Operation.EXP) + +def log(x: BaseType): + return LexicOpNode([immc(x)], Operation.LOG) + +def temp(): + return TempVar() + +def cast(x: Node, fptype: FloatingPointType): + return CastOpNode([immc(x)], fptype) diff --git a/yateto/ast/node.py b/yateto/ast/node.py index e4b6422..2cade4b 100644 --- a/yateto/ast/node.py +++ b/yateto/ast/node.py @@ -9,6 +9,7 @@ def __init__(self): self.indices = None self._children = [] self._eqspp = None + self.datatype = None def size(self): return self.indices.size() @@ -255,6 +256,84 @@ def nonZeroFlops(self): def __str__(self): return '{}: {}'.format(super().__str__(), str(self._scalar)) +# TODO: temporary. +class ScalarOp(Op): + def __init__(self, optype, *ops): + super().__init__(*ops) + self.optype = optype + self.indices = ops[0].indices + for op in ops[1:]: + self.indices = self.indices | op.indices + + def computeSparsityPattern(self, *spps): + if len(spps) == 0: + spps = [node.eqspp() for node in self] + permute_summand = lambda i: self.permute(self[i].indices, spps[i]) + spp = permute_summand(0) + for i in range(1, len(spps)): + add_spp = permute_summand(i) + spp = aspp.add(spp, add_spp) + return spp + + def nonZeroFlops(self): + return 0 + +class ScalarRegion(Op): + def __init__(self, ops, data): + super().__init__(*ops) + self.data = data + self.indices = ops[0].indices + for op in ops[1:]: + common = self.indices & op.indices + self.indices = self.indices.merged(op.indices - common) + + def computeSparsityPattern(self, *spps): + if len(spps) == 0: + spps = [node.eqspp() for node in self] + permute_summand = lambda i: self.permute(self[i].indices, spps[i]) + spp = permute_summand(0) + for i in range(1, len(spps)): + add_spp = permute_summand(i) + spp = aspp.add(spp, add_spp) + return spp + + def nonZeroFlops(self): + return 0 + +class Conditional(Op): + def __init__(self, condition, yesop, noop): + super().__init__(condition, yesop, noop) + + def computeSparsityPattern(self, *spps): + if len(spps) == 0: + spps = [node.eqspp() for node in self] + permute_summand = lambda i: self.permute(self[i].indices, spps[i]) + spp = permute_summand(0) + for i in range(1, len(spps)): + add_spp = permute_summand(i) + spp = aspp.add(spp, add_spp) + return spp + + def nonZeroFlops(self): + return self._children[0].nonZeroFlops() # TODO: add more? + +class Loop(Op): + def __init__(self, condition, update): + super().__init__(condition, update) + + def computeSparsityPattern(self, *spps): + if len(spps) == 0: + spps = [node.eqspp() for node in self] + permute_summand = lambda i: self.permute(self[i].indices, spps[i]) + spp = permute_summand(0) + for i in range(1, len(spps)): + add_spp = permute_summand(i) + spp = aspp.add(spp, add_spp) + return spp + + def nonZeroFlops(self): + return self._children[0].nonZeroFlops() # TODO: add more? + class BinOp(Op): def __init__(self, lTerm, rTerm): super().__init__(lTerm, rTerm) diff --git a/yateto/ast/transformer.py b/yateto/ast/transformer.py index 152ce50..8875ecc 100644 --- a/yateto/ast/transformer.py +++ b/yateto/ast/transformer.py @@ -103,6 +103,9 @@ def visit_Assign(self, node, bound): return node + def visit_ScalarRegion(self, node, bound): + return node + ### Optimal binary tree class StrengthReduction(Transformer): @@ -186,6 +189,11 @@ def visit_Add(self, node): node.setEqspp( node.computeSparsityPattern() ) return node + def visit_ScalarOp(self, node): + self.generic_visit(node) + node.setEqspp( node.computeSparsityPattern() ) + return node + def visit_ScalarMultiplication(self, node): self.generic_visit(node) node.setEqspp(node.term().eqspp()) diff --git a/yateto/codegen/factory.py b/yateto/codegen/factory.py index ce85eb8..97a52d7 100644 --- a/yateto/codegen/factory.py +++ b/yateto/codegen/factory.py @@ -195,6 +195,12 @@ def create_Permute(self, node, result, arguments, add, scalar, prefetchName, rou termTerm = self._formatTerm(arguments[0], node.term().indices) return self._simpleBody(resultTerm, termTerm, add, scalar, node.indices) + def create_ScalarOp(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + return 0 + + def create_ScalarRegion(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + return 0 + def _simpleBody(self, resultTerm, termTerm, add, scalar, indices): ranges = {idx: Range(0, indices.indexSize(idx)) for idx in indices} diff --git a/yateto/codegen/gpukernel.py b/yateto/codegen/gpukernel.py index bed4535..e002572 100644 --- a/yateto/codegen/gpukernel.py +++ b/yateto/codegen/gpukernel.py @@ -1,5 +1,6 @@ from .factory import KernelFactory -from kernelforge.generators.descriptions import MultilinearDescr +from kernelforge.generators.descriptions import MultilinearDescr, ElementwiseDescr +from kernelforge.generators.optree import Assignment, OpNode, TensorVar from common import * from .common import TensorDescription, IndexedTensorDescription, BatchedOperationsAux @@ -23,14 +24,31 @@ def __init__(self, arch): self._descr_list = [] def add_operation(self, dest, ops, target, permute, add): - self._cache_matrices(dest, ops, target, permute) - can_be_aligned = self._can_be_aligned(dest, ops, target, permute) - self._descr_list.append(MultilinearDescr(self._cache[dest.name], - [self._cache[op.name() if isinstance(op, Scalar) else op.name] for op in ops], - target, permute, add=add, - strict_match=False, - prefer_align=can_be_aligned)) - return 0# self._descr_list[-1].get_flops() + self._cache_matrices(dest, ops, target, permute) + can_be_aligned = self._can_be_aligned(dest, ops, target, permute) + self._descr_list.append(MultilinearDescr(self._cache[dest.name], + [self._cache[op.name() if isinstance(op, Scalar) else op.name] for op in ops], + target, permute, add=add, + strict_match=False, + prefer_align=can_be_aligned)) + return 0# self._descr_list[-1].get_flops() + + def add_scalar(self, ops, statements, indices): + indicesIndexed = {} + for i,op in enumerate(ops): + self.make_tensor(op, False, None) + indicesIndexed[op.name() if isinstance(op, Scalar) else op.name] = indices[i] + + def assigner(pretensor): + return self._cache[pretensor.name()], indicesIndexed[pretensor.name()] + + for statement in statements: + statement.assignTensor(assigner) + + self._descr_list.append(ElementwiseDescr(statements, + strict_match=False, + prefer_align=False)) + return 0 def generate(self, cpp, routineCache): context = Context(arch=self._arch.name, @@ -55,42 +73,42 @@ def _can_be_aligned(self, dest, ops, target, permute): return aligned + def make_tensor(self, op, can_be_aligned, dims): + if isinstance(op, Scalar): + entry = self._add_scalar(op) + entry_name = op.name() + else: + # TODO: refine + currentPreShape = list(BoundingBox.fromSpp(op.eqspp)) + if can_be_aligned: + for i, dim in enumerate(dims): + if i == 0 and op.memoryLayout.alignedStride(): # previously: dim == 0 + currentPreShape[i] = currentPreShape[i].aligned(self._arch) + currentShape = [b.stop for b in currentPreShape] + currentRange = list(BoundingBox(Range(0, b) for b in currentShape)) + + entry = self._get_kernelforge_matrix(tensor=op, + tensor_variable=op, + shape=currentShape, + bboxrange=currentRange) + entry_name = op.name + + if not (entry_name in self._cache and entry.is_same(self._cache[entry_name])): + self._cache[entry_name] = entry + def _cache_matrices(self, dest, ops, target, permute): can_be_aligned = self._can_be_aligned(dest, ops, target, permute) - - def make_tensor(op, dims): - if isinstance(op, Scalar): - entry = self._add_scalar(op) - entry_name = op.name() - else: - # TODO: refine - currentPreShape = list(BoundingBox.fromSpp(op.eqspp)) - if can_be_aligned: - for i, dim in enumerate(dims): - if i == 0 and op.memoryLayout.alignedStride(): # previously: dim == 0 - currentPreShape[i] = currentPreShape[i].aligned(self._arch) - currentShape = [b.stop for b in currentPreShape] - currentRange = list(BoundingBox(Range(0, b) for b in currentShape)) - - entry = self._get_kernelforge_matrix(tensor=op, - tensor_variable=op, - shape=currentShape, - bboxrange=currentRange) - entry_name = op.name - - if not (entry_name in self._cache and entry.is_same(self._cache[entry_name])): - self._cache[entry_name] = entry # no add onto a matrix that doesn't exist (TODO: check if that's always the case) assert not(dest.is_temporary and dest in ops) for op, optarget in zip(ops, target): - make_tensor(op, optarget) + self.make_tensor(op, can_be_aligned, optarget) if dest.is_temporary: # (dest is never a scalar---for the time being) self._cache[dest.name] = self._gen_tmp_matix(ops, target, permute, dest.name) else: - make_tensor(dest, [i for i in range(len(dest.indices))]) + self.make_tensor(dest, can_be_aligned, [i for i in range(len(dest.indices))]) def _add_scalar(self, scalar): tensor = Tensor([], Addressing.SCALAR, alias=scalar.name()) @@ -189,27 +207,33 @@ def create(self, node, *args): def create_LoopOverGEMM(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 2 - return self.handle(IndexedTensorDescription.fromNode(result, node), [IndexedTensorDescription.fromNode(arguments[0], node.leftTerm()), IndexedTensorDescription.fromNode(arguments[1], node.rightTerm())], add, scalar, node.transA(), node.transB()) + return self.handleLinear(IndexedTensorDescription.fromNode(result, node), [IndexedTensorDescription.fromNode(arguments[0], node.leftTerm()), IndexedTensorDescription.fromNode(arguments[1], node.rightTerm())], add, scalar, node.transA(), node.transB()) def create_IndexSum(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 1 - return self.handle(IndexedTensorDescription.fromNode(result, node), [IndexedTensorDescription.fromNode(arguments[0], node.term())], add, scalar, False, False) + return self.handleLinear(IndexedTensorDescription.fromNode(result, node), [IndexedTensorDescription.fromNode(arguments[0], node.term())], add, scalar, False, False) def create_Product(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 2 - return self.handle(IndexedTensorDescription.fromNode(result, node), [IndexedTensorDescription.fromNode(arguments[0], node.leftTerm()), IndexedTensorDescription.fromNode(arguments[1], node.rightTerm())], add, scalar, False, False) + return self.handleLinear(IndexedTensorDescription.fromNode(result, node), [IndexedTensorDescription.fromNode(arguments[0], node.leftTerm()), IndexedTensorDescription.fromNode(arguments[1], node.rightTerm())], add, scalar, False, False) def create_Permute(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): term = arguments[0] - return self.handle(IndexedTensorDescription(str(result), node.indices, result.memoryLayout(), result.eqspp()), [IndexedTensorDescription(str(term), node.term().indices, term.memoryLayout(), term.eqspp())], add, scalar, False, False) + return self.handleLinear(IndexedTensorDescription(str(result), node.indices, result.memoryLayout(), result.eqspp()), [IndexedTensorDescription(str(term), node.term().indices, term.memoryLayout(), term.eqspp())], add, scalar, False, False) def simple(self, result, term, add, scalar, routineCache): - return self.handle(IndexedTensorDescription(str(result), self._indices(result), result.memoryLayout(), result.eqspp()), [IndexedTensorDescription(str(term), self._indices(term), term.memoryLayout(), term.eqspp())], add, scalar, False, False) + return self.handleLinear(IndexedTensorDescription(str(result), self._indices(result), result.memoryLayout(), result.eqspp()), [IndexedTensorDescription(str(term), self._indices(term), term.memoryLayout(), term.eqspp())], add, scalar, False, False) - def handle(self, dest, ops, add, scalar, transposeA, transposeB): - # convert indices to loop numbers - - target_indices = dest.indices + def create_ScalarRegion(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + terms = [IndexedTensorDescription.fromNode(arg, terms) for arg, terms in zip(arguments, node)] + target, permute = self.getIndices(None, terms) + return self.generator.add_scalar(terms, node.data, target) + + def getIndices(self, dest, ops): + if dest is None: + target_indices = [] + else: + target_indices = dest.indices indexindex = {index:i for i, index in enumerate(target_indices)} contract_counter = -1 @@ -222,6 +246,13 @@ def handle(self, dest, ops, add, scalar, transposeA, transposeB): target = [[indexindex[index] for index in op.indices] for op in ops] permute = [[i for i,_ in enumerate(op.indices)] for op in ops] + + return target, permute + + def handleLinear(self, dest, ops, add, scalar, transposeA, transposeB): + # convert indices to loop numbers + + target, permute = self.getIndices(dest, ops) if not (scalar == 1 or scalar == 1.0): ops += [scalar]