Skip to content

Commit

Permalink
Merge pull request #16932 from jdavcs/dev_sa20_fix21
Browse files Browse the repository at this point in the history
SQLAlchemy 2.0 upgrades (part 5)
  • Loading branch information
jdavcs authored Nov 29, 2023
2 parents 55d3d58 + eeeaf7b commit fc0a266
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 55 deletions.
8 changes: 4 additions & 4 deletions lib/galaxy/managers/histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sqlalchemy import (
asc,
desc,
exists,
false,
func,
select,
Expand Down Expand Up @@ -340,13 +341,12 @@ def get_sharing_extra_information(
return extra

def is_history_shared_with(self, history: model.History, user: model.User) -> bool:
stmt = (
select(HistoryUserShareAssociation.id)
stmt = select(
exists()
.where(HistoryUserShareAssociation.user_id == user.id)
.where(HistoryUserShareAssociation.history_id == history.id)
.limit(1)
)
return bool(self.session().execute(stmt).first())
return self.session().scalar(stmt)

def make_members_public(self, trans, item):
"""Make the non-purged datasets in history public.
Expand Down
81 changes: 45 additions & 36 deletions lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6125,7 +6125,7 @@ def __init__(self, id=None, collection_type=None, populated=True, element_count=
self.populated_state = DatasetCollection.populated_states.NEW
self.element_count = element_count

def _get_nested_collection_attributes(
def _build_nested_collection_attributes_stmt(
self,
collection_attributes: Optional[Iterable[str]] = None,
element_attributes: Optional[Iterable[str]] = None,
Expand All @@ -6152,10 +6152,8 @@ def _get_nested_collection_attributes(
dataset_permission_attributes = dataset_permission_attributes or ()
return_entities = return_entities or ()
dataset_collection = self
db_session = object_session(self)
dc = alias(DatasetCollection)
dce = alias(DatasetCollectionElement)

depth_collection_type = dataset_collection.collection_type
order_by_columns = [dce.c.element_index]
nesting_level = 0
Expand All @@ -6165,14 +6163,15 @@ def attribute_columns(column_collection, attributes, nesting_level=None):
return [getattr(column_collection, a).label(f"{a}{label_fragment}") for a in attributes]

q = (
db_session.query(
select(
*attribute_columns(dce.c, element_attributes, nesting_level),
*attribute_columns(dc.c, collection_attributes, nesting_level),
)
.select_from(dce, dc)
.join(dce, dce.c.dataset_collection_id == dc.c.id)
.filter(dc.c.id == dataset_collection.id)
)

while ":" in depth_collection_type:
nesting_level += 1
inner_dc = alias(DatasetCollection)
Expand Down Expand Up @@ -6207,18 +6206,28 @@ def attribute_columns(column_collection, attributes, nesting_level=None):
.add_columns(*attribute_columns(DatasetPermissions, dataset_permission_attributes))
)
for entity in return_entities:
q = q.add_entity(entity)
q = q.add_columns(entity)
if entity == DatasetCollectionElement:
q = q.filter(entity.id == dce.c.id)
return q.distinct().order_by(*order_by_columns)

q = q.order_by(*order_by_columns)
return q

@property
def dataset_states_and_extensions_summary(self):
if not hasattr(self, "_dataset_states_and_extensions_summary"):
q = self._get_nested_collection_attributes(hda_attributes=("extension",), dataset_attributes=("state",))
stmt = self._build_nested_collection_attributes_stmt(
hda_attributes=("extension",), dataset_attributes=("state",)
)
# With DISTINCT, all columns that appear in the ORDER BY clause must appear in the SELECT clause.
stmt = stmt.add_columns(*stmt._order_by_clauses)
stmt = stmt.distinct()

tuples = object_session(self).execute(stmt)

extensions = set()
states = set()
for extension, state in q:
for extension, state, *_ in tuples: # we discard the added columns from the order-by clause
states.add(state)
extensions.add(extension)

Expand All @@ -6232,8 +6241,9 @@ def has_deferred_data(self):
has_deferred_data = False
if object_session(self):
# TODO: Optimize by just querying without returning the states...
q = self._get_nested_collection_attributes(dataset_attributes=("state",))
for (state,) in q:
stmt = self._build_nested_collection_attributes_stmt(dataset_attributes=("state",))
tuples = object_session(self).execute(stmt)
for (state,) in tuples:
if state == Dataset.states.DEFERRED:
has_deferred_data = True
break
Expand All @@ -6254,13 +6264,16 @@ def populated_optimized(self):
if ":" not in self.collection_type:
_populated_optimized = self.populated_state == DatasetCollection.populated_states.OK
else:
q = self._get_nested_collection_attributes(
stmt = self._build_nested_collection_attributes_stmt(
collection_attributes=("populated_state",),
inner_filter=InnerCollectionFilter(
"populated_state", operator.__ne__, DatasetCollection.populated_states.OK
),
)
_populated_optimized = q.session.query(~exists(q.subquery())).scalar()
stmt = stmt.subquery()
stmt = select(~exists(stmt))
session = object_session(self)
_populated_optimized = session.scalar(stmt)

self._populated_optimized = _populated_optimized

Expand All @@ -6276,37 +6289,25 @@ def populated(self):
@property
def dataset_action_tuples(self):
if not hasattr(self, "_dataset_action_tuples"):
q = self._get_nested_collection_attributes(dataset_permission_attributes=("action", "role_id"))
_dataset_action_tuples = []
for _dataset_action_tuple in q:
if _dataset_action_tuple[0] is None:
continue
_dataset_action_tuples.append(_dataset_action_tuple)

self._dataset_action_tuples = _dataset_action_tuples

stmt = self._build_nested_collection_attributes_stmt(dataset_permission_attributes=("action", "role_id"))
tuples = object_session(self).execute(stmt)
self._dataset_action_tuples = [(action, role_id) for action, role_id in tuples if action is not None]
return self._dataset_action_tuples

@property
def element_identifiers_extensions_and_paths(self):
q = self._get_nested_collection_attributes(
element_attributes=("element_identifier",), hda_attributes=("extension",), return_entities=(Dataset,)
)
return [(row[:-2], row.extension, row.Dataset.get_file_name()) for row in q]

@property
def element_identifiers_extensions_paths_and_metadata_files(
self,
) -> List[List[Any]]:
results = []
if object_session(self):
q = self._get_nested_collection_attributes(
stmt = self._build_nested_collection_attributes_stmt(
element_attributes=("element_identifier",),
hda_attributes=("extension",),
return_entities=(HistoryDatasetAssociation, Dataset),
)
tuples = object_session(self).execute(stmt)
# element_identifiers, extension, path
for row in q:
for row in tuples:
result = [row[:-3], row.extension, row.Dataset.get_file_name()]
hda = row.HistoryDatasetAssociation
result.append(hda.get_metadata_file_paths_and_extensions())
Expand Down Expand Up @@ -6351,7 +6352,9 @@ def finalize(self, collection_type_description):
def dataset_instances(self):
db_session = object_session(self)
if db_session and self.id:
return self._get_nested_collection_attributes(return_entities=(HistoryDatasetAssociation,)).all()
stmt = self._build_nested_collection_attributes_stmt(return_entities=(HistoryDatasetAssociation,))
tuples = db_session.execute(stmt).all()
return [tuple[0] for tuple in tuples]
else:
# Sessionless context
instances = []
Expand All @@ -6367,7 +6370,9 @@ def dataset_instances(self):
def dataset_elements(self):
db_session = object_session(self)
if db_session and self.id:
return self._get_nested_collection_attributes(return_entities=(DatasetCollectionElement,)).all()
stmt = self._build_nested_collection_attributes_stmt(return_entities=(DatasetCollectionElement,))
tuples = db_session.execute(stmt).all()
return [tuple[0] for tuple in tuples]
elements = []
for element in self.elements:
if element.is_collection:
Expand Down Expand Up @@ -6452,9 +6457,11 @@ def copy(
return new_collection

def replace_failed_elements(self, replacements):
hda_id_to_element = dict(
self._get_nested_collection_attributes(return_entities=[DatasetCollectionElement], hda_attributes=["id"])
stmt = self._build_nested_collection_attributes_stmt(
return_entities=[DatasetCollectionElement], hda_attributes=["id"]
)
tuples = object_session(self).execute(stmt).all()
hda_id_to_element = dict(tuples)
for failed, replacement in replacements.items():
element = hda_id_to_element.get(failed.id)
if element:
Expand Down Expand Up @@ -6719,10 +6726,12 @@ def job_state_summary_dict(self):
@property
def dataset_dbkeys_and_extensions_summary(self):
if not hasattr(self, "_dataset_dbkeys_and_extensions_summary"):
rows = self.collection._get_nested_collection_attributes(hda_attributes=("_metadata", "extension"))
stmt = self.collection._build_nested_collection_attributes_stmt(hda_attributes=("_metadata", "extension"))
tuples = object_session(self).execute(stmt)

extensions = set()
dbkeys = set()
for row in rows:
for row in tuples:
if row is not None:
dbkey_field = row._metadata.get("dbkey")
if isinstance(dbkey_field, list):
Expand Down
5 changes: 3 additions & 2 deletions scripts/check_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def load_indexes(metadata):

# create EMPTY metadata, then load from database
db_url = get_config(sys.argv)["db_url"]
metadata = MetaData(bind=create_engine(db_url))
metadata.reflect()
metadata = MetaData()
engine = create_engine(db_url)
metadata.reflect(bind=engine)
indexes_in_db = load_indexes(metadata)

all_indexes = set(mapping_indexes.keys()) | set(tsi_mapping_indexes.keys())
Expand Down
56 changes: 43 additions & 13 deletions test/unit/data/test_galaxy_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,22 @@ def test_collections_in_library_folders(self):
# assert len(loaded_dataset_collection.datasets) == 2
# assert loaded_dataset_collection.collection_type == "pair"

def test_dataset_action_tuples(self):
u = model.User(email="foo", password="foo")
h1 = model.History(user=u)
hda1 = model.HistoryDatasetAssociation(history=h1, create_dataset=True, sa_session=self.model.session)
hda2 = model.HistoryDatasetAssociation(history=h1, create_dataset=True, sa_session=self.model.session)
r1 = model.Role()
dp1 = model.DatasetPermissions(action="action1", dataset=hda1.dataset, role=r1)
dp2 = model.DatasetPermissions(action=None, dataset=hda1.dataset, role=r1)
dp3 = model.DatasetPermissions(action="action3", dataset=hda1.dataset, role=r1)
c1 = model.DatasetCollection(collection_type="type1")
dce1 = model.DatasetCollectionElement(collection=c1, element=hda1)
dce2 = model.DatasetCollectionElement(collection=c1, element=hda2)
self.model.session.add_all([u, h1, hda1, hda2, r1, dp1, dp2, dp3, c1, dce1, dce2])
self.model.session.flush()
assert c1.dataset_action_tuples == [("action1", r1.id), ("action3", r1.id)]

def test_nested_collection_attributes(self):
u = model.User(email="[email protected]", password="password")
h1 = model.History(name="History 1", user=u)
Expand Down Expand Up @@ -392,18 +408,31 @@ def test_nested_collection_attributes(self):
)
self.model.session.add_all([d1, d2, c1, dce1, dce2, c2, dce3, c3, c4, dce4])
self.model.session.flush()
q = c2._get_nested_collection_attributes(

stmt = c2._build_nested_collection_attributes_stmt(
element_attributes=("element_identifier",), hda_attributes=("extension",), dataset_attributes=("state",)
)
assert [(r._fields) for r in q] == [
result = self.model.session.execute(stmt).all()
assert [(r._fields) for r in result] == [
("element_identifier_0", "element_identifier_1", "extension", "state"),
("element_identifier_0", "element_identifier_1", "extension", "state"),
]
assert q.all() == [("inner_list", "forward", "bam", "new"), ("inner_list", "reverse", "txt", "new")]
q = c2._get_nested_collection_attributes(return_entities=(model.HistoryDatasetAssociation,))
assert q.all() == [d1, d2]
q = c2._get_nested_collection_attributes(return_entities=(model.HistoryDatasetAssociation, model.Dataset))
assert q.all() == [(d1, d1.dataset), (d2, d2.dataset)]

stmt = c2._build_nested_collection_attributes_stmt(
element_attributes=("element_identifier",), hda_attributes=("extension",), dataset_attributes=("state",)
)
result = self.model.session.execute(stmt).all()
assert result == [("inner_list", "forward", "bam", "new"), ("inner_list", "reverse", "txt", "new")]

stmt = c2._build_nested_collection_attributes_stmt(return_entities=(model.HistoryDatasetAssociation,))
result = self.model.session.execute(stmt).all()
assert result == [(d1,), (d2,)]

stmt = c2._build_nested_collection_attributes_stmt(
return_entities=(model.HistoryDatasetAssociation, model.Dataset)
)
result = self.model.session.execute(stmt).all()
assert result == [(d1, d1.dataset), (d2, d2.dataset)]
# Assert properties that use _get_nested_collection_attributes return correct content
assert c2.dataset_instances == [d1, d2]
assert c2.dataset_elements == [dce1, dce2]
Expand All @@ -422,13 +451,14 @@ def test_nested_collection_attributes(self):
assert c3.dataset_instances == []
assert c3.dataset_elements == []
assert c3.dataset_states_and_extensions_summary == (set(), set())
q = c4._get_nested_collection_attributes(element_attributes=("element_identifier",))
assert q.all() == [("outer_list", "inner_list", "forward"), ("outer_list", "inner_list", "reverse")]
assert c4.dataset_elements == [dce1, dce2]
assert c4.element_identifiers_extensions_and_paths == [
(("outer_list", "inner_list", "forward"), "bam", "mock_dataset_14.dat"),
(("outer_list", "inner_list", "reverse"), "txt", "mock_dataset_14.dat"),

stmt = c4._build_nested_collection_attributes_stmt(element_attributes=("element_identifier",))
result = self.model.session.execute(stmt).all()
assert result == [
("outer_list", "inner_list", "forward"),
("outer_list", "inner_list", "reverse"),
]
assert c4.dataset_elements == [dce1, dce2]

def test_dataset_dbkeys_and_extensions_summary(self):
u = model.User(email="[email protected]", password="password")
Expand Down

0 comments on commit fc0a266

Please sign in to comment.