Skip to content

Commit

Permalink
Second target kernel fix pass
Browse files Browse the repository at this point in the history
  • Loading branch information
davschneller committed Dec 16, 2024
1 parent 380d574 commit 9c1fd45
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tensorforge/common/vm/lexic/target_lexic.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def __enter__(self):
for symbol in batched_symbols_in:
file(f'static std::unordered_map<const {precision}**, {precision}(*)[{symbol.obj.get_real_volume()}]> {symbol.name}_datamap;')
with file.If(f'{symbol.name}_datamap[{symbol.name}] == nullptr'):
file(f'{symbol.name}_datamap[{symbol.name}] = reinterpret_cast<decltype({symbol.name}_ptr)>(std::malloc(sizeof({precision}[{symbol.obj.get_real_volume()}]) * numElements));')
file(f'{symbol.name}_datamap[{symbol.name}] = reinterpret_cast<{precision}(*)[{symbol.obj.get_real_volume()}]>(omp_alloc(sizeof({precision}[{symbol.obj.get_real_volume()}]) * numElements, llvm_omp_target_shared_mem_alloc));')
file(f'auto* {symbol.name}_ptr = {symbol.name}_datamap[{symbol.name}];')
for symbol in batched_symbols_out + batched_symbols_inout:
file(f'static std::unordered_map<{precision}**, {precision}(*)[{symbol.obj.get_real_volume()}]> {symbol.name}_datamap;')
with file.If(f'{symbol.name}_datamap[{symbol.name}] == nullptr'):
file(f'{symbol.name}_datamap[{symbol.name}] = reinterpret_cast<decltype({symbol.name}_ptr)>(std::malloc(sizeof({precision}[{symbol.obj.get_real_volume()}]) * numElements));')
file(f'{symbol.name}_datamap[{symbol.name}] = reinterpret_cast<{precision}(*)[{symbol.obj.get_real_volume()}]>(omp_alloc(sizeof({precision}[{symbol.obj.get_real_volume()}]) * numElements, llvm_omp_target_shared_mem_alloc));')
file(f'auto* {symbol.name}_ptr = {symbol.name}_datamap[{symbol.name}];')
if len(batched_symbols_in + batched_symbols_inout) > 0:
file(f'#pragma omp target nowait depend(inout: streamobj[0]) map(from: {", ".join(f"{symbol.name}_ptr[0:numElements]" for symbol in batched_symbols_in + batched_symbols_inout)}) is_device_ptr({", ".join(symbol.name for symbol in batched_symbols_in + batched_symbols_inout)}) {device}')
file(f'#pragma omp target nowait depend(inout: streamobj[0]) is_device_ptr({", ".join(f"{symbol.name}_ptr" for symbol in batched_symbols_in + batched_symbols_inout)}) is_device_ptr({", ".join(symbol.name for symbol in batched_symbols_in + batched_symbols_inout)}) {device}')
with file.Scope():
for symbol in batched_symbols_in + batched_symbols_inout:
file('#pragma omp loop collapse(2)')
Expand All @@ -122,7 +122,7 @@ def __enter__(self):
file(f'{symbol.name}_ptr[j][i] = {symbol.name}[j][i];')
if len(batched_symbols_out + batched_symbols_inout) > 0:
def epilogue():
file(f'#pragma omp target nowait depend(inout: streamobj[0]) map(to: {", ".join(f"{symbol.name}_ptr[0:numElements]" for symbol in batched_symbols_out + batched_symbols_inout)}) is_device_ptr({", ".join(symbol.name for symbol in batched_symbols_out + batched_symbols_inout)}) {device}')
file(f'#pragma omp target nowait depend(inout: streamobj[0]) is_device_ptr({", ".join(f"{symbol.name}_ptr" for symbol in batched_symbols_out + batched_symbols_inout)}) is_device_ptr({", ".join(symbol.name for symbol in batched_symbols_out + batched_symbols_inout)}) {device}')
with file.Scope():
for symbol in batched_symbols_out + batched_symbols_inout:
file('#pragma omp loop collapse(2)')
Expand Down

0 comments on commit 9c1fd45

Please sign in to comment.