Skip to content

Commit

Permalink
compiler: convert all in visitors to f-string
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 30, 2025
1 parent 33d300d commit d68ecfe
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 78 deletions.
8 changes: 4 additions & 4 deletions devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,10 @@ def __init__(self):
_cstd = 'c99'

def __init__(self, **kwargs):
_name = kwargs.pop('name', self.__class__.__name__)
if isinstance(_name, Compiler):
_name = _name.name
self._name = _name
name = kwargs.pop('name', self.__class__.__name__)
if isinstance(name, Compiler):
name = name.name
self._name = name

super().__init__(**kwargs)

Expand Down
7 changes: 2 additions & 5 deletions devito/data/allocators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import abc
from functools import reduce
from operator import mul
import ctypes
from ctypes.util import find_library
import mmap
Expand All @@ -11,7 +9,7 @@

from devito.logger import logger
from devito.parameters import configuration
from devito.tools import is_integer, dtype_alloc_ctype
from devito.tools import is_integer, infer_datasize

__all__ = ['ALLOC_ALIGNED', 'ALLOC_NUMA_LOCAL', 'ALLOC_NUMA_ANY',
'ALLOC_KNL_MCDRAM', 'ALLOC_KNL_DRAM', 'ALLOC_GUARD',
Expand Down Expand Up @@ -92,8 +90,7 @@ def initialize(cls):
return

def alloc(self, shape, dtype, padding=0):
ctype, c_scale = dtype_alloc_ctype(dtype)
datasize = int(reduce(mul, shape) * c_scale)
ctype, datasize = infer_datasize(dtype, shape)

# Add padding, if any
try:
Expand Down
125 changes: 67 additions & 58 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,23 +80,24 @@ def indent(self):
return ' ' * self._depth

def visit_Node(self, o):
return self.indent + '<%s>' % o.__class__.__name__
return self.indent + f'<{o.__class__.__name__}>'

def visit_Generable(self, o):
body = ' %s' % str(o) if self.verbose else ''
return self.indent + '<C.%s%s>' % (o.__class__.__name__, body)
body = f" {str(o) if self.verbose else ''}"
return self.indent + f'<C.{o.__class__.__name__}{body}>'

def visit_Callable(self, o):
self._depth += 1
body = self._visit(o.children)
self._depth -= 1
return self.indent + '<Callable %s>\n%s' % (o.name, body)
return self.indent + f'<Callable {o.name}>\n{body}'

def visit_CallableBody(self, o):
self._depth += 1
body = [self._visit(o.init), self._visit(o.unpacks), self._visit(o.body)]
self._depth -= 1
return self.indent + "%s\n%s" % (o.__repr__(), '\n'.join([i for i in body if i]))
cbody = '\n'.join([i for i in body if i])
return self.indent + f"{o.__repr__()}\n{cbody}"

def visit_list(self, o):
return ('\n').join([self._visit(i) for i in o])
Expand All @@ -111,62 +112,64 @@ def visit_List(self, o):
else:
body = [self._visit(o.body)]
self._depth -= 1
return self.indent + "%s\n%s" % (o.__repr__(), '\n'.join(body))
cbody = '\n'.join(body)
return self.indent + f"{o.__repr__()}\n{cbody}"

def visit_TimedList(self, o):
self._depth += 1
body = [self._visit(o.body)]
self._depth -= 1
return self.indent + "%s\n%s" % (o.__repr__(), '\n'.join(body))
cbody = '\n'.join(body)
return self.indent + f"{o.__repr__()}\n{cbody}"

def visit_Iteration(self, o):
self._depth += 1
body = self._visit(o.children)
self._depth -= 1
if self.verbose:
detail = '::%s::%s' % (o.index, o.limits)
detail = f'::{o.index}::{o.limits}'
props = [str(i) for i in o.properties]
props = '[%s] ' % ','.join(props) if props else ''
cprops = ','.join(props) if props else ''
props = f'[{cprops}] '
else:
detail, props = '', ''
return self.indent + "<%sIteration %s%s>\n%s" % (props, o.dim.name, detail, body)
return self.indent + f"<{props}Iteration {o.dim.name}{detail}>\n{body}"

def visit_While(self, o):
self._depth += 1
body = self._visit(o.children)
self._depth -= 1
return self.indent + "<While %s>\n%s" % (o.condition, body)
return self.indent + f"<While {o.condition}>\n{body}"

def visit_Expression(self, o):
if self.verbose:
body = "%s = %s" % (o.expr.lhs, o.expr.rhs)
return self.indent + "<Expression %s>" % body
body = f"{o.expr.lhs} = {o.expr.rhs}"
return self.indent + f"<Expression {body}>"
else:
return self.indent + str(o)

def visit_AugmentedExpression(self, o):
if self.verbose:
body = "%s %s= %s" % (o.expr.lhs, o.op, o.expr.rhs)
return self.indent + "<%s %s>" % (o.__class__.__name__, body)
body = f"{o.expr.lhs} {o.op}= {o.expr.rhs}"
return self.indent + f"<{o.__class__.__name__} {body}>"
else:
return self.indent + str(o)

def visit_HaloSpot(self, o):
self._depth += 1
body = self._visit(o.children)
self._depth -= 1
return self.indent + "%s\n%s" % (o.__repr__(), body)
return self.indent + f"{o.__repr__()}\n{body}"

def visit_Conditional(self, o):
self._depth += 1
then_body = self._visit(o.then_body)
self._depth -= 1
if o.else_body:
else_body = self._visit(o.else_body)
return self.indent + "<If %s>\n%s\n<Else>\n%s" % (o.condition,
then_body, else_body)
return self.indent + f"<If {o.condition}>\n{then_body}\n<Else>\n{else_body}"
else:
return self.indent + "<If %s>\n%s" % (o.condition, then_body)
return self.indent + f"<If {o.condition}>\n{then_body}"


class CGen(Visitor):
Expand Down Expand Up @@ -249,20 +252,20 @@ def _gen_value(self, obj, mode=1, masked=()):

if (obj._mem_stack or obj._mem_constant) and mode == 1:
strtype = self.ccode(obj._C_typedata)
strshape = ''.join('[%s]' % self.ccode(i) for i in obj.symbolic_shape)
strshape = ''.join(f'[{self.ccode(i)}]' for i in obj.symbolic_shape)
else:
strtype = self.ccode(obj._C_ctype)
strshape = ''
if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1:
if not obj._mem_stack:
strtype = '%s%s' % (strtype, self._restrict_keyword)
strtype = f'{strtype}{self._restrict_keyword}'
strtype = ' '.join(qualifiers + [strtype])

if obj.is_LocalObject and obj._C_modifier is not None and mode == 2:
strtype += obj._C_modifier

strname = obj._C_name
strobj = '%s%s' % (strname, strshape)
strobj = f'{strname}{strshape}'

if obj.is_LocalObject and obj.cargs and mode == 1:
arguments = [self.ccode(i) for i in obj.cargs]
Expand Down Expand Up @@ -386,16 +389,16 @@ def visit_PointerCast(self, o):

if f.is_PointerArray:
# lvalue
lvalue = c.Value(cstr, '**%s' % f.name)
lvalue = c.Value(cstr, f'**{f.name}')

# rvalue
if isinstance(o.obj, ArrayObject):
v = '%s->%s' % (o.obj.name, f._C_name)
v = f'{o.obj.name}->{f._C_name}'
elif isinstance(o.obj, IndexedData):
v = f._C_name
else:
assert False
rvalue = '(%s**) %s' % (cstr, v)
rvalue = f'({cstr}**) {v}'

else:
# lvalue
Expand All @@ -404,12 +407,12 @@ def visit_PointerCast(self, o):
else:
v = f.name
if o.flat is None:
shape = ''.join("[%s]" % self.ccode(i) for i in o.castshape)
rshape = '(*)%s' % shape
shape = ''.join(f"[{self.ccode(i)}]" for i in o.castshape)
rshape = f'(*){shape}'
lvalue = c.Value(cstr, f'(*{self._restrict_keyword} {v}){shape}')
else:
rshape = '*'
lvalue = c.Value(cstr, '*%s' % v)
lvalue = c.Value(cstr, f'*{v}')
if o.alignment and f._data_alignment:
lvalue = c.AlignedAttribute(f._data_alignment, lvalue)

Expand All @@ -422,14 +425,14 @@ def visit_PointerCast(self, o):
else:
assert False

rvalue = '(%s %s) %s->%s' % (cstr, rshape, f._C_name, v)
rvalue = f'({cstr} {rshape}) {f._C_name}->{v}'
else:
if isinstance(o.obj, Pointer):
v = o.obj.name
else:
v = f._C_name

rvalue = '(%s %s) %s' % (cstr, rshape, v)
rvalue = f'({cstr} {rshape}) {v}'

return c.Initializer(lvalue, rvalue)

Expand All @@ -439,17 +442,19 @@ def visit_Dereference(self, o):
i = a1.indexed
cstr = self.ccode(i._C_typedata)
if o.flat is None:
shape = ''.join("[%s]" % self.ccode(i) for i in a0.symbolic_shape[1:])
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,
a1.dim.name)
shape = ''.join(f"[{self.ccode(i)}]" for i in a0.symbolic_shape[1:])
rvalue = f'({cstr} (*){shape}) {a1.name}[{a1.dim.name}]'
lvalue = c.Value(cstr, f'(*{self._restrict_keyword} {a0.name}){shape}')
else:
rvalue = '(%s *) %s[%s]' % (cstr, a1.name, a1.dim.name)
rvalue = f'({cstr} *) {a1.name}[{a1.dim.name}]'
lvalue = c.Value(cstr, f'*{self._restrict_keyword} {a0.name}')
if a0._data_alignment:
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
else:
rvalue = '*%s' % a1.name if a1.is_Symbol else '%s->%s' % (a1.name, a0._C_name)
if a1.is_Symbol:
rvalue = f'*{a1.name}'
else:
rvalue = f'{a1.name}->{a0._C_name}'
lvalue = self._gen_value(a0, 0)
return c.Initializer(lvalue, rvalue)

Expand Down Expand Up @@ -494,7 +499,7 @@ def visit_Expression(self, o):
def visit_AugmentedExpression(self, o):
c_lhs = self.ccode(o.expr.lhs, dtype=o.dtype)
c_rhs = self.ccode(o.expr.rhs, dtype=o.dtype)
code = c.Statement("%s %s= %s" % (c_lhs, o.op, c_rhs))
code = c.Statement(f"{c_lhs} {o.op}= {c_rhs}")
if o.pragmas:
code = c.Module(self._visit(o.pragmas) + (code,))
return code
Expand Down Expand Up @@ -538,23 +543,23 @@ def visit_Iteration(self, o):

# For backward direction flip loop bounds
if o.direction == Backward:
loop_init = 'int %s = %s' % (o.index, self.ccode(_max))
loop_cond = '%s >= %s' % (o.index, self.ccode(_min))
loop_inc = '%s -= %s' % (o.index, o.limits[2])
loop_init = f'int {o.index} = {self.ccode(_max)}'
loop_cond = f'{o.index} >= {self.ccode(_min)}'
loop_inc = f'{o.index} -= {o.limits[2]}'
else:
loop_init = 'int %s = %s' % (o.index, self.ccode(_min))
loop_cond = '%s <= %s' % (o.index, self.ccode(_max))
loop_inc = '%s += %s' % (o.index, o.limits[2])
loop_init = f'int {o.index} = {self.ccode(_min)}'
loop_cond = f'{o.index} <= {self.ccode(_max)}'
loop_inc = f'{o.index} += {o.limits[2]}'

# Append unbounded indices, if any
if o.uindices:
uinit = ['%s = %s' % (i.name, self.ccode(i.symbolic_min)) for i in o.uindices]
uinit = [f'{i.name} = {self.ccode(i.symbolic_min)}' for i in o.uindices]
loop_init = c.Line(', '.join([loop_init] + uinit))

ustep = []
for i in o.uindices:
op = '=' if i.is_Modulo else '+='
ustep.append('%s %s %s' % (i.name, op, self.ccode(i.symbolic_incr)))
ustep.append(f'{i.name} {op} {self.ccode(i.symbolic_incr)}')
loop_inc = c.Line(', '.join([loop_inc] + ustep))

# Create For header+body
Expand All @@ -577,7 +582,7 @@ def visit_While(self, o):
return c.While(condition, c.Block(body))
else:
# Hack: cgen doesn't support body-less while-loops, i.e. `while(...);`
return c.Statement('while(%s)' % condition)
return c.Statement(f'while({condition})')

def visit_Callable(self, o):
body = flatten(self._visit(i) for i in o.children)
Expand All @@ -597,7 +602,7 @@ def visit_MultiTraversable(self, o):
return c.Collection(body)

def visit_UsingNamespace(self, o):
return c.Statement('using namespace %s' % str(o.namespace))
return c.Statement(f'using namespace {str(o.namespace)}')

def visit_Lambda(self, o):
body = []
Expand All @@ -615,9 +620,11 @@ def visit_Lambda(self, o):
extra.append(' '.join(str(i) for i in o.special))
if o.attributes:
extra.append(' ')
extra.append(' '.join('[[%s]]' % i for i in o.attributes))
top = c.Line('[%s](%s)%s' %
(', '.join(captures), ', '.join(decls), ''.join(extra)))
extra.append(' '.join(f'[[{i}]]' for i in o.attributes))
ccapt = ', '.join(captures)
cdecls = ', '.join(decls)
cextra = ''.join(extra)
top = c.Line(f'[{ccapt}]({cdecls}){cextra}')
return LambdaCollection([top, c.Block(body)])

def visit_HaloSpot(self, o):
Expand All @@ -626,7 +633,8 @@ def visit_HaloSpot(self, o):

def visit_KernelLaunch(self, o):
if o.templates:
templates = '<%s>' % ','.join([str(i) for i in o.templates])
ctemplates = ','.join([str(i) for i in o.templates])
templates = f'<{ctemplates}>'
else:
templates = ''

Expand All @@ -640,8 +648,7 @@ def visit_KernelLaunch(self, o):
arguments = self._args_call(o.arguments)
arguments = ','.join(arguments)

return c.Statement('%s%s<<<%s>>>(%s)'
% (o.name, templates, launch_config, arguments))
return c.Statement(f'{o.name}{templates}<<<{launch_config}>>>({arguments})')

# Operator-handle machinery

Expand Down Expand Up @@ -742,7 +749,7 @@ class CInterface(CGen):

def _operator_includes(self, o):
includes = super()._operator_includes(o)
includes.append(c.Include("%s.h" % o.name, system=False))
includes.append(c.Include(f"{o.name}.h", system=False))

return includes

Expand All @@ -754,7 +761,7 @@ def visit_Operator(self, o):
typedecls = self._operator_typedecls(o, mode='public')
guarded_typedecls = []
for i in typedecls:
guard = "DEVITO_%s" % i.tpname.upper()
guard = f"DEVITO_{i.tpname.upper()}"
iflines = [c.Define(guard, ""), blankline, i, blankline]
guarded_typedecl = c.IfNDef(guard, iflines, [])
guarded_typedecls.extend([guarded_typedecl, blankline])
Expand Down Expand Up @@ -1383,13 +1390,15 @@ def __init__(self, name, arguments, is_expr=False, is_indirect=False,

def generate(self):
if self.templates:
tip = "%s<%s>" % (self.name, ", ".join(str(i) for i in self.templates))
ctemplates = ", ".join(str(i) for i in self.templates)
tip = f"{self.name}<{ctemplates}>"
else:
tip = self.name
if not self.is_indirect:
tip = "%s(" % tip
tip = f"{tip}("
else:
tip = "%s%s" % (tip, ',' if self.arguments else '')
cargs = ',' if self.arguments else ''
tip = f"{tip}{cargs}"
processed = []
for i in self.arguments:
if isinstance(i, (MultilineCall, LambdaCollection)):
Expand All @@ -1411,7 +1420,7 @@ def generate(self):
if not self.is_expr:
tip += ";"
if self.cast:
tip = '(%s)%s' % (self.cast, tip)
tip = f'({self.cast}){tip}'
yield tip


Expand Down
Loading

0 comments on commit d68ecfe

Please sign in to comment.