Skip to content

Commit

Permalink
Fix tD lexic
Browse files Browse the repository at this point in the history
  • Loading branch information
davschneller committed Dec 16, 2024
1 parent f9404ab commit 380d574
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions tensorforge/common/vm/lexic/target_lexic.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,41 +104,41 @@ def __enter__(self):
deviceAny = 'device(TARGETDART_ANY)'
for symbol in batched_symbols_in:
file(f'static std::unordered_map<const {precision}**, {precision}(*)[{symbol.obj.get_real_volume()}]> {symbol.name}_datamap;')
with file.If(f'{symbol.name}_datamap[{symbol.name}] == nullptr'):
file(f'{symbol.name}_datamap[{symbol.name}] = reinterpret_cast<decltype({symbol.name}_ptr)>(std::malloc(sizeof({precision}[{symbol.obj.get_real_volume()}]) * numElements));')
file(f'auto* {symbol.name}_ptr = {symbol.name}_datamap[{symbol.name}];')
with file.If(f'{symbol.name}_ptr == nullptr'):
file(f'{symbol.name}_ptr = reinterpret_cast<decltype({symbol.name}_ptr)>(std::malloc(sizeof({precision}[{symbol.obj.get_real_volume()}]) * bX));')
for symbol in batched_symbols_out + batched_symbols_inout:
file(f'static std::unordered_map<{precision}**, {precision}(*)[{symbol.obj.get_real_volume()}]> {symbol.name}_datamap;')
with file.If(f'{symbol.name}_datamap[{symbol.name}] == nullptr'):
file(f'{symbol.name}_datamap[{symbol.name}] = reinterpret_cast<decltype({symbol.name}_ptr)>(std::malloc(sizeof({precision}[{symbol.obj.get_real_volume()}]) * numElements));')
file(f'auto* {symbol.name}_ptr = {symbol.name}_datamap[{symbol.name}];')
with file.If(f'{symbol.name}_ptr == nullptr'):
file(f'{symbol.name}_ptr = reinterpret_cast<decltype({symbol.name}_ptr)>(std::malloc(sizeof({precision}[{symbol.obj.get_real_volume()}]) * bX));')
if len(batched_symbols_in + batched_symbols_inout) > 0:
file(f'#pragma omp target nowait depend(inout: streamobj[0]) map(to:bX) map(from: {", ".join(f"{symbol.name}_ptr[0:bX]" for symbol in batched_symbols_in + batched_symbols_inout)}) is_device_ptr({", ".join(symbol.name for symbol in batched_symbols_in + batched_symbols_inout)}) {device}')
file(f'#pragma omp target nowait depend(inout: streamobj[0]) map(from: {", ".join(f"{symbol.name}_ptr[0:numElements]" for symbol in batched_symbols_in + batched_symbols_inout)}) is_device_ptr({", ".join(symbol.name for symbol in batched_symbols_in + batched_symbols_inout)}) {device}')
with file.Scope():
for symbol in batched_symbols_in + batched_symbols_inout:
file('#pragma omp loop collapse(2)')
with file.For(f'int j = 0; j < bX; ++j'):
with file.For(f'int j = 0; j < numElements; ++j'):
with file.For(f'int i = 0; i < {symbol.obj.get_real_volume()}; ++i'):
file(f'{symbol.name}_ptr[j][i] = {symbol.name}[j][i];')
if len(batched_symbols_out + batched_symbols_inout) > 0:
def epilogue():
file(f'#pragma omp target nowait depend(inout: streamobj[0]) map(to:bX) map(to: {", ".join(f"{symbol.name}_ptr[0:bX]" for symbol in batched_symbols_out + batched_symbols_inout)}) is_device_ptr({", ".join(symbol.name for symbol in batched_symbols_out + batched_symbols_inout)}) {device}')
file(f'#pragma omp target nowait depend(inout: streamobj[0]) map(to: {", ".join(f"{symbol.name}_ptr[0:numElements]" for symbol in batched_symbols_out + batched_symbols_inout)}) is_device_ptr({", ".join(symbol.name for symbol in batched_symbols_out + batched_symbols_inout)}) {device}')
with file.Scope():
for symbol in batched_symbols_out + batched_symbols_inout:
file('#pragma omp loop collapse(2)')
with file.For(f'int j = 0; j < bX; ++j'):
with file.For(f'int j = 0; j < numElements; ++j'):
with file.For(f'int i = 0; i < {symbol.obj.get_real_volume()}; ++i'):
file(f'{symbol.name}[j][i] = {symbol.name}_ptr[j][i];')
self.epilogue = epilogue
else:
self.epilogue = lambda: None

batched_symbols_out_str = f'map(from: {", ".join(f"{symbol.name}_ptr[0:bX]" for symbol in batched_symbols_out)})' if len(batched_symbols_out) > 0 else ''
batched_symbols_in_str = f'map(to: {", ".join(f"{symbol.name}_ptr[0:bX]" for symbol in batched_symbols_in)})' if len(batched_symbols_in) > 0 else ''
batched_symbols_inout_str = f'map(tofrom: {", ".join(f"{symbol.name}_ptr[0:bX]" for symbol in batched_symbols_inout)})' if len(batched_symbols_inout) > 0 else ''
strided_symbols_out_str = f'map(from: {", ".join(f"{symbol.name}[0:{symbol.obj.get_real_volume()}*bX]" for symbol in strided_symbols_out)})' if len(strided_symbols_out) > 0 else ''
strided_symbols_in_str = f'map(to: {", ".join(f"{symbol.name}[0:{symbol.obj.get_real_volume()}*bX]" for symbol in strided_symbols_in)})' if len(strided_symbols_in) > 0 else ''
strided_symbols_inout_str = f'map(tofrom: {", ".join(f"{symbol.name}[0:{symbol.obj.get_real_volume()}*bX]" for symbol in strided_symbols_inout)})' if len(strided_symbols_inout) > 0 else ''
batched_symbols_out_str = f'map(from: {", ".join(f"{symbol.name}_ptr[0:numElements]" for symbol in batched_symbols_out)})' if len(batched_symbols_out) > 0 else ''
batched_symbols_in_str = f'map(to: {", ".join(f"{symbol.name}_ptr[0:numElements]" for symbol in batched_symbols_in)})' if len(batched_symbols_in) > 0 else ''
batched_symbols_inout_str = f'map(tofrom: {", ".join(f"{symbol.name}_ptr[0:numElements]" for symbol in batched_symbols_inout)})' if len(batched_symbols_inout) > 0 else ''
strided_symbols_out_str = f'map(from: {", ".join(f"{symbol.name}[0:{symbol.obj.get_real_volume()}*numElements]" for symbol in strided_symbols_out)})' if len(strided_symbols_out) > 0 else ''
strided_symbols_in_str = f'map(to: {", ".join(f"{symbol.name}[0:{symbol.obj.get_real_volume()}*numElements]" for symbol in strided_symbols_in)})' if len(strided_symbols_in) > 0 else ''
strided_symbols_inout_str = f'map(tofrom: {", ".join(f"{symbol.name}[0:{symbol.obj.get_real_volume()}*numElements]" for symbol in strided_symbols_inout)})' if len(strided_symbols_inout) > 0 else ''
constant_symbols_str = f'map(to: {", ".join(f"{symbol.name}[0:{symbol.obj.get_real_volume()}]" for symbol in constant_symbols)})' if len(constant_symbols) > 0 else ''
# TODO: map offsets
file(f'#pragma omp target teams nowait num_teams(bX) map(to:tX) depend(inout: streamobj[0]) {constant_symbols_str} {strided_symbols_in_str} {strided_symbols_out_str} {strided_symbols_inout_str} {batched_symbols_in_str} {batched_symbols_out_str} {batched_symbols_inout_str} thread_limit({bounds}) {deviceAny}')
Expand Down Expand Up @@ -210,6 +210,7 @@ def get_fptype(self, fptype, length=1):
def get_operation(self, op: Operation, fptype, value1, value2):
fpsuffix = 'f' if fptype == FloatingPointType.FLOAT else ''
fpprefix = 'f' if fptype == FloatingPointType.FLOAT else 'd'
fpname = str(fptype)
if op == Operation.COPY:
return value1
elif op == Operation.ADD:
Expand All @@ -227,9 +228,9 @@ def get_operation(self, op: Operation, fptype, value1, value2):
elif op == Operation.NEG:
return f'(-{value1})'
elif op == Operation.MIN:
return f'std::min({value1}, {value2})'
return f'std::min({fpname}({value1}), {fpname}({value2}))'
elif op == Operation.MAX:
return f'std::max({value1}, {value2})'
return f'std::max({fpname}({value1}), {fpname}({value2}))'
elif op == Operation.POW:
return f'std::pow({value1}, {value2})'
elif op == Operation.EXP:
Expand Down

0 comments on commit 380d574

Please sign in to comment.