Skip to content

Commit

Permalink
Merge pull request #19395 from mvdbeek/alternative_format_source_fix
Browse files Browse the repository at this point in the history
Alternative `format_source` fix
  • Loading branch information
jdavcs authored Jan 23, 2025
2 parents e2636d7 + f28ea20 commit 16e18d4
Show file tree
Hide file tree
Showing 15 changed files with 136 additions and 51 deletions.
17 changes: 9 additions & 8 deletions lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6829,9 +6829,10 @@ def dataset_elements_and_identifiers(self, identifiers=None):
def first_dataset_element(self) -> Optional["DatasetCollectionElement"]:
for element in self.elements:
if element.is_collection:
first_element = element.child_collection.first_dataset_element
if first_element:
return first_element
if element.child_collection:
first_element = element.child_collection.first_dataset_element
if first_element:
return first_element
else:
return element
return None
Expand Down Expand Up @@ -7003,7 +7004,7 @@ class HistoryDatasetCollectionAssociation(
create_time: Mapped[datetime] = mapped_column(default=now, nullable=True)
update_time: Mapped[datetime] = mapped_column(default=now, onupdate=now, index=True, nullable=True)

collection = relationship("DatasetCollection")
collection: Mapped["DatasetCollection"] = relationship("DatasetCollection")
history: Mapped[Optional["History"]] = relationship(back_populates="dataset_collections")

copied_from_history_dataset_collection_association = relationship(
Expand Down Expand Up @@ -7421,18 +7422,18 @@ class DatasetCollectionElement(Base, Dictifiable, Serializable):
element_index: Mapped[Optional[int]]
element_identifier: Mapped[Optional[str]] = mapped_column(Unicode(255))

hda = relationship(
hda: Mapped[Optional["HistoryDatasetAssociation"]] = relationship(
"HistoryDatasetAssociation",
primaryjoin=(lambda: DatasetCollectionElement.hda_id == HistoryDatasetAssociation.id),
)
ldda = relationship(
ldda: Mapped[Optional["LibraryDatasetDatasetAssociation"]] = relationship(
"LibraryDatasetDatasetAssociation",
primaryjoin=(lambda: DatasetCollectionElement.ldda_id == LibraryDatasetDatasetAssociation.id),
)
child_collection = relationship(
child_collection: Mapped[Optional["DatasetCollection"]] = relationship(
"DatasetCollection", primaryjoin=(lambda: DatasetCollectionElement.child_collection_id == DatasetCollection.id)
)
collection = relationship(
collection: Mapped[DatasetCollection] = relationship(
"DatasetCollection",
primaryjoin=(lambda: DatasetCollection.id == DatasetCollectionElement.dataset_collection_id),
back_populates="elements",
Expand Down
10 changes: 9 additions & 1 deletion lib/galaxy/tool_util/parser/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
ResourceRequirement,
ToolRequirements,
)
from galaxy.tool_util.parser.output_objects import (
ToolOutput,
ToolOutputCollection,
)
from galaxy.tools import Tool


NOT_IMPLEMENTED_MESSAGE = "Galaxy tool format does not yet support this tool feature."

Expand Down Expand Up @@ -331,7 +337,9 @@ def parse_provided_metadata_file(self):
return "galaxy.json"

@abstractmethod
def parse_outputs(self, tool):
def parse_outputs(
self, tool: Optional["Tool"]
) -> Tuple[Dict[str, "ToolOutput"], Dict[str, "ToolOutputCollection"]]:
"""Return a pair of output and output collections ordered
dictionaries for use by Tool.
"""
Expand Down
3 changes: 2 additions & 1 deletion lib/galaxy/tool_util/parser/output_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import (
List,
Optional,
Sequence,
Union,
)

Expand Down Expand Up @@ -105,7 +106,7 @@ class FilePatternDatasetCollectionDescription(DatasetCollectionDescription):
ToolOutput = Annotated[ToolOutputT, Field(discriminator="type")]


def from_tool_source(tool_source: ToolSource) -> List[ToolOutput]:
def from_tool_source(tool_source: ToolSource) -> Sequence[ToolOutput]:
tool_outputs, tool_output_collections = tool_source.parse_outputs(None)
outputs = []
for tool_output in tool_outputs.values():
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/tool_util/parser/output_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(
self.collection = True
self.default_format = default_format
self.structure = structure
self.outputs: Dict[str, str] = {}
self.outputs: Dict[str, ToolOutput] = {}

self.inherit_format = inherit_format
self.inherit_metadata = inherit_metadata
Expand Down
6 changes: 6 additions & 0 deletions lib/galaxy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@
PageSource,
ToolSource,
)
from galaxy.tool_util.parser.output_objects import (
ToolOutput,
ToolOutputCollection,
)
from galaxy.tool_util.parser.util import (
parse_profile_version,
parse_tool_version_with_defaults,
Expand Down Expand Up @@ -847,6 +851,8 @@ def __init__(
self.tool_errors = None
# Parse XML element containing configuration
self.tool_source = tool_source
self.outputs: Dict[str, ToolOutput] = {}
self.output_collections: Dict[str, ToolOutputCollection] = {}
self._is_workflow_compatible = None
self.__help = None
self.__tests: Optional[str] = None
Expand Down
7 changes: 4 additions & 3 deletions lib/galaxy/tools/actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
cast,
Dict,
List,
MutableMapping,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -533,7 +534,7 @@ def handle_output(name, output, hidden=None):
output,
wrapped_params.params,
inp_data,
inp_dataset_collections,
input_collections,
input_ext,
python_template_version=tool.python_template_version,
execution_cache=execution_cache,
Expand Down Expand Up @@ -1156,7 +1157,7 @@ def determine_output_format(
output: "ToolOutput",
parameter_context,
input_datasets,
input_dataset_collections,
input_dataset_collections: MutableMapping[str, model.HistoryDatasetCollectionAssociation],
random_input_ext,
python_template_version="3",
execution_cache=None,
Expand Down Expand Up @@ -1198,7 +1199,7 @@ def determine_output_format(

if collection_name in input_dataset_collections:
try:
input_collection = input_dataset_collections[collection_name][0][0]
input_collection = input_dataset_collections[collection_name]
input_collection_collection = input_collection.collection
if element_index is None:
# just pick the first HDA
Expand Down
58 changes: 42 additions & 16 deletions lib/galaxy/tools/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
MinimalToolApp,
)
from galaxy.tool_util.data import TabularToolDataTable
from galaxy.tools.actions import determine_output_format
from galaxy.tools.parameters import (
visit_input_values,
wrapped_json,
Expand Down Expand Up @@ -64,6 +65,7 @@
safe_makedirs,
unicodify,
)
from galaxy.util.path import StrPath
from galaxy.util.template import (
fill_template,
InputNotFoundSyntaxError,
Expand Down Expand Up @@ -102,7 +104,7 @@ def __init__(self, *args: object, tool_id: Optional[str], tool_version: str, is_
self.is_latest = is_latest


def global_tool_logs(func, config_file: str, action_str: str, tool: "Tool"):
def global_tool_logs(func, config_file: Optional[StrPath], action_str: str, tool: "Tool"):
try:
return func()
except Exception as e:
Expand Down Expand Up @@ -130,7 +132,7 @@ class ToolEvaluator:
job: model.Job
materialize_datasets: bool = True

def __init__(self, app: MinimalToolApp, tool, job, local_working_directory):
def __init__(self, app: MinimalToolApp, tool: "Tool", job, local_working_directory):
self.app = app
self.job = job
self.tool = tool
Expand Down Expand Up @@ -186,6 +188,9 @@ def set_compute_environment(self, compute_environment: ComputeEnvironment, get_s
out_data,
output_collections=out_collections,
)
# late update of format_source outputs
self._eval_format_source(job, inp_data, out_data)

self.execute_tool_hooks(inp_data=inp_data, out_data=out_data, incoming=incoming)

def execute_tool_hooks(self, inp_data, out_data, incoming):
Expand Down Expand Up @@ -275,6 +280,23 @@ def _materialize_objects(

return undeferred_objects

def _eval_format_source(
self,
job: model.Job,
inp_data: Dict[str, Optional[model.DatasetInstance]],
out_data: Dict[str, model.DatasetInstance],
):
for output_name, output in out_data.items():
if (
(tool_output := self.tool.outputs.get(output_name))
and (tool_output.format_source or tool_output.change_format)
and output.extension == "expression.json"
):
input_collections = {jtidca.name: jtidca.dataset_collection for jtidca in job.input_dataset_collections}
ext = determine_output_format(tool_output, self.param_dict, inp_data, input_collections, None)
if ext:
output.extension = ext

def _replaced_deferred_objects(
self,
inp_data: Dict[str, Optional[model.DatasetInstance]],
Expand Down Expand Up @@ -364,6 +386,9 @@ def do_walk(inputs, input_values):
do_walk(inputs, input_values)

def __populate_wrappers(self, param_dict, input_datasets, job_working_directory):

element_identifier_mapper = ElementIdentifierMapper(input_datasets)

def wrap_input(input_values, input):
value = input_values[input.name]
if isinstance(input, DataToolParameter) and input.multiple:
Expand All @@ -380,26 +405,26 @@ def wrap_input(input_values, input):

elif isinstance(input, DataToolParameter):
dataset = input_values[input.name]
wrapper_kwds = dict(
element_identifier = element_identifier_mapper.identifier(dataset, param_dict)
input_values[input.name] = DatasetFilenameWrapper(
dataset=dataset,
datatypes_registry=self.app.datatypes_registry,
tool=self.tool,
name=input.name,
compute_environment=self.compute_environment,
identifier=element_identifier,
formats=input.formats,
)
element_identifier = element_identifier_mapper.identifier(dataset, param_dict)
if element_identifier:
wrapper_kwds["identifier"] = element_identifier
wrapper_kwds["formats"] = input.formats
input_values[input.name] = DatasetFilenameWrapper(dataset, **wrapper_kwds)
elif isinstance(input, DataCollectionToolParameter):
dataset_collection = value
wrapper_kwds = dict(
wrapper = DatasetCollectionWrapper(
job_working_directory=job_working_directory,
has_collection=dataset_collection,
datatypes_registry=self.app.datatypes_registry,
compute_environment=self.compute_environment,
tool=self.tool,
name=input.name,
)
wrapper = DatasetCollectionWrapper(job_working_directory, dataset_collection, **wrapper_kwds)
input_values[input.name] = wrapper
elif isinstance(input, SelectToolParameter):
if input.multiple:
Expand All @@ -409,14 +434,13 @@ def wrap_input(input_values, input):
)
else:
input_values[input.name] = InputValueWrapper(
input, value, param_dict, profile=self.tool and self.tool.profile
input, value, param_dict, profile=self.tool and self.tool.profile or None
)

# HACK: only wrap if check_values is not false, this deals with external
# tools where the inputs don't even get passed through. These
# tools (e.g. UCSC) should really be handled in a special way.
if self.tool.check_values:
element_identifier_mapper = ElementIdentifierMapper(input_datasets)
self.__walk_inputs(self.tool.inputs, param_dict, wrap_input)

def __populate_input_dataset_wrappers(self, param_dict, input_datasets):
Expand All @@ -443,13 +467,13 @@ def __populate_input_dataset_wrappers(self, param_dict, input_datasets):
param_dict[name] = wrapper
continue
if not isinstance(param_dict_value, ToolParameterValueWrapper):
wrapper_kwds = dict(
param_dict[name] = DatasetFilenameWrapper(
dataset=data,
datatypes_registry=self.app.datatypes_registry,
tool=self.tool,
name=name,
compute_environment=self.compute_environment,
)
param_dict[name] = DatasetFilenameWrapper(data, **wrapper_kwds)

def __populate_output_collection_wrappers(self, param_dict, output_collections, job_working_directory):
tool = self.tool
Expand All @@ -460,14 +484,15 @@ def __populate_output_collection_wrappers(self, param_dict, output_collections,
# message = message_template % ( name, tool.output_collections )
# raise AssertionError( message )

wrapper_kwds = dict(
wrapper = DatasetCollectionWrapper(
job_working_directory=job_working_directory,
has_collection=out_collection,
datatypes_registry=self.app.datatypes_registry,
compute_environment=self.compute_environment,
io_type="output",
tool=tool,
name=name,
)
wrapper = DatasetCollectionWrapper(job_working_directory, out_collection, **wrapper_kwds)
param_dict[name] = wrapper
# TODO: Handle nested collections...
for element_identifier, output_def in tool.output_collections[name].outputs.items():
Expand Down Expand Up @@ -662,6 +687,7 @@ def _build_command_line(self):
if interpreter:
# TODO: path munging for cluster/dataset server relocatability
executable = command_line.split()[0]
assert self.tool.tool_dir
tool_dir = os.path.abspath(self.tool.tool_dir)
abs_executable = os.path.join(tool_dir, executable)
command_line = command_line.replace(executable, f"{interpreter} {shlex.quote(abs_executable)}", 1)
Expand Down
35 changes: 21 additions & 14 deletions lib/galaxy/tools/parameters/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1998,6 +1998,7 @@ def do_validate(v):
dataset_count += 1
do_validate(v.hda)
else:
assert v.child_collection
for dataset_instance in v.child_collection.dataset_instances:
dataset_count += 1
do_validate(dataset_instance)
Expand Down Expand Up @@ -2176,33 +2177,39 @@ def from_json(self, value, trans, other_values=None):
dataset_matcher_factory = get_dataset_matcher_factory(trans)
dataset_matcher = dataset_matcher_factory.dataset_matcher(self, other_values)
for v in rval:
value_to_check: Union[
DatasetInstance, DatasetCollection, DatasetCollectionElement, HistoryDatasetCollectionAssociation
] = v
if isinstance(v, DatasetCollectionElement):
if hda := v.hda:
v = hda
value_to_check = hda
elif ldda := v.ldda:
v = ldda
value_to_check = ldda
elif collection := v.child_collection:
v = collection
elif not v.collection and v.collection.populated_optimized:
value_to_check = collection
elif v.collection and not v.collection.populated_optimized:
raise ParameterValueError("the selected collection has not been populated.", self.name)
else:
raise ParameterValueError("Collection element in unexpected state", self.name)
if isinstance(v, DatasetInstance):
if v.deleted:
if isinstance(value_to_check, DatasetInstance):
if value_to_check.deleted:
raise ParameterValueError("the previously selected dataset has been deleted.", self.name)
elif v.dataset and v.dataset.state in [Dataset.states.ERROR, Dataset.states.DISCARDED]:
elif value_to_check.dataset and value_to_check.dataset.state in [
Dataset.states.ERROR,
Dataset.states.DISCARDED,
]:
raise ParameterValueError(
"the previously selected dataset has entered an unusable state", self.name
)
match = dataset_matcher.hda_match(v)
match = dataset_matcher.hda_match(value_to_check)
if match and match.implicit_conversion:
v.implicit_conversion = True # type:ignore[union-attr]
elif isinstance(v, HistoryDatasetCollectionAssociation):
if v.deleted:
value_to_check.implicit_conversion = True # type:ignore[attr-defined]
elif isinstance(value_to_check, HistoryDatasetCollectionAssociation):
if value_to_check.deleted:
raise ParameterValueError("the previously selected dataset collection has been deleted.", self.name)
v = v.collection
if isinstance(v, DatasetCollection):
if v.elements_deleted:
value_to_check = value_to_check.collection
if isinstance(value_to_check, DatasetCollection):
if value_to_check.elements_deleted:
raise ParameterValueError(
"the previously selected dataset collection has elements that are deleted.", self.name
)
Expand Down
Loading

0 comments on commit 16e18d4

Please sign in to comment.