Skip to content

Commit

Permalink
Remove extra offsets
Browse files Browse the repository at this point in the history
  • Loading branch information
davschneller committed Dec 18, 2024
1 parent 9f886a5 commit e227f39
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
8 changes: 4 additions & 4 deletions tensorforge/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ def deduce_addresing(self, term):

def deduce_arg(self, term, as_const=False):
if term.is_compute_constant or term.is_temporary:
extra_offset = '0'
extra_offset = ''
else:
extra_offset = f'{self.EXTRA_OFFSET_NAME}_{term.name}'
extra_offset = f', {self.EXTRA_OFFSET_NAME}_{term.name}'

if as_const:
addressing = self.deduce_addresing(term)
ptr = self._get_ptr_type(addressing)
datatype = self.underlying_data_type if term.datatype is None else term.datatype
const_ptr_type = f'const {datatype} {ptr}'
return f'const_cast<{const_ptr_type}>({term.name}), {extra_offset}'
return f'const_cast<{const_ptr_type}>({term.name}){extra_offset}'
else:
return f'{term.name}, {extra_offset}'
return f'{term.name}{extra_offset}'
22 changes: 12 additions & 10 deletions tensorforge/generators/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ def __init__(self, context, global_mem, mem_size_per_mult, num_threads):
super().__init__(context, global_mem, mem_size_per_mult, num_threads)

def get_num_mults_per_block(self):
return 1
# the //2 is a heuristic
# self._max_threads // self._num_threads // 2
max_thread_mults = 256 // self._num_threads
# max_thread_mults = 256 // self._num_threads
if self._mem_per_mult == 0:
return max_thread_mults
else:
Expand Down Expand Up @@ -210,8 +211,8 @@ def _deduce_num_threads(self):
if isinstance(gemm_descr, ElementwiseDescr):
compress = False
break
if compress:
self._num_threads = 32
# if compress:
# self._num_threads = 32

def _deduce_accumulator_size(self):
for descr in self.descr_list:
Expand All @@ -232,13 +233,13 @@ def _emit_global_ir(self):
builder = GetElementPtrBuilder(self._context, self._scopes)
for symbol in self._scopes.get_global_scope().values():
if symbol.obj.addressing == Addressing.SCALAR or (symbol.obj.addressing == Addressing.NONE and symbol.stype == SymbolType.Data):
builder.build(symbol)
builder.build(symbol, symbol.obj.addressing == Addressing.PTR_BASED)
self._global_ir.extend(builder.get_instructions())

builder = GlobalLoaderBuilder(self._context, self._scopes, self._shr_mem_obj, self._num_threads)
for symbol in self._scopes.get_global_scope().values():
if symbol.obj.addressing == Addressing.NONE and symbol.stype != SymbolType.Data:
builder.build(symbol)
builder.build(symbol, symbol.obj.addressing == Addressing.PTR_BASED)
self._global_ir.extend(builder.get_instructions())

self._global_ir.append(SyncBlock(self._context, self._num_threads))
Expand All @@ -249,7 +250,7 @@ def _emit_ir(self):
self._scopes.add_scope()
for symbol in self._scopes.get_global_scope().values():
if not self._preload_globals or not (symbol.obj.addressing == Addressing.NONE or symbol.obj.addressing == Addressing.SCALAR):
builder.build(symbol)
builder.build(symbol, symbol.obj.addressing == Addressing.PTR_BASED)
self._ir.extend(builder.get_instructions())

self._scopes.add_scope()
Expand Down Expand Up @@ -474,7 +475,7 @@ def generate_call_site(self,
for symbol in symbols:
if symbol.obj.alias in mat_name_map:
args.append(mat_name_map[symbol.obj.alias])
if symbol.obj.addressing != Addressing.SCALAR:
if symbol.obj.addressing == Addressing.PTR_BASED:
args.append(offset_name_map[symbol.obj.alias])

# add num. elements
Expand All @@ -499,7 +500,8 @@ def _get_element_size_guard(self):
return f'{GeneralLexicon.BATCH_ID_NAME} < {GeneralLexicon.NUM_ELEMENTS}'

def _get_flag_guard(self, writer):
writer(f'bool isFlagsProvided = ({GeneralLexicon.FLAGS_NAME} != nullptr);')
flag_value = f'static_cast<bool>({GeneralLexicon.FLAGS_NAME}[{GeneralLexicon.BATCH_ID_NAME}])'
writer(f'bool allowed = isFlagsProvided ? {flag_value} : true;')
writer(f'bool allowed = true;')
# with writer.If('{GeneralLexicon.FLAGS_NAME} != nullptr'):
# flag_value = f'static_cast<bool>({GeneralLexicon.FLAGS_NAME}[{GeneralLexicon.BATCH_ID_NAME}])'
# writer(f'allowed = {flag_value};')
return 'allowed'

0 comments on commit e227f39

Please sign in to comment.