Skip to content

Commit

Permalink
Some small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
davschneller committed Dec 14, 2024
1 parent 9553af1 commit f9404ab
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 14 deletions.
23 changes: 14 additions & 9 deletions tensorforge/backend/instructions/builders/multilinear_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self,
self._dest_regs = None

self._use_registers_always = False
self._temporary_registers = False
self._deferred_stores = {}
self._temporaries = {}

Expand Down Expand Up @@ -211,15 +212,18 @@ def _make_store(self):
raise InternalError(f'gemm-buider: `res` is not in scopes and thus must be tmp')

dest_symbol = Symbol(name=self._name_shr_reg(),
stype=SymbolType.SharedMem,
obj=self._dest_obj.tensor)
stype=SymbolType.SharedMem,
obj=self._dest_obj.tensor)

self._scopes.add_symbol(dest_symbol)
self._instructions.append(StoreRegToShr(context=self._context,
src=self._temp_regs,
dest=dest_symbol,
shr_mem=self._shr_mem,
num_threads=self._num_threads))
if self._temporary_registers:
self._deferred_stores[dest_symbol.name] = (self._temp_regs, dest_symbol)
else:
self._scopes.add_symbol(dest_symbol)
self._instructions.append(StoreRegToShr(context=self._context,
src=self._temp_regs,
dest=dest_symbol,
shr_mem=self._shr_mem,
num_threads=self._num_threads))

def _insert_sync_block(self):
self._instructions.append(SyncThreads(context=self._context,
Expand All @@ -233,7 +237,8 @@ def _name_shr_reg(self):
def build_epilogue(self):
self._reset()
for store_regs, store_global in self._deferred_stores.values():
self._instructions.append(StoreRegToGlb(context=self._context,
if store_global.stype == SymbolType.Global:
self._instructions.append(StoreRegToGlb(context=self._context,
src=store_regs,
dest=store_global,
alpha=1,#self._descr.alpha,
Expand Down
2 changes: 1 addition & 1 deletion tensorforge/backend/opt/shr_mem_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def apply(self) -> None:
self._check_regions()

max_memory, mem_per_region = self._compute_total_shr_mem_size()
self._shr_mem_obj.set_size_per_mult(max_memory)
self._shr_mem_obj.set_size_per_mult(max_memory) # TODO:

offsets = self._compute_start_addresses(mem_per_region)
self._assign_offsets(offsets)
Expand Down
6 changes: 6 additions & 0 deletions tensorforge/common/vm/hw_descr_db.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
- arch: sm_90
base: sm_60
max_local_mem_size_per_block: 227kB
- arch: sm_100
base: sm_60
max_local_mem_size_per_block: 227kB
- arch: gfx900
vec_unit_length: 64
max_local_mem_size_per_block: 64kB
Expand All @@ -63,6 +66,9 @@
base: gfx90a
- arch: gfx942
base: gfx90a
- arch: gfx950
base: gfx90a
max_local_mem_size_per_block: 160kB
- arch: gfx1010
vec_unit_length: 32
max_local_mem_size_per_block: 128kB # assume WGP mode
Expand Down
18 changes: 16 additions & 2 deletions tensorforge/common/vm/lexic/target_lexic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(self, backend, underlying_hardware):
self.thread_idx_z = "tz"
self.block_idx_x = "bx"
self.block_idx_z = "bz"
self.block_dim_x = "tX"
self.block_dim_y = "tY"
self.block_dim_z = "tZ"
self.grid_dim_x = "omp_get_num_teams()"
Expand All @@ -20,11 +21,20 @@ def __init__(self, backend, underlying_hardware):
def multifile(self):
return False

def get_launch_size(self, func_name, block, shmem):
return ''

def set_shmem_size(self, func_name, shmem):
return ''

def get_launch_code(self, func_name, grid, block, stream, func_params, shmem):
return f"{func_name}({stream}, {grid}[0], {block}[0], {block}[1], {func_params})"

def declare_shared_memory_inline(self, name, precision, size, alignment):
return ""

def declare_shared_memory(self, name, precision):
return ""

def kernel_definition(self, file, kernel_bounds, base_name, params, precision=None, total_shared_mem_size=None, global_symbols=None):
bounds = "*".join(str(kb) for kb in kernel_bounds)
Expand Down Expand Up @@ -106,7 +116,7 @@ def __enter__(self):
file(f'#pragma omp target nowait depend(inout: streamobj[0]) map(to:bX) map(from: {", ".join(f"{symbol.name}_ptr[0:bX]" for symbol in batched_symbols_in + batched_symbols_inout)}) is_device_ptr({", ".join(symbol.name for symbol in batched_symbols_in + batched_symbols_inout)}) {device}')
with file.Scope():
for symbol in batched_symbols_in + batched_symbols_inout:
file('#pragma omp loop nowait collapse(2)')
file('#pragma omp loop collapse(2)')
with file.For(f'int j = 0; j < bX; ++j'):
with file.For(f'int i = 0; i < {symbol.obj.get_real_volume()}; ++i'):
file(f'{symbol.name}_ptr[j][i] = {symbol.name}[j][i];')
Expand All @@ -115,11 +125,13 @@ def epilogue():
file(f'#pragma omp target nowait depend(inout: streamobj[0]) map(to:bX) map(to: {", ".join(f"{symbol.name}_ptr[0:bX]" for symbol in batched_symbols_out + batched_symbols_inout)}) is_device_ptr({", ".join(symbol.name for symbol in batched_symbols_out + batched_symbols_inout)}) {device}')
with file.Scope():
for symbol in batched_symbols_out + batched_symbols_inout:
file('#pragma omp loop nowait collapse(2)')
file('#pragma omp loop collapse(2)')
with file.For(f'int j = 0; j < bX; ++j'):
with file.For(f'int i = 0; i < {symbol.obj.get_real_volume()}; ++i'):
file(f'{symbol.name}[j][i] = {symbol.name}_ptr[j][i];')
self.epilogue = epilogue
else:
self.epilogue = lambda: None

batched_symbols_out_str = f'map(from: {", ".join(f"{symbol.name}_ptr[0:bX]" for symbol in batched_symbols_out)})' if len(batched_symbols_out) > 0 else ''
batched_symbols_in_str = f'map(to: {", ".join(f"{symbol.name}_ptr[0:bX]" for symbol in batched_symbols_in)})' if len(batched_symbols_in) > 0 else ''
Expand Down Expand Up @@ -196,6 +208,8 @@ def get_fptype(self, fptype, length=1):
return f'__attribute__ ((vector_size (sizeof({fptype}) * {length}))) {fptype}'

def get_operation(self, op: Operation, fptype, value1, value2):
fpsuffix = 'f' if fptype == FloatingPointType.FLOAT else ''
fpprefix = 'f' if fptype == FloatingPointType.FLOAT else 'd'
if op == Operation.COPY:
return value1
elif op == Operation.ADD:
Expand Down
4 changes: 2 additions & 2 deletions tensorforge/generators/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def __init__(self,
self._check_consistency_with_user_options()
self._name_operands(self.descr_list)

self._persistent_threading = True
self._preload_globals = True
self._persistent_threading = False
self._preload_globals = False

def set_kernel_name(self, name):
self._base_kernel_name = name
Expand Down

0 comments on commit f9404ab

Please sign in to comment.