diff --git a/psiflow/execution.py b/psiflow/execution.py index 1b9a529..fb81296 100644 --- a/psiflow/execution.py +++ b/psiflow/execution.py @@ -488,7 +488,7 @@ def from_config( ) model_training = ModelTraining.from_config( container=container, - **kwargs.pop("ModelTraining", {}), + **kwargs.pop("ModelTraining", {'gpu': True}), # avoid triggering assertion ) reference_evaluations = [] # reference evaluations might be class specific for key in list(kwargs.keys()):