Skip to content

Commit

Permalink
Firedrake coupling (#513)
Browse files Browse the repository at this point in the history
* Added coupling of pySDC with Firedrake. So far only simple heat equation
example with short tutorial.

* Pipeline fix

* Implemented @tlunet's suggestion

* Linting
  • Loading branch information
brownbaerchen authored Jan 22, 2025
1 parent c472a48 commit 5ac6441
Show file tree
Hide file tree
Showing 12 changed files with 804 additions and 11 deletions.
52 changes: 52 additions & 0 deletions .github/workflows/ci_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,58 @@ jobs:
path: |
data_libpressio
coverage_libpressio.dat
user_firedrake_tests:
runs-on: ubuntu-latest
container:
image: firedrakeproject/firedrake-vanilla:latest
options: --user root
volumes:
- ${{ github.workspace }}:/repositories
defaults:
run:
shell: bash -l {0}
steps:
- name: Checkout pySDC
uses: actions/checkout@v4
with:
path: ./pySDC
- name: Checkout gusto
uses: actions/checkout@v4
with:
repository: firedrakeproject/gusto
path: ./gusto_repo
- name: Install pySDC
run: |
. /home/firedrake/firedrake/bin/activate
python -m pip install --no-deps -e /repositories/pySDC
python -m pip install qmat
- name: Install gusto
run: |
. /home/firedrake/firedrake/bin/activate
python -m pip install -e /repositories/gusto_repo
- name: run pytest
run: |
. /home/firedrake/firedrake/bin/activate
firedrake-clean
cd ./pySDC
coverage run -m pytest --continue-on-collection-errors -v --durations=0 /repositories/pySDC/pySDC/tests -m firedrake
timeout-minutes: 120
- name: Make coverage report
run: |
. /home/firedrake/firedrake/bin/activate
cd ./pySDC
mv data ../data_firedrake
coverage combine
mv .coverage ../coverage_firedrake.dat
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: test-artifacts-firedrake
path: |
data_firedrake
coverage_firedrake.dat
user_monodomain_tests_linux:
runs-on: ubuntu-latest
Expand Down
58 changes: 58 additions & 0 deletions pySDC/helpers/firedrake_ensemble_communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from mpi4py import MPI
import firedrake as fd
import numpy as np


class FiredrakeEnsembleCommunicator:
"""
Ensemble communicator for performing multiple similar distributed simulations with Firedrake, see https://www.firedrakeproject.org/firedrake/parallelism.html
This is intended to do space-time parallelism in pySDC.
This class wraps the time communicator. All requests that are not overloaded are passed to the time communicator. For instance, `ensemble.rank` will return the rank in the time communicator.
Some operations are overloaded to use the interface of the MPI communicator but handles communication with the ensemble communicator instead.
"""

def __init__(self, comm, space_size):
"""
Args:
comm (MPI.Intracomm): MPI communicator, which will be split into time and space communicators
space_size (int): Size of the spatial communicators
Attributes:
ensemble (firedrake.Ensemble): Ensemble communicator
"""
self.ensemble = fd.Ensemble(comm, space_size)

@property
def space_comm(self):
return self.ensemble.comm

@property
def time_comm(self):
return self.ensemble.ensemble_comm

def __getattr__(self, name):
return getattr(self.time_comm, name)

def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0):
if type(sendbuf) in [np.ndarray]:
self.ensemble.ensemble_comm.Reduce(sendbuf, recvbuf, op, root)
else:
assert op == MPI.SUM
self.ensemble.reduce(sendbuf, recvbuf, root=root)

def Allreduce(self, sendbuf, recvbuf, op=MPI.SUM):
if type(sendbuf) in [np.ndarray]:
self.ensemble.ensemble_comm.Allreduce(sendbuf, recvbuf, op)
else:
assert op == MPI.SUM
self.ensemble.allreduce(sendbuf, recvbuf)

def Bcast(self, buf, root=0):
if type(buf) in [np.ndarray]:
self.ensemble.ensemble_comm.Bcast(buf, root)
else:
self.ensemble.bcast(buf, root=root)


def get_ensemble(comm, space_size):
return fd.Ensemble(comm, space_size)
2 changes: 1 addition & 1 deletion pySDC/implementations/datatype_classes/fenics_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __rmul__(self, other):
Args:
other (float): factor
Raises:
DataError: is other is not a float
DataError: if other is not a float
Returns:
fenics_mesh: copy of original values scaled by factor
"""
Expand Down
114 changes: 114 additions & 0 deletions pySDC/implementations/datatype_classes/firedrake_mesh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import firedrake as fd

from pySDC.core.errors import DataError


class firedrake_mesh(object):
"""
Wrapper for firedrake function data.
Attributes:
functionspace (firedrake.Function): firedrake data
"""

def __init__(self, init, val=0.0):
if fd.functionspaceimpl.WithGeometry in type(init).__mro__:
self.functionspace = fd.Function(init)
self.functionspace.assign(val)
elif fd.Function in type(init).__mro__:
self.functionspace = fd.Function(init)
elif type(init) == firedrake_mesh:
self.functionspace = init.functionspace.copy(deepcopy=True)
else:
raise DataError('something went wrong during %s initialization' % type(init))

def __getattr__(self, key):
return getattr(self.functionspace, key)

@property
def asnumpy(self):
"""
Get a numpy array of the values associated with this data
"""
return self.functionspace.dat._numpy_data

def __add__(self, other):
if isinstance(other, type(self)):
me = firedrake_mesh(other)
me.functionspace.assign(self.functionspace + other.functionspace)
return me
else:
raise DataError("Type error: cannot add %s to %s" % (type(other), type(self)))

def __sub__(self, other):
if isinstance(other, type(self)):
me = firedrake_mesh(other)
me.functionspace.assign(self.functionspace - other.functionspace)
return me
else:
raise DataError("Type error: cannot add %s to %s" % (type(other), type(self)))

def __rmul__(self, other):
"""
Overloading the right multiply by scalar factor
Args:
other (float): factor
Raises:
DataError: if other is not a float
Returns:
fenics_mesh: copy of original values scaled by factor
"""

try:
me = firedrake_mesh(self)
me.functionspace.assign(other * self.functionspace)
return me
except TypeError as e:
raise DataError("Type error: cannot multiply %s to %s" % (type(other), type(self))) from e

def __abs__(self):
"""
Overloading the abs operator for mesh types
Returns:
float: L2 norm
"""

return fd.norm(self.functionspace, 'L2')


class IMEX_firedrake_mesh(object):
"""
Datatype for IMEX integration with firedrake data.
Attributes:
impl (firedrake_mesh): implicit part
expl (firedrake_mesh): explicit part
"""

def __init__(self, init, val=0.0):
if type(init) == type(self):
self.impl = firedrake_mesh(init.impl)
self.expl = firedrake_mesh(init.expl)
else:
self.impl = firedrake_mesh(init, val=val)
self.expl = firedrake_mesh(init, val=val)

def __add__(self, other):
me = IMEX_firedrake_mesh(self)
me.impl = self.impl + other.impl
me.expl = self.expl + other.expl
return me

def __sub__(self, other):
me = IMEX_firedrake_mesh(self)
me.impl = self.impl - other.impl
me.expl = self.expl - other.expl
return me

def __rmul__(self, other):
me = IMEX_firedrake_mesh(self)
me.impl = other * self.impl
me.expl = other * self.expl
return me
Loading

0 comments on commit 5ac6441

Please sign in to comment.