From 1cf858d18c0c838f622abcf33716abb28dcf7ec1 Mon Sep 17 00:00:00 2001 From: Omar Valenzuela Date: Wed, 5 Jun 2024 10:58:49 -0700 Subject: [PATCH 01/16] pull in tests from czid-platformics --- test_app/Makefile | 4 + test_app/schema/schema.yaml | 25 + test_app/tests/output/api/mutations.py | 56 ++ test_app/tests/output/api/queries.py | 51 ++ test_app/tests/output/api/types/contig.py | 378 ++++++++++++++ .../tests/output/api/types/genomic_range.py | 426 +++++++++++++++ test_app/tests/output/api/types/sample.py | 432 +++++++++++++++ .../tests/output/api/types/sequencing_read.py | 492 ++++++++++++++++++ .../tests/output/cerbos/policies/contig.yaml | 22 + .../output/cerbos/policies/genomic_range.yaml | 22 + .../tests/output/cerbos/policies/sample.yaml | 22 + .../cerbos/policies/sequencing_read.yaml | 22 + .../tests/output/database/models/__init__.py | 20 + .../tests/output/database/models/contig.py | 32 ++ .../output/database/models/genomic_range.py | 32 ++ .../tests/output/database/models/sample.py | 36 ++ .../output/database/models/sequencing_read.py | 51 ++ test_app/tests/output/support/enums.py | 49 ++ .../output/test_infra/factories/contig.py | 37 ++ .../test_infra/factories/genomic_range.py | 36 ++ .../output/test_infra/factories/sample.py | 36 ++ .../test_infra/factories/sequencing_read.py | 73 +++ test_app/tests/test_aggregate_queries.py | 290 ++++++++++- test_app/tests/test_authorization.py | 238 ++++++++- test_app/tests/test_basic_queries.py | 24 +- test_app/tests/test_bulk_download_deletion.py | 72 +++ test_app/tests/test_bulk_download_policy.py | 123 +++++ test_app/tests/test_cascade_deletion.py | 84 +++ test_app/tests/test_error_handling.py | 54 ++ test_app/tests/test_field_constraints.py | 135 +++++ test_app/tests/test_field_visibility.py | 184 +++++++ test_app/tests/test_file_concatenation.py | 8 +- test_app/tests/test_file_mutations.py | 122 ++++- test_app/tests/test_file_queries.py | 13 +- test_app/tests/test_file_uploads.py | 6 +- test_app/tests/test_limit_offset_queries.py | 108 ++++ test_app/tests/test_nested_queries.py | 11 +- .../test_schemas/overrides/api/.gitignore | 1 + .../overrides/api/extra_test_code.py.j2 | 4 + test_app/tests/test_schemas/platformics.yaml | 443 ++++++++++++++++ test_app/tests/test_sorting_queries.py | 243 +++++++++ test_app/tests/test_where_clause.py | 168 +++++- 42 files changed, 4630 insertions(+), 55 deletions(-) create mode 100644 test_app/tests/output/api/mutations.py create mode 100644 test_app/tests/output/api/queries.py create mode 100644 test_app/tests/output/api/types/contig.py create mode 100644 test_app/tests/output/api/types/genomic_range.py create mode 100644 test_app/tests/output/api/types/sample.py create mode 100644 test_app/tests/output/api/types/sequencing_read.py create mode 100644 test_app/tests/output/cerbos/policies/contig.yaml create mode 100644 test_app/tests/output/cerbos/policies/genomic_range.yaml create mode 100644 test_app/tests/output/cerbos/policies/sample.yaml create mode 100644 test_app/tests/output/cerbos/policies/sequencing_read.yaml create mode 100644 test_app/tests/output/database/models/__init__.py create mode 100644 test_app/tests/output/database/models/contig.py create mode 100644 test_app/tests/output/database/models/genomic_range.py create mode 100644 test_app/tests/output/database/models/sample.py create mode 100644 test_app/tests/output/database/models/sequencing_read.py create mode 100644 test_app/tests/output/support/enums.py create mode 100644 test_app/tests/output/test_infra/factories/contig.py create mode 100644 test_app/tests/output/test_infra/factories/genomic_range.py create mode 100644 test_app/tests/output/test_infra/factories/sample.py create mode 100644 test_app/tests/output/test_infra/factories/sequencing_read.py create mode 100644 test_app/tests/test_bulk_download_deletion.py create mode 100644 test_app/tests/test_bulk_download_policy.py create mode 100644 test_app/tests/test_cascade_deletion.py create mode 100644 test_app/tests/test_error_handling.py create mode 100644 test_app/tests/test_field_constraints.py create mode 100644 test_app/tests/test_field_visibility.py create mode 100644 test_app/tests/test_limit_offset_queries.py create mode 100644 test_app/tests/test_schemas/overrides/api/.gitignore create mode 100644 test_app/tests/test_schemas/overrides/api/extra_test_code.py.j2 create mode 100644 test_app/tests/test_schemas/platformics.yaml create mode 100644 test_app/tests/test_sorting_queries.py diff --git a/test_app/Makefile b/test_app/Makefile index 3fc5603..91ef303 100644 --- a/test_app/Makefile +++ b/test_app/Makefile @@ -68,6 +68,10 @@ codegen: ## Run codegen to convert the LinkML schema to a GQL API test: init ## Run tests $(docker_compose) exec $(APP_CONTAINER) pytest -vvv +.PHONY: test-file +test-file: init ## Run tests for a specific file, ex: make test-file FILE=tests/test_file.py + $(docker_compose) exec $(APP_CONTAINER) pytest -vvv $(FILE) + .PHONY: restart restart: ## Restart the GQL service $(docker_compose_run) $(APP_CONTAINER) supervisorctl restart graphql_api diff --git a/test_app/schema/schema.yaml b/test_app/schema/schema.yaml index 8664dbd..386fb08 100644 --- a/test_app/schema/schema.yaml +++ b/test_app/schema/schema.yaml @@ -218,8 +218,15 @@ classes: sequencing_read: range: SequencingRead inverse: Sample.contigs + multivalued: true sequence: required: true + upstream_database: + range: UpstreamDatabase + inverse: UpstreamDatabase.contigs + required: true + annotations: + mutable: false annotations: plural: Contigs @@ -234,3 +241,21 @@ classes: inverse: entity.id annotations: hidden: true + + UpstreamDatabase: + is_a: Entity + mixins: + - EntityMixin + attributes: + name: + range: string + required: true + annotations: + indexed: true + contigs: + range: Contig + inverse: Contig.upstream_database + multivalued: true + # This is where NCBI indexes would live + annotations: + plural: UpstreamDatabases diff --git a/test_app/tests/output/api/mutations.py b/test_app/tests/output/api/mutations.py new file mode 100644 index 0000000..01772cb --- /dev/null +++ b/test_app/tests/output/api/mutations.py @@ -0,0 +1,56 @@ +""" +GraphQL mutations for files and entities + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/api/mutations.py.j2 instead. +""" + +import strawberry +from typing import Sequence +from api.files import ( + File, + create_file, + upload_file, + mark_upload_complete, + concatenate_files, + SignedURL, + MultipartUploadResponse, +) +from api.types.sample import Sample, create_sample, update_sample, delete_sample +from api.types.sequencing_read import ( + SequencingRead, + create_sequencing_read, + update_sequencing_read, + delete_sequencing_read, +) +from api.types.genomic_range import GenomicRange, create_genomic_range, update_genomic_range, delete_genomic_range +from api.types.contig import Contig, create_contig, update_contig, delete_contig + + +@strawberry.type +class Mutation: + # File mutations + create_file: File = create_file + upload_file: MultipartUploadResponse = upload_file + mark_upload_complete: File = mark_upload_complete + concatenate_files: SignedURL = concatenate_files + + # Sample mutations + create_sample: Sample = create_sample + update_sample: Sequence[Sample] = update_sample + delete_sample: Sequence[Sample] = delete_sample + + # SequencingRead mutations + create_sequencing_read: SequencingRead = create_sequencing_read + update_sequencing_read: Sequence[SequencingRead] = update_sequencing_read + delete_sequencing_read: Sequence[SequencingRead] = delete_sequencing_read + + # GenomicRange mutations + create_genomic_range: GenomicRange = create_genomic_range + update_genomic_range: Sequence[GenomicRange] = update_genomic_range + delete_genomic_range: Sequence[GenomicRange] = delete_genomic_range + + # Contig mutations + create_contig: Contig = create_contig + update_contig: Sequence[Contig] = update_contig + delete_contig: Sequence[Contig] = delete_contig diff --git a/test_app/tests/output/api/queries.py b/test_app/tests/output/api/queries.py new file mode 100644 index 0000000..5c659a7 --- /dev/null +++ b/test_app/tests/output/api/queries.py @@ -0,0 +1,51 @@ +""" +Supported GraphQL queries for files and entities + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/api/queries.py.j2 instead. +""" + +import strawberry +from strawberry import relay +from typing import Sequence, List +from api.files import File, resolve_files +from api.types.sample import Sample, resolve_samples, SampleAggregate, resolve_samples_aggregate +from api.types.sequencing_read import ( + SequencingRead, + resolve_sequencing_reads, + SequencingReadAggregate, + resolve_sequencing_reads_aggregate, +) +from api.types.genomic_range import ( + GenomicRange, + resolve_genomic_ranges, + GenomicRangeAggregate, + resolve_genomic_ranges_aggregate, +) +from api.types.contig import Contig, resolve_contigs, ContigAggregate, resolve_contigs_aggregate + + +@strawberry.type +class Query: + # Allow relay-style queries by node ID + node: relay.Node = relay.node() + nodes: List[relay.Node] = relay.node() + # Query files + files: Sequence[File] = resolve_files + + # Query entities + samples: Sequence[Sample] = resolve_samples + sequencing_reads: Sequence[SequencingRead] = resolve_sequencing_reads + genomic_ranges: Sequence[GenomicRange] = resolve_genomic_ranges + contigs: Sequence[Contig] = resolve_contigs + + # Query entity aggregates + samples_aggregate: SampleAggregate = resolve_samples_aggregate + sequencing_reads_aggregate: SequencingReadAggregate = resolve_sequencing_reads_aggregate + genomic_ranges_aggregate: GenomicRangeAggregate = resolve_genomic_ranges_aggregate + contigs_aggregate: ContigAggregate = resolve_contigs_aggregate + + @strawberry.field + def uncaught_exception(self) -> str: + # Trigger an AttributeException + return self.kaboom diff --git a/test_app/tests/output/api/types/contig.py b/test_app/tests/output/api/types/contig.py new file mode 100644 index 0000000..04baf39 --- /dev/null +++ b/test_app/tests/output/api/types/contig.py @@ -0,0 +1,378 @@ +""" +GraphQL type for Contig + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/api/types/class_name.py.j2 instead. +""" + +# ruff: noqa: E501 Line too long + + +import typing +from typing import TYPE_CHECKING, Annotated, Optional, Sequence + +import database.models as db +import strawberry +from platformics.api.core.helpers import get_db_rows, get_aggregate_db_rows +from api.types.entities import EntityInterface +from cerbos.sdk.client import CerbosClient +from cerbos.sdk.model import Principal, Resource +from fastapi import Depends +from platformics.api.core.errors import PlatformicsException +from platformics.api.core.deps import get_cerbos_client, get_db_session, require_auth_principal +from platformics.api.core.gql_to_sql import ( + aggregator_map, + IntComparators, + StrComparators, + UUIDComparators, +) +from platformics.api.core.strawberry_extensions import DependencyExtension +from platformics.security.authorization import CerbosAction +from sqlalchemy import inspect +from sqlalchemy.engine.row import RowMapping +from sqlalchemy.ext.asyncio import AsyncSession +from strawberry.types import Info +from typing_extensions import TypedDict +import enum + +E = typing.TypeVar("E", db.File, db.Entity) +T = typing.TypeVar("T") + +if TYPE_CHECKING: + from api.types.sequencing_read import SequencingReadWhereClause, SequencingRead + + pass +else: + SequencingReadWhereClause = "SequencingReadWhereClause" + SequencingRead = "SequencingRead" + pass + + +""" +------------------------------------------------------------------------------ +Dataloaders +------------------------------------------------------------------------------ +These are batching functions for loading related objects to avoid N+1 queries. +""" + + +@strawberry.field +async def load_sequencing_read_rows( + root: "Contig", + info: Info, + where: Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")] | None = None, +) -> Optional[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]]: + dataloader = info.context["sqlalchemy_loader"] + mapper = inspect(db.Contig) + relationship = mapper.relationships["sequencing_read"] + return await dataloader.loader_for(relationship, where).load(root.sequencing_read_id) # type:ignore + + +""" +------------------------------------------------------------------------------ +Define Strawberry GQL types +------------------------------------------------------------------------------ +""" + +""" +Only let users specify IDs in WHERE clause when mutating data (for safety). +We can extend that list as we gather more use cases from the FE team. +""" + + +@strawberry.input +class ContigWhereClauseMutations(TypedDict): + id: UUIDComparators | None + + +""" +Supported WHERE clause attributes +""" + + +@strawberry.input +class ContigWhereClause(TypedDict): + id: UUIDComparators | None + producing_run_id: IntComparators | None + owner_user_id: IntComparators | None + collection_id: IntComparators | None + sequencing_read: ( + Optional[Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")]] | None + ) + sequence: Optional[StrComparators] | None + + +""" +Define Contig type +""" + + +@strawberry.type +class Contig(EntityInterface): + id: strawberry.ID + producing_run_id: Optional[int] + owner_user_id: int + collection_id: int + sequencing_read: Optional[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]] = ( + load_sequencing_read_rows + ) # type:ignore + sequence: str + + +""" +We need to add this to each Queryable type so that strawberry will accept either our +Strawberry type *or* a SQLAlchemy model instance as a valid response class from a resolver +""" +Contig.__strawberry_definition__.is_type_of = ( # type: ignore + lambda obj, info: type(obj) == db.Contig or type(obj) == Contig +) + +""" +------------------------------------------------------------------------------ +Aggregation types +------------------------------------------------------------------------------ +""" + +""" +Define columns that support numerical aggregations +""" + + +@strawberry.type +class ContigNumericalColumns: + producing_run_id: Optional[int] = None + owner_user_id: Optional[int] = None + collection_id: Optional[int] = None + + +""" +Define columns that support min/max aggregations +""" + + +@strawberry.type +class ContigMinMaxColumns: + producing_run_id: Optional[int] = None + owner_user_id: Optional[int] = None + collection_id: Optional[int] = None + sequence: Optional[str] = None + + +""" +Define enum of all columns to support count and count(distinct) aggregations +""" + + +@strawberry.enum +class ContigCountColumns(enum.Enum): + sequencing_read = "sequencing_read" + sequence = "sequence" + entity_id = "entity_id" + id = "id" + producing_run_id = "producing_run_id" + owner_user_id = "owner_user_id" + collection_id = "collection_id" + created_at = "created_at" + updated_at = "updated_at" + deleted_at = "deleted_at" + + +""" +All supported aggregation functions +""" + + +@strawberry.type +class ContigAggregateFunctions: + # This is a hack to accept "distinct" and "columns" as arguments to "count" + @strawberry.field + def count(self, distinct: Optional[bool] = False, columns: Optional[ContigCountColumns] = None) -> Optional[int]: + # Count gets set with the proper value in the resolver, so we just return it here + return self.count # type: ignore + + sum: Optional[ContigNumericalColumns] = None + avg: Optional[ContigNumericalColumns] = None + min: Optional[ContigMinMaxColumns] = None + max: Optional[ContigMinMaxColumns] = None + stddev: Optional[ContigNumericalColumns] = None + variance: Optional[ContigNumericalColumns] = None + + +""" +Wrapper around ContigAggregateFunctions +""" + + +@strawberry.type +class ContigAggregate: + aggregate: Optional[ContigAggregateFunctions] = None + + +""" +------------------------------------------------------------------------------ +Mutation types +------------------------------------------------------------------------------ +""" + + +@strawberry.input() +class ContigCreateInput: + collection_id: int + sequencing_read_id: Optional[strawberry.ID] = None + sequence: str + + +@strawberry.input() +class ContigUpdateInput: + collection_id: Optional[int] = None + sequencing_read_id: Optional[strawberry.ID] = None + sequence: Optional[str] = None + + +""" +------------------------------------------------------------------------------ +Utilities +------------------------------------------------------------------------------ +""" + + +@strawberry.field(extensions=[DependencyExtension()]) +async def resolve_contigs( + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), + where: Optional[ContigWhereClause] = None, +) -> typing.Sequence[Contig]: + """ + Resolve Contig objects. Used for queries (see api/queries.py). + """ + return await get_db_rows(db.Contig, session, cerbos_client, principal, where, []) # type: ignore + + +def format_contig_aggregate_output(query_results: RowMapping) -> ContigAggregateFunctions: + """ + Given a row from the DB containing the results of an aggregate query, + format the results using the proper GraphQL types. + """ + output = ContigAggregateFunctions() + for aggregate_name, value in query_results.items(): + if aggregate_name == "count": + output.count = value + else: + aggregator_fn, col_name = aggregate_name.split("_", 1) + # Filter out the group_by key from the results if one was provided. + if aggregator_fn in aggregator_map.keys(): + if not getattr(output, aggregator_fn): + if aggregate_name in ["min", "max"]: + setattr(output, aggregator_fn, ContigMinMaxColumns()) + else: + setattr(output, aggregator_fn, ContigNumericalColumns()) + setattr(getattr(output, aggregator_fn), col_name, value) + return output + + +@strawberry.field(extensions=[DependencyExtension()]) +async def resolve_contigs_aggregate( + info: Info, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), + where: Optional[ContigWhereClause] = None, +) -> ContigAggregate: + """ + Aggregate values for Contig objects. Used for queries (see api/queries.py). + """ + # Get the selected aggregate functions and columns to operate on + # TODO: not sure why selected_fields is a list + # The first list of selections will always be ["aggregate"], so just grab the first item + selections = info.selected_fields[0].selections[0].selections + rows = await get_aggregate_db_rows(db.Contig, session, cerbos_client, principal, where, selections, []) # type: ignore + aggregate_output = format_contig_aggregate_output(rows) + return ContigAggregate(aggregate=aggregate_output) + + +@strawberry.mutation(extensions=[DependencyExtension()]) +async def create_contig( + input: ContigCreateInput, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), +) -> db.Entity: + """ + Create a new Contig object. Used for mutations (see api/mutations.py). + """ + params = input.__dict__ + + # Validate that user can create entity in this collection + attr = {"collection_id": input.collection_id} + resource = Resource(id="NEW_ID", kind=db.Contig.__tablename__, attr=attr) + if not cerbos_client.is_allowed("create", principal, resource): + raise PlatformicsException("Unauthorized: Cannot create entity in this collection") + + # Save to DB + params["owner_user_id"] = int(principal.id) + new_entity = db.Contig(**params) + session.add(new_entity) + await session.commit() + return new_entity + + +@strawberry.mutation(extensions=[DependencyExtension()]) +async def update_contig( + input: ContigUpdateInput, + where: ContigWhereClauseMutations, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), +) -> Sequence[db.Entity]: + """ + Update Contig objects. Used for mutations (see api/mutations.py). + """ + params = input.__dict__ + + # Need at least one thing to update + num_params = len([x for x in params if params[x] is not None]) + if num_params == 0: + raise PlatformicsException("No fields to update") + + # Fetch entities for update, if we have access to them + entities = await get_db_rows(db.Contig, session, cerbos_client, principal, where, [], CerbosAction.UPDATE) + if len(entities) == 0: + raise PlatformicsException("Unauthorized: Cannot update entities") + + # Validate that the user has access to the new collection ID + if input.collection_id: + attr = {"collection_id": input.collection_id} + resource = Resource(id="SOME_ID", kind=db.Contig.__tablename__, attr=attr) + if not cerbos_client.is_allowed(CerbosAction.UPDATE, principal, resource): + raise PlatformicsException("Unauthorized: Cannot access new collection") + + # Update DB + for entity in entities: + for key in params: + if params[key]: + setattr(entity, key, params[key]) + await session.commit() + return entities + + +@strawberry.mutation(extensions=[DependencyExtension()]) +async def delete_contig( + where: ContigWhereClauseMutations, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), +) -> Sequence[db.Entity]: + """ + Delete Contig objects. Used for mutations (see api/mutations.py). + """ + # Fetch entities for deletion, if we have access to them + entities = await get_db_rows(db.Contig, session, cerbos_client, principal, where, [], CerbosAction.DELETE) + if len(entities) == 0: + raise PlatformicsException("Unauthorized: Cannot delete entities") + + # Update DB + for entity in entities: + await session.delete(entity) + await session.commit() + return entities diff --git a/test_app/tests/output/api/types/genomic_range.py b/test_app/tests/output/api/types/genomic_range.py new file mode 100644 index 0000000..1745de4 --- /dev/null +++ b/test_app/tests/output/api/types/genomic_range.py @@ -0,0 +1,426 @@ +""" +GraphQL type for GenomicRange + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/api/types/class_name.py.j2 instead. +""" + +# ruff: noqa: E501 Line too long + + +import typing +from typing import TYPE_CHECKING, Annotated, Optional, Sequence, Callable + +import database.models as db +import strawberry +from platformics.api.core.helpers import get_db_rows, get_aggregate_db_rows +from api.files import File, FileWhereClause +from api.types.entities import EntityInterface +from api.types.sequencing_read import SequencingReadAggregate, format_sequencing_read_aggregate_output +from cerbos.sdk.client import CerbosClient +from cerbos.sdk.model import Principal, Resource +from fastapi import Depends +from platformics.api.core.errors import PlatformicsException +from platformics.api.core.deps import get_cerbos_client, get_db_session, require_auth_principal +from platformics.api.core.gql_to_sql import ( + aggregator_map, + IntComparators, + UUIDComparators, +) +from platformics.api.core.strawberry_extensions import DependencyExtension +from platformics.security.authorization import CerbosAction +from sqlalchemy import inspect +from sqlalchemy.engine.row import RowMapping +from sqlalchemy.ext.asyncio import AsyncSession +from strawberry import relay +from strawberry.types import Info +from typing_extensions import TypedDict +import enum + +E = typing.TypeVar("E", db.File, db.Entity) +T = typing.TypeVar("T") + +if TYPE_CHECKING: + from api.types.sequencing_read import SequencingReadWhereClause, SequencingRead + + pass +else: + SequencingReadWhereClause = "SequencingReadWhereClause" + SequencingRead = "SequencingRead" + pass + + +""" +------------------------------------------------------------------------------ +Dataloaders +------------------------------------------------------------------------------ +These are batching functions for loading related objects to avoid N+1 queries. +""" + + +@relay.connection( + relay.ListConnection[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]] # type:ignore +) +async def load_sequencing_read_rows( + root: "GenomicRange", + info: Info, + where: Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")] | None = None, +) -> Sequence[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]]: + dataloader = info.context["sqlalchemy_loader"] + mapper = inspect(db.GenomicRange) + relationship = mapper.relationships["sequencing_reads"] + return await dataloader.loader_for(relationship, where).load(root.id) # type:ignore + + +@strawberry.field +async def load_sequencing_read_aggregate_rows( + root: "GenomicRange", + info: Info, + where: Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")] | None = None, +) -> Optional[Annotated["SequencingReadAggregate", strawberry.lazy("api.types.sequencing_read")]]: + selections = info.selected_fields[0].selections[0].selections + dataloader = info.context["sqlalchemy_loader"] + mapper = inspect(db.GenomicRange) + relationship = mapper.relationships["sequencing_reads"] + rows = await dataloader.aggregate_loader_for(relationship, where, selections).load(root.id) # type:ignore + # Aggregate queries always return a single row, so just grab the first one + result = rows[0] if rows else None + aggregate_output = format_sequencing_read_aggregate_output(result) + return SequencingReadAggregate(aggregate=aggregate_output) + + +""" +------------------------------------------------------------------------------ +Dataloader for File object +------------------------------------------------------------------------------ +""" + + +def load_files_from(attr_name: str) -> Callable: + @strawberry.field + async def load_files( + root: "GenomicRange", + info: Info, + where: Annotated["FileWhereClause", strawberry.lazy("api.files")] | None = None, + ) -> Optional[Annotated["File", strawberry.lazy("api.files")]]: + """ + Given a list of GenomicRange IDs for a certain file type, return related Files + """ + dataloader = info.context["sqlalchemy_loader"] + mapper = inspect(db.GenomicRange) + relationship = mapper.relationships[attr_name] + return await dataloader.loader_for(relationship, where).load(getattr(root, f"{attr_name}_id")) # type:ignore + + return load_files + + +""" +------------------------------------------------------------------------------ +Define Strawberry GQL types +------------------------------------------------------------------------------ +""" + +""" +Only let users specify IDs in WHERE clause when mutating data (for safety). +We can extend that list as we gather more use cases from the FE team. +""" + + +@strawberry.input +class GenomicRangeWhereClauseMutations(TypedDict): + id: UUIDComparators | None + + +""" +Supported WHERE clause attributes +""" + + +@strawberry.input +class GenomicRangeWhereClause(TypedDict): + id: UUIDComparators | None + producing_run_id: IntComparators | None + owner_user_id: IntComparators | None + collection_id: IntComparators | None + sequencing_reads: ( + Optional[Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")]] | None + ) + + +""" +Define GenomicRange type +""" + + +@strawberry.type +class GenomicRange(EntityInterface): + id: strawberry.ID + producing_run_id: Optional[int] + owner_user_id: int + collection_id: int + file_id: Optional[strawberry.ID] + file: Optional[Annotated["File", strawberry.lazy("api.files")]] = load_files_from("file") # type: ignore + sequencing_reads: Sequence[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]] = ( + load_sequencing_read_rows + ) # type:ignore + sequencing_reads_aggregate: Optional[ + Annotated["SequencingReadAggregate", strawberry.lazy("api.types.sequencing_read")] + ] = load_sequencing_read_aggregate_rows # type:ignore + + +""" +We need to add this to each Queryable type so that strawberry will accept either our +Strawberry type *or* a SQLAlchemy model instance as a valid response class from a resolver +""" +GenomicRange.__strawberry_definition__.is_type_of = ( # type: ignore + lambda obj, info: type(obj) == db.GenomicRange or type(obj) == GenomicRange +) + +""" +------------------------------------------------------------------------------ +Aggregation types +------------------------------------------------------------------------------ +""" + +""" +Define columns that support numerical aggregations +""" + + +@strawberry.type +class GenomicRangeNumericalColumns: + producing_run_id: Optional[int] = None + owner_user_id: Optional[int] = None + collection_id: Optional[int] = None + + +""" +Define columns that support min/max aggregations +""" + + +@strawberry.type +class GenomicRangeMinMaxColumns: + producing_run_id: Optional[int] = None + owner_user_id: Optional[int] = None + collection_id: Optional[int] = None + + +""" +Define enum of all columns to support count and count(distinct) aggregations +""" + + +@strawberry.enum +class GenomicRangeCountColumns(enum.Enum): + file = "file" + sequencing_reads = "sequencing_reads" + entity_id = "entity_id" + id = "id" + producing_run_id = "producing_run_id" + owner_user_id = "owner_user_id" + collection_id = "collection_id" + created_at = "created_at" + updated_at = "updated_at" + deleted_at = "deleted_at" + + +""" +All supported aggregation functions +""" + + +@strawberry.type +class GenomicRangeAggregateFunctions: + # This is a hack to accept "distinct" and "columns" as arguments to "count" + @strawberry.field + def count( + self, distinct: Optional[bool] = False, columns: Optional[GenomicRangeCountColumns] = None + ) -> Optional[int]: + # Count gets set with the proper value in the resolver, so we just return it here + return self.count # type: ignore + + sum: Optional[GenomicRangeNumericalColumns] = None + avg: Optional[GenomicRangeNumericalColumns] = None + min: Optional[GenomicRangeMinMaxColumns] = None + max: Optional[GenomicRangeMinMaxColumns] = None + stddev: Optional[GenomicRangeNumericalColumns] = None + variance: Optional[GenomicRangeNumericalColumns] = None + + +""" +Wrapper around GenomicRangeAggregateFunctions +""" + + +@strawberry.type +class GenomicRangeAggregate: + aggregate: Optional[GenomicRangeAggregateFunctions] = None + + +""" +------------------------------------------------------------------------------ +Mutation types +------------------------------------------------------------------------------ +""" + + +@strawberry.input() +class GenomicRangeCreateInput: + collection_id: int + file_id: Optional[strawberry.ID] = None + + +@strawberry.input() +class GenomicRangeUpdateInput: + collection_id: Optional[int] = None + file_id: Optional[strawberry.ID] = None + + +""" +------------------------------------------------------------------------------ +Utilities +------------------------------------------------------------------------------ +""" + + +@strawberry.field(extensions=[DependencyExtension()]) +async def resolve_genomic_ranges( + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), + where: Optional[GenomicRangeWhereClause] = None, +) -> typing.Sequence[GenomicRange]: + """ + Resolve GenomicRange objects. Used for queries (see api/queries.py). + """ + return await get_db_rows(db.GenomicRange, session, cerbos_client, principal, where, []) # type: ignore + + +def format_genomic_range_aggregate_output(query_results: RowMapping) -> GenomicRangeAggregateFunctions: + """ + Given a row from the DB containing the results of an aggregate query, + format the results using the proper GraphQL types. + """ + output = GenomicRangeAggregateFunctions() + for aggregate_name, value in query_results.items(): + if aggregate_name == "count": + output.count = value + else: + aggregator_fn, col_name = aggregate_name.split("_", 1) + # Filter out the group_by key from the results if one was provided. + if aggregator_fn in aggregator_map.keys(): + if not getattr(output, aggregator_fn): + if aggregate_name in ["min", "max"]: + setattr(output, aggregator_fn, GenomicRangeMinMaxColumns()) + else: + setattr(output, aggregator_fn, GenomicRangeNumericalColumns()) + setattr(getattr(output, aggregator_fn), col_name, value) + return output + + +@strawberry.field(extensions=[DependencyExtension()]) +async def resolve_genomic_ranges_aggregate( + info: Info, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), + where: Optional[GenomicRangeWhereClause] = None, +) -> GenomicRangeAggregate: + """ + Aggregate values for GenomicRange objects. Used for queries (see api/queries.py). + """ + # Get the selected aggregate functions and columns to operate on + # TODO: not sure why selected_fields is a list + # The first list of selections will always be ["aggregate"], so just grab the first item + selections = info.selected_fields[0].selections[0].selections + rows = await get_aggregate_db_rows(db.GenomicRange, session, cerbos_client, principal, where, selections, []) # type: ignore + aggregate_output = format_genomic_range_aggregate_output(rows) + return GenomicRangeAggregate(aggregate=aggregate_output) + + +@strawberry.mutation(extensions=[DependencyExtension()]) +async def create_genomic_range( + input: GenomicRangeCreateInput, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), +) -> db.Entity: + """ + Create a new GenomicRange object. Used for mutations (see api/mutations.py). + """ + params = input.__dict__ + + # Validate that user can create entity in this collection + attr = {"collection_id": input.collection_id} + resource = Resource(id="NEW_ID", kind=db.GenomicRange.__tablename__, attr=attr) + if not cerbos_client.is_allowed("create", principal, resource): + raise PlatformicsException("Unauthorized: Cannot create entity in this collection") + + # Save to DB + params["owner_user_id"] = int(principal.id) + new_entity = db.GenomicRange(**params) + session.add(new_entity) + await session.commit() + return new_entity + + +@strawberry.mutation(extensions=[DependencyExtension()]) +async def update_genomic_range( + input: GenomicRangeUpdateInput, + where: GenomicRangeWhereClauseMutations, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), +) -> Sequence[db.Entity]: + """ + Update GenomicRange objects. Used for mutations (see api/mutations.py). + """ + params = input.__dict__ + + # Need at least one thing to update + num_params = len([x for x in params if params[x] is not None]) + if num_params == 0: + raise PlatformicsException("No fields to update") + + # Fetch entities for update, if we have access to them + entities = await get_db_rows(db.GenomicRange, session, cerbos_client, principal, where, [], CerbosAction.UPDATE) + if len(entities) == 0: + raise PlatformicsException("Unauthorized: Cannot update entities") + + # Validate that the user has access to the new collection ID + if input.collection_id: + attr = {"collection_id": input.collection_id} + resource = Resource(id="SOME_ID", kind=db.GenomicRange.__tablename__, attr=attr) + if not cerbos_client.is_allowed(CerbosAction.UPDATE, principal, resource): + raise PlatformicsException("Unauthorized: Cannot access new collection") + + # Update DB + for entity in entities: + for key in params: + if params[key]: + setattr(entity, key, params[key]) + await session.commit() + return entities + + +@strawberry.mutation(extensions=[DependencyExtension()]) +async def delete_genomic_range( + where: GenomicRangeWhereClauseMutations, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), +) -> Sequence[db.Entity]: + """ + Delete GenomicRange objects. Used for mutations (see api/mutations.py). + """ + # Fetch entities for deletion, if we have access to them + entities = await get_db_rows(db.GenomicRange, session, cerbos_client, principal, where, [], CerbosAction.DELETE) + if len(entities) == 0: + raise PlatformicsException("Unauthorized: Cannot delete entities") + + # Update DB + for entity in entities: + await session.delete(entity) + await session.commit() + return entities diff --git a/test_app/tests/output/api/types/sample.py b/test_app/tests/output/api/types/sample.py new file mode 100644 index 0000000..4d902f8 --- /dev/null +++ b/test_app/tests/output/api/types/sample.py @@ -0,0 +1,432 @@ +""" +GraphQL type for Sample + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/api/types/class_name.py.j2 instead. +""" + +# ruff: noqa: E501 Line too long + + +import typing +from typing import TYPE_CHECKING, Annotated, Optional, Sequence + +import database.models as db +import strawberry +import datetime +from platformics.api.core.helpers import get_db_rows, get_aggregate_db_rows +from api.types.entities import EntityInterface +from api.types.sequencing_read import SequencingReadAggregate, format_sequencing_read_aggregate_output +from cerbos.sdk.client import CerbosClient +from cerbos.sdk.model import Principal, Resource +from fastapi import Depends +from platformics.api.core.errors import PlatformicsException +from platformics.api.core.deps import get_cerbos_client, get_db_session, require_auth_principal +from platformics.api.core.gql_to_sql import ( + aggregator_map, + DatetimeComparators, + IntComparators, + StrComparators, + UUIDComparators, + BoolComparators, +) +from platformics.api.core.strawberry_extensions import DependencyExtension +from platformics.security.authorization import CerbosAction +from sqlalchemy import inspect +from sqlalchemy.engine.row import RowMapping +from sqlalchemy.ext.asyncio import AsyncSession +from strawberry import relay +from strawberry.types import Info +from typing_extensions import TypedDict +import enum + +E = typing.TypeVar("E", db.File, db.Entity) +T = typing.TypeVar("T") + +if TYPE_CHECKING: + from api.types.sequencing_read import SequencingReadWhereClause, SequencingRead + + pass +else: + SequencingReadWhereClause = "SequencingReadWhereClause" + SequencingRead = "SequencingRead" + pass + + +""" +------------------------------------------------------------------------------ +Dataloaders +------------------------------------------------------------------------------ +These are batching functions for loading related objects to avoid N+1 queries. +""" + + +@relay.connection( + relay.ListConnection[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]] # type:ignore +) +async def load_sequencing_read_rows( + root: "Sample", + info: Info, + where: Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")] | None = None, +) -> Sequence[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]]: + dataloader = info.context["sqlalchemy_loader"] + mapper = inspect(db.Sample) + relationship = mapper.relationships["sequencing_reads"] + return await dataloader.loader_for(relationship, where).load(root.id) # type:ignore + + +@strawberry.field +async def load_sequencing_read_aggregate_rows( + root: "Sample", + info: Info, + where: Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")] | None = None, +) -> Optional[Annotated["SequencingReadAggregate", strawberry.lazy("api.types.sequencing_read")]]: + selections = info.selected_fields[0].selections[0].selections + dataloader = info.context["sqlalchemy_loader"] + mapper = inspect(db.Sample) + relationship = mapper.relationships["sequencing_reads"] + rows = await dataloader.aggregate_loader_for(relationship, where, selections).load(root.id) # type:ignore + # Aggregate queries always return a single row, so just grab the first one + result = rows[0] if rows else None + aggregate_output = format_sequencing_read_aggregate_output(result) + return SequencingReadAggregate(aggregate=aggregate_output) + + +""" +------------------------------------------------------------------------------ +Define Strawberry GQL types +------------------------------------------------------------------------------ +""" + +""" +Only let users specify IDs in WHERE clause when mutating data (for safety). +We can extend that list as we gather more use cases from the FE team. +""" + + +@strawberry.input +class SampleWhereClauseMutations(TypedDict): + id: UUIDComparators | None + + +""" +Supported WHERE clause attributes +""" + + +@strawberry.input +class SampleWhereClause(TypedDict): + id: UUIDComparators | None + producing_run_id: IntComparators | None + owner_user_id: IntComparators | None + collection_id: IntComparators | None + name: Optional[StrComparators] | None + sample_type: Optional[StrComparators] | None + water_control: Optional[BoolComparators] | None + collection_date: Optional[DatetimeComparators] | None + collection_location: Optional[StrComparators] | None + description: Optional[StrComparators] | None + sequencing_reads: ( + Optional[Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")]] | None + ) + + +""" +Define Sample type +""" + + +@strawberry.type +class Sample(EntityInterface): + id: strawberry.ID + producing_run_id: Optional[int] + owner_user_id: int + collection_id: int + name: str + sample_type: str + water_control: bool + collection_date: Optional[datetime.datetime] = None + collection_location: str + description: Optional[str] = None + sequencing_reads: Sequence[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]] = ( + load_sequencing_read_rows + ) # type:ignore + sequencing_reads_aggregate: Optional[ + Annotated["SequencingReadAggregate", strawberry.lazy("api.types.sequencing_read")] + ] = load_sequencing_read_aggregate_rows # type:ignore + + +""" +We need to add this to each Queryable type so that strawberry will accept either our +Strawberry type *or* a SQLAlchemy model instance as a valid response class from a resolver +""" +Sample.__strawberry_definition__.is_type_of = ( # type: ignore + lambda obj, info: type(obj) == db.Sample or type(obj) == Sample +) + +""" +------------------------------------------------------------------------------ +Aggregation types +------------------------------------------------------------------------------ +""" + +""" +Define columns that support numerical aggregations +""" + + +@strawberry.type +class SampleNumericalColumns: + producing_run_id: Optional[int] = None + owner_user_id: Optional[int] = None + collection_id: Optional[int] = None + + +""" +Define columns that support min/max aggregations +""" + + +@strawberry.type +class SampleMinMaxColumns: + producing_run_id: Optional[int] = None + owner_user_id: Optional[int] = None + collection_id: Optional[int] = None + name: Optional[str] = None + sample_type: Optional[str] = None + collection_date: Optional[datetime.datetime] = None + collection_location: Optional[str] = None + description: Optional[str] = None + + +""" +Define enum of all columns to support count and count(distinct) aggregations +""" + + +@strawberry.enum +class SampleCountColumns(enum.Enum): + name = "name" + sample_type = "sample_type" + water_control = "water_control" + collection_date = "collection_date" + collection_location = "collection_location" + description = "description" + sequencing_reads = "sequencing_reads" + entity_id = "entity_id" + id = "id" + producing_run_id = "producing_run_id" + owner_user_id = "owner_user_id" + collection_id = "collection_id" + created_at = "created_at" + updated_at = "updated_at" + deleted_at = "deleted_at" + + +""" +All supported aggregation functions +""" + + +@strawberry.type +class SampleAggregateFunctions: + # This is a hack to accept "distinct" and "columns" as arguments to "count" + @strawberry.field + def count(self, distinct: Optional[bool] = False, columns: Optional[SampleCountColumns] = None) -> Optional[int]: + # Count gets set with the proper value in the resolver, so we just return it here + return self.count # type: ignore + + sum: Optional[SampleNumericalColumns] = None + avg: Optional[SampleNumericalColumns] = None + min: Optional[SampleMinMaxColumns] = None + max: Optional[SampleMinMaxColumns] = None + stddev: Optional[SampleNumericalColumns] = None + variance: Optional[SampleNumericalColumns] = None + + +""" +Wrapper around SampleAggregateFunctions +""" + + +@strawberry.type +class SampleAggregate: + aggregate: Optional[SampleAggregateFunctions] = None + + +""" +------------------------------------------------------------------------------ +Mutation types +------------------------------------------------------------------------------ +""" + + +@strawberry.input() +class SampleCreateInput: + collection_id: int + name: str + sample_type: str + water_control: bool + collection_date: Optional[datetime.datetime] = None + collection_location: str + description: Optional[str] = None + + +@strawberry.input() +class SampleUpdateInput: + collection_id: Optional[int] = None + name: Optional[str] = None + sample_type: Optional[str] = None + water_control: Optional[bool] = None + collection_date: Optional[datetime.datetime] = None + collection_location: Optional[str] = None + description: Optional[str] = None + + +""" +------------------------------------------------------------------------------ +Utilities +------------------------------------------------------------------------------ +""" + + +@strawberry.field(extensions=[DependencyExtension()]) +async def resolve_samples( + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), + where: Optional[SampleWhereClause] = None, +) -> typing.Sequence[Sample]: + """ + Resolve Sample objects. Used for queries (see api/queries.py). + """ + return await get_db_rows(db.Sample, session, cerbos_client, principal, where, []) # type: ignore + + +def format_sample_aggregate_output(query_results: RowMapping) -> SampleAggregateFunctions: + """ + Given a row from the DB containing the results of an aggregate query, + format the results using the proper GraphQL types. + """ + output = SampleAggregateFunctions() + for aggregate_name, value in query_results.items(): + if aggregate_name == "count": + output.count = value + else: + aggregator_fn, col_name = aggregate_name.split("_", 1) + # Filter out the group_by key from the results if one was provided. + if aggregator_fn in aggregator_map.keys(): + if not getattr(output, aggregator_fn): + if aggregate_name in ["min", "max"]: + setattr(output, aggregator_fn, SampleMinMaxColumns()) + else: + setattr(output, aggregator_fn, SampleNumericalColumns()) + setattr(getattr(output, aggregator_fn), col_name, value) + return output + + +@strawberry.field(extensions=[DependencyExtension()]) +async def resolve_samples_aggregate( + info: Info, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), + where: Optional[SampleWhereClause] = None, +) -> SampleAggregate: + """ + Aggregate values for Sample objects. Used for queries (see api/queries.py). + """ + # Get the selected aggregate functions and columns to operate on + # TODO: not sure why selected_fields is a list + # The first list of selections will always be ["aggregate"], so just grab the first item + selections = info.selected_fields[0].selections[0].selections + rows = await get_aggregate_db_rows(db.Sample, session, cerbos_client, principal, where, selections, []) # type: ignore + aggregate_output = format_sample_aggregate_output(rows) + return SampleAggregate(aggregate=aggregate_output) + + +@strawberry.mutation(extensions=[DependencyExtension()]) +async def create_sample( + input: SampleCreateInput, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), +) -> db.Entity: + """ + Create a new Sample object. Used for mutations (see api/mutations.py). + """ + params = input.__dict__ + + # Validate that user can create entity in this collection + attr = {"collection_id": input.collection_id} + resource = Resource(id="NEW_ID", kind=db.Sample.__tablename__, attr=attr) + if not cerbos_client.is_allowed("create", principal, resource): + raise PlatformicsException("Unauthorized: Cannot create entity in this collection") + + # Save to DB + params["owner_user_id"] = int(principal.id) + new_entity = db.Sample(**params) + session.add(new_entity) + await session.commit() + return new_entity + + +@strawberry.mutation(extensions=[DependencyExtension()]) +async def update_sample( + input: SampleUpdateInput, + where: SampleWhereClauseMutations, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), +) -> Sequence[db.Entity]: + """ + Update Sample objects. Used for mutations (see api/mutations.py). + """ + params = input.__dict__ + + # Need at least one thing to update + num_params = len([x for x in params if params[x] is not None]) + if num_params == 0: + raise PlatformicsException("No fields to update") + + # Fetch entities for update, if we have access to them + entities = await get_db_rows(db.Sample, session, cerbos_client, principal, where, [], CerbosAction.UPDATE) + if len(entities) == 0: + raise PlatformicsException("Unauthorized: Cannot update entities") + + # Validate that the user has access to the new collection ID + if input.collection_id: + attr = {"collection_id": input.collection_id} + resource = Resource(id="SOME_ID", kind=db.Sample.__tablename__, attr=attr) + if not cerbos_client.is_allowed(CerbosAction.UPDATE, principal, resource): + raise PlatformicsException("Unauthorized: Cannot access new collection") + + # Update DB + for entity in entities: + for key in params: + if params[key]: + setattr(entity, key, params[key]) + await session.commit() + return entities + + +@strawberry.mutation(extensions=[DependencyExtension()]) +async def delete_sample( + where: SampleWhereClauseMutations, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), +) -> Sequence[db.Entity]: + """ + Delete Sample objects. Used for mutations (see api/mutations.py). + """ + # Fetch entities for deletion, if we have access to them + entities = await get_db_rows(db.Sample, session, cerbos_client, principal, where, [], CerbosAction.DELETE) + if len(entities) == 0: + raise PlatformicsException("Unauthorized: Cannot delete entities") + + # Update DB + for entity in entities: + await session.delete(entity) + await session.commit() + return entities diff --git a/test_app/tests/output/api/types/sequencing_read.py b/test_app/tests/output/api/types/sequencing_read.py new file mode 100644 index 0000000..28a00e1 --- /dev/null +++ b/test_app/tests/output/api/types/sequencing_read.py @@ -0,0 +1,492 @@ +""" +GraphQL type for SequencingRead + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/api/types/class_name.py.j2 instead. +""" + +# ruff: noqa: E501 Line too long + + +import typing +from typing import TYPE_CHECKING, Annotated, Optional, Sequence, Callable + +import database.models as db +import strawberry +from platformics.api.core.helpers import get_db_rows, get_aggregate_db_rows +from api.files import File, FileWhereClause +from api.types.entities import EntityInterface +from api.types.contig import ContigAggregate, format_contig_aggregate_output +from cerbos.sdk.client import CerbosClient +from cerbos.sdk.model import Principal, Resource +from fastapi import Depends +from platformics.api.core.errors import PlatformicsException +from platformics.api.core.deps import get_cerbos_client, get_db_session, require_auth_principal +from platformics.api.core.gql_to_sql import ( + aggregator_map, + EnumComparators, + IntComparators, + UUIDComparators, + BoolComparators, +) +from platformics.api.core.strawberry_extensions import DependencyExtension +from platformics.security.authorization import CerbosAction +from sqlalchemy import inspect +from sqlalchemy.engine.row import RowMapping +from sqlalchemy.ext.asyncio import AsyncSession +from strawberry import relay +from strawberry.types import Info +from typing_extensions import TypedDict +import enum +from support.enums import SequencingProtocol, SequencingTechnology, NucleicAcid + +E = typing.TypeVar("E", db.File, db.Entity) +T = typing.TypeVar("T") + +if TYPE_CHECKING: + from api.types.sample import SampleWhereClause, Sample + from api.types.genomic_range import GenomicRangeWhereClause, GenomicRange + from api.types.contig import ContigWhereClause, Contig + + pass +else: + SampleWhereClause = "SampleWhereClause" + Sample = "Sample" + GenomicRangeWhereClause = "GenomicRangeWhereClause" + GenomicRange = "GenomicRange" + ContigWhereClause = "ContigWhereClause" + Contig = "Contig" + pass + + +""" +------------------------------------------------------------------------------ +Dataloaders +------------------------------------------------------------------------------ +These are batching functions for loading related objects to avoid N+1 queries. +""" + + +@strawberry.field +async def load_sample_rows( + root: "SequencingRead", + info: Info, + where: Annotated["SampleWhereClause", strawberry.lazy("api.types.sample")] | None = None, +) -> Optional[Annotated["Sample", strawberry.lazy("api.types.sample")]]: + dataloader = info.context["sqlalchemy_loader"] + mapper = inspect(db.SequencingRead) + relationship = mapper.relationships["sample"] + return await dataloader.loader_for(relationship, where).load(root.sample_id) # type:ignore + + +@strawberry.field +async def load_genomic_range_rows( + root: "SequencingRead", + info: Info, + where: Annotated["GenomicRangeWhereClause", strawberry.lazy("api.types.genomic_range")] | None = None, +) -> Optional[Annotated["GenomicRange", strawberry.lazy("api.types.genomic_range")]]: + dataloader = info.context["sqlalchemy_loader"] + mapper = inspect(db.SequencingRead) + relationship = mapper.relationships["primer_file"] + return await dataloader.loader_for(relationship, where).load(root.primer_file_id) # type:ignore + + +@relay.connection( + relay.ListConnection[Annotated["Contig", strawberry.lazy("api.types.contig")]] # type:ignore +) +async def load_contig_rows( + root: "SequencingRead", + info: Info, + where: Annotated["ContigWhereClause", strawberry.lazy("api.types.contig")] | None = None, +) -> Sequence[Annotated["Contig", strawberry.lazy("api.types.contig")]]: + dataloader = info.context["sqlalchemy_loader"] + mapper = inspect(db.SequencingRead) + relationship = mapper.relationships["contigs"] + return await dataloader.loader_for(relationship, where).load(root.id) # type:ignore + + +@strawberry.field +async def load_contig_aggregate_rows( + root: "SequencingRead", + info: Info, + where: Annotated["ContigWhereClause", strawberry.lazy("api.types.contig")] | None = None, +) -> Optional[Annotated["ContigAggregate", strawberry.lazy("api.types.contig")]]: + selections = info.selected_fields[0].selections[0].selections + dataloader = info.context["sqlalchemy_loader"] + mapper = inspect(db.SequencingRead) + relationship = mapper.relationships["contigs"] + rows = await dataloader.aggregate_loader_for(relationship, where, selections).load(root.id) # type:ignore + # Aggregate queries always return a single row, so just grab the first one + result = rows[0] if rows else None + aggregate_output = format_contig_aggregate_output(result) + return ContigAggregate(aggregate=aggregate_output) + + +""" +------------------------------------------------------------------------------ +Dataloader for File object +------------------------------------------------------------------------------ +""" + + +def load_files_from(attr_name: str) -> Callable: + @strawberry.field + async def load_files( + root: "SequencingRead", + info: Info, + where: Annotated["FileWhereClause", strawberry.lazy("api.files")] | None = None, + ) -> Optional[Annotated["File", strawberry.lazy("api.files")]]: + """ + Given a list of SequencingRead IDs for a certain file type, return related Files + """ + dataloader = info.context["sqlalchemy_loader"] + mapper = inspect(db.SequencingRead) + relationship = mapper.relationships[attr_name] + return await dataloader.loader_for(relationship, where).load(getattr(root, f"{attr_name}_id")) # type:ignore + + return load_files + + +""" +------------------------------------------------------------------------------ +Define Strawberry GQL types +------------------------------------------------------------------------------ +""" + +""" +Only let users specify IDs in WHERE clause when mutating data (for safety). +We can extend that list as we gather more use cases from the FE team. +""" + + +@strawberry.input +class SequencingReadWhereClauseMutations(TypedDict): + id: UUIDComparators | None + + +""" +Supported WHERE clause attributes +""" + + +@strawberry.input +class SequencingReadWhereClause(TypedDict): + id: UUIDComparators | None + producing_run_id: IntComparators | None + owner_user_id: IntComparators | None + collection_id: IntComparators | None + sample: Optional[Annotated["SampleWhereClause", strawberry.lazy("api.types.sample")]] | None + protocol: Optional[EnumComparators[SequencingProtocol]] | None + technology: Optional[EnumComparators[SequencingTechnology]] | None + nucleic_acid: Optional[EnumComparators[NucleicAcid]] | None + primer_file: Optional[Annotated["GenomicRangeWhereClause", strawberry.lazy("api.types.genomic_range")]] | None + contigs: Optional[Annotated["ContigWhereClause", strawberry.lazy("api.types.contig")]] | None + clearlabs_export: Optional[BoolComparators] | None + + +""" +Define SequencingRead type +""" + + +@strawberry.type +class SequencingRead(EntityInterface): + id: strawberry.ID + producing_run_id: Optional[int] + owner_user_id: int + collection_id: int + sample: Optional[Annotated["Sample", strawberry.lazy("api.types.sample")]] = load_sample_rows # type:ignore + protocol: SequencingProtocol + r1_file_id: Optional[strawberry.ID] + r1_file: Optional[Annotated["File", strawberry.lazy("api.files")]] = load_files_from("r1_file") # type: ignore + r2_file_id: Optional[strawberry.ID] + r2_file: Optional[Annotated["File", strawberry.lazy("api.files")]] = load_files_from("r2_file") # type: ignore + technology: SequencingTechnology + nucleic_acid: NucleicAcid + primer_file: Optional[Annotated["GenomicRange", strawberry.lazy("api.types.genomic_range")]] = ( + load_genomic_range_rows + ) # type:ignore + contigs: Sequence[Annotated["Contig", strawberry.lazy("api.types.contig")]] = load_contig_rows # type:ignore + contigs_aggregate: Optional[Annotated["ContigAggregate", strawberry.lazy("api.types.contig")]] = ( + load_contig_aggregate_rows + ) # type:ignore + clearlabs_export: bool + + +""" +We need to add this to each Queryable type so that strawberry will accept either our +Strawberry type *or* a SQLAlchemy model instance as a valid response class from a resolver +""" +SequencingRead.__strawberry_definition__.is_type_of = ( # type: ignore + lambda obj, info: type(obj) == db.SequencingRead or type(obj) == SequencingRead +) + +""" +------------------------------------------------------------------------------ +Aggregation types +------------------------------------------------------------------------------ +""" + +""" +Define columns that support numerical aggregations +""" + + +@strawberry.type +class SequencingReadNumericalColumns: + producing_run_id: Optional[int] = None + owner_user_id: Optional[int] = None + collection_id: Optional[int] = None + + +""" +Define columns that support min/max aggregations +""" + + +@strawberry.type +class SequencingReadMinMaxColumns: + producing_run_id: Optional[int] = None + owner_user_id: Optional[int] = None + collection_id: Optional[int] = None + + +""" +Define enum of all columns to support count and count(distinct) aggregations +""" + + +@strawberry.enum +class SequencingReadCountColumns(enum.Enum): + sample = "sample" + protocol = "protocol" + r1_file = "r1_file" + r2_file = "r2_file" + technology = "technology" + nucleic_acid = "nucleic_acid" + primer_file = "primer_file" + contigs = "contigs" + clearlabs_export = "clearlabs_export" + entity_id = "entity_id" + id = "id" + producing_run_id = "producing_run_id" + owner_user_id = "owner_user_id" + collection_id = "collection_id" + created_at = "created_at" + updated_at = "updated_at" + deleted_at = "deleted_at" + + +""" +All supported aggregation functions +""" + + +@strawberry.type +class SequencingReadAggregateFunctions: + # This is a hack to accept "distinct" and "columns" as arguments to "count" + @strawberry.field + def count( + self, distinct: Optional[bool] = False, columns: Optional[SequencingReadCountColumns] = None + ) -> Optional[int]: + # Count gets set with the proper value in the resolver, so we just return it here + return self.count # type: ignore + + sum: Optional[SequencingReadNumericalColumns] = None + avg: Optional[SequencingReadNumericalColumns] = None + min: Optional[SequencingReadMinMaxColumns] = None + max: Optional[SequencingReadMinMaxColumns] = None + stddev: Optional[SequencingReadNumericalColumns] = None + variance: Optional[SequencingReadNumericalColumns] = None + + +""" +Wrapper around SequencingReadAggregateFunctions +""" + + +@strawberry.type +class SequencingReadAggregate: + aggregate: Optional[SequencingReadAggregateFunctions] = None + + +""" +------------------------------------------------------------------------------ +Mutation types +------------------------------------------------------------------------------ +""" + + +@strawberry.input() +class SequencingReadCreateInput: + collection_id: int + sample_id: Optional[strawberry.ID] = None + protocol: SequencingProtocol + r1_file_id: Optional[strawberry.ID] = None + r2_file_id: Optional[strawberry.ID] = None + technology: SequencingTechnology + nucleic_acid: NucleicAcid + primer_file_id: Optional[strawberry.ID] = None + clearlabs_export: bool + + +@strawberry.input() +class SequencingReadUpdateInput: + collection_id: Optional[int] = None + sample_id: Optional[strawberry.ID] = None + protocol: Optional[SequencingProtocol] = None + r1_file_id: Optional[strawberry.ID] = None + r2_file_id: Optional[strawberry.ID] = None + technology: Optional[SequencingTechnology] = None + nucleic_acid: Optional[NucleicAcid] = None + primer_file_id: Optional[strawberry.ID] = None + clearlabs_export: Optional[bool] = None + + +""" +------------------------------------------------------------------------------ +Utilities +------------------------------------------------------------------------------ +""" + + +@strawberry.field(extensions=[DependencyExtension()]) +async def resolve_sequencing_reads( + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), + where: Optional[SequencingReadWhereClause] = None, +) -> typing.Sequence[SequencingRead]: + """ + Resolve SequencingRead objects. Used for queries (see api/queries.py). + """ + return await get_db_rows(db.SequencingRead, session, cerbos_client, principal, where, []) # type: ignore + + +def format_sequencing_read_aggregate_output(query_results: RowMapping) -> SequencingReadAggregateFunctions: + """ + Given a row from the DB containing the results of an aggregate query, + format the results using the proper GraphQL types. + """ + output = SequencingReadAggregateFunctions() + for aggregate_name, value in query_results.items(): + if aggregate_name == "count": + output.count = value + else: + aggregator_fn, col_name = aggregate_name.split("_", 1) + # Filter out the group_by key from the results if one was provided. + if aggregator_fn in aggregator_map.keys(): + if not getattr(output, aggregator_fn): + if aggregate_name in ["min", "max"]: + setattr(output, aggregator_fn, SequencingReadMinMaxColumns()) + else: + setattr(output, aggregator_fn, SequencingReadNumericalColumns()) + setattr(getattr(output, aggregator_fn), col_name, value) + return output + + +@strawberry.field(extensions=[DependencyExtension()]) +async def resolve_sequencing_reads_aggregate( + info: Info, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), + where: Optional[SequencingReadWhereClause] = None, +) -> SequencingReadAggregate: + """ + Aggregate values for SequencingRead objects. Used for queries (see api/queries.py). + """ + # Get the selected aggregate functions and columns to operate on + # TODO: not sure why selected_fields is a list + # The first list of selections will always be ["aggregate"], so just grab the first item + selections = info.selected_fields[0].selections[0].selections + rows = await get_aggregate_db_rows(db.SequencingRead, session, cerbos_client, principal, where, selections, []) # type: ignore + aggregate_output = format_sequencing_read_aggregate_output(rows) + return SequencingReadAggregate(aggregate=aggregate_output) + + +@strawberry.mutation(extensions=[DependencyExtension()]) +async def create_sequencing_read( + input: SequencingReadCreateInput, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), +) -> db.Entity: + """ + Create a new SequencingRead object. Used for mutations (see api/mutations.py). + """ + params = input.__dict__ + + # Validate that user can create entity in this collection + attr = {"collection_id": input.collection_id} + resource = Resource(id="NEW_ID", kind=db.SequencingRead.__tablename__, attr=attr) + if not cerbos_client.is_allowed("create", principal, resource): + raise PlatformicsException("Unauthorized: Cannot create entity in this collection") + + # Save to DB + params["owner_user_id"] = int(principal.id) + new_entity = db.SequencingRead(**params) + session.add(new_entity) + await session.commit() + return new_entity + + +@strawberry.mutation(extensions=[DependencyExtension()]) +async def update_sequencing_read( + input: SequencingReadUpdateInput, + where: SequencingReadWhereClauseMutations, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), +) -> Sequence[db.Entity]: + """ + Update SequencingRead objects. Used for mutations (see api/mutations.py). + """ + params = input.__dict__ + + # Need at least one thing to update + num_params = len([x for x in params if params[x] is not None]) + if num_params == 0: + raise PlatformicsException("No fields to update") + + # Fetch entities for update, if we have access to them + entities = await get_db_rows(db.SequencingRead, session, cerbos_client, principal, where, [], CerbosAction.UPDATE) + if len(entities) == 0: + raise PlatformicsException("Unauthorized: Cannot update entities") + + # Validate that the user has access to the new collection ID + if input.collection_id: + attr = {"collection_id": input.collection_id} + resource = Resource(id="SOME_ID", kind=db.SequencingRead.__tablename__, attr=attr) + if not cerbos_client.is_allowed(CerbosAction.UPDATE, principal, resource): + raise PlatformicsException("Unauthorized: Cannot access new collection") + + # Update DB + for entity in entities: + for key in params: + if params[key]: + setattr(entity, key, params[key]) + await session.commit() + return entities + + +@strawberry.mutation(extensions=[DependencyExtension()]) +async def delete_sequencing_read( + where: SequencingReadWhereClauseMutations, + session: AsyncSession = Depends(get_db_session, use_cache=False), + cerbos_client: CerbosClient = Depends(get_cerbos_client), + principal: Principal = Depends(require_auth_principal), +) -> Sequence[db.Entity]: + """ + Delete SequencingRead objects. Used for mutations (see api/mutations.py). + """ + # Fetch entities for deletion, if we have access to them + entities = await get_db_rows(db.SequencingRead, session, cerbos_client, principal, where, [], CerbosAction.DELETE) + if len(entities) == 0: + raise PlatformicsException("Unauthorized: Cannot delete entities") + + # Update DB + for entity in entities: + await session.delete(entity) + await session.commit() + return entities diff --git a/test_app/tests/output/cerbos/policies/contig.yaml b/test_app/tests/output/cerbos/policies/contig.yaml new file mode 100644 index 0000000..6df5ad5 --- /dev/null +++ b/test_app/tests/output/cerbos/policies/contig.yaml @@ -0,0 +1,22 @@ +# Auto-generated by running 'make codegen'. Do not edit. +# Make changes to the template codegen/templates/cerbos/policies/class_name.yaml.j2 instead. +# yaml-language-server: $schema=https://api.cerbos.dev/latest/cerbos/policy/v1/Policy.schema.json +apiVersion: api.cerbos.dev/v1 +resourcePolicy: + version: "default" + importDerivedRoles: + - common_roles + resource: "contig" + rules: + - actions: ['view', 'create', 'update'] + effect: EFFECT_ALLOW + derivedRoles: + - project_member + + - actions: ['download', 'delete'] + effect: EFFECT_ALLOW + derivedRoles: + - owner + schemas: + principalSchema: + ref: cerbos:///principal.json diff --git a/test_app/tests/output/cerbos/policies/genomic_range.yaml b/test_app/tests/output/cerbos/policies/genomic_range.yaml new file mode 100644 index 0000000..cdee2d7 --- /dev/null +++ b/test_app/tests/output/cerbos/policies/genomic_range.yaml @@ -0,0 +1,22 @@ +# Auto-generated by running 'make codegen'. Do not edit. +# Make changes to the template codegen/templates/cerbos/policies/class_name.yaml.j2 instead. +# yaml-language-server: $schema=https://api.cerbos.dev/latest/cerbos/policy/v1/Policy.schema.json +apiVersion: api.cerbos.dev/v1 +resourcePolicy: + version: "default" + importDerivedRoles: + - common_roles + resource: "genomic_range" + rules: + - actions: ['view', 'create', 'update'] + effect: EFFECT_ALLOW + derivedRoles: + - project_member + + - actions: ['download', 'delete'] + effect: EFFECT_ALLOW + derivedRoles: + - owner + schemas: + principalSchema: + ref: cerbos:///principal.json diff --git a/test_app/tests/output/cerbos/policies/sample.yaml b/test_app/tests/output/cerbos/policies/sample.yaml new file mode 100644 index 0000000..b34e82e --- /dev/null +++ b/test_app/tests/output/cerbos/policies/sample.yaml @@ -0,0 +1,22 @@ +# Auto-generated by running 'make codegen'. Do not edit. +# Make changes to the template codegen/templates/cerbos/policies/class_name.yaml.j2 instead. +# yaml-language-server: $schema=https://api.cerbos.dev/latest/cerbos/policy/v1/Policy.schema.json +apiVersion: api.cerbos.dev/v1 +resourcePolicy: + version: "default" + importDerivedRoles: + - common_roles + resource: "sample" + rules: + - actions: ['view', 'create', 'update'] + effect: EFFECT_ALLOW + derivedRoles: + - project_member + + - actions: ['download', 'delete'] + effect: EFFECT_ALLOW + derivedRoles: + - owner + schemas: + principalSchema: + ref: cerbos:///principal.json diff --git a/test_app/tests/output/cerbos/policies/sequencing_read.yaml b/test_app/tests/output/cerbos/policies/sequencing_read.yaml new file mode 100644 index 0000000..964dd3c --- /dev/null +++ b/test_app/tests/output/cerbos/policies/sequencing_read.yaml @@ -0,0 +1,22 @@ +# Auto-generated by running 'make codegen'. Do not edit. +# Make changes to the template codegen/templates/cerbos/policies/class_name.yaml.j2 instead. +# yaml-language-server: $schema=https://api.cerbos.dev/latest/cerbos/policy/v1/Policy.schema.json +apiVersion: api.cerbos.dev/v1 +resourcePolicy: + version: "default" + importDerivedRoles: + - common_roles + resource: "sequencing_read" + rules: + - actions: ['view', 'create', 'update'] + effect: EFFECT_ALLOW + derivedRoles: + - project_member + + - actions: ['download', 'delete'] + effect: EFFECT_ALLOW + derivedRoles: + - owner + schemas: + principalSchema: + ref: cerbos:///principal.json diff --git a/test_app/tests/output/database/models/__init__.py b/test_app/tests/output/database/models/__init__.py new file mode 100644 index 0000000..76a31d7 --- /dev/null +++ b/test_app/tests/output/database/models/__init__.py @@ -0,0 +1,20 @@ +""" +Make database models importable + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/database/models/__init__.py.j2 instead. +""" + +# isort: skip_file + +from sqlalchemy.orm import configure_mappers + +from platformics.database.models.base import Base, meta, Entity # noqa: F401 +from database.models.sample import Sample # noqa: F401 +from database.models.sequencing_read import SequencingRead # noqa: F401 +from database.models.genomic_range import GenomicRange # noqa: F401 +from database.models.contig import Contig # noqa: F401 + +from database.models.file import File, FileStatus # noqa: F401 + +configure_mappers() diff --git a/test_app/tests/output/database/models/contig.py b/test_app/tests/output/database/models/contig.py new file mode 100644 index 0000000..9cfffb4 --- /dev/null +++ b/test_app/tests/output/database/models/contig.py @@ -0,0 +1,32 @@ +""" +SQLAlchemy database model for Contig + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/database/models/class_name.py.j2 instead. +""" + +import uuid +from typing import TYPE_CHECKING + +from platformics.database.models.base import Entity +from sqlalchemy import ForeignKey, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +if TYPE_CHECKING: + from database.models.file import File + from database.models.sequencing_read import SequencingRead +else: + File = "File" + SequencingRead = "SequencingRead" + + +class Contig(Entity): + __tablename__ = "contig" + __mapper_args__ = {"polymorphic_identity": __tablename__, "polymorphic_load": "inline"} + sequencing_read_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("sequencing_read.entity_id"), nullable=True) + sequencing_read: Mapped["SequencingRead"] = relationship( + "SequencingRead", back_populates="contigs", foreign_keys=sequencing_read_id + ) + sequence: Mapped[str] = mapped_column(String, nullable=False) + entity_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), nullable=False, primary_key=True) diff --git a/test_app/tests/output/database/models/genomic_range.py b/test_app/tests/output/database/models/genomic_range.py new file mode 100644 index 0000000..059dab0 --- /dev/null +++ b/test_app/tests/output/database/models/genomic_range.py @@ -0,0 +1,32 @@ +""" +SQLAlchemy database model for GenomicRange + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/database/models/class_name.py.j2 instead. +""" + +import uuid +from typing import TYPE_CHECKING + +from platformics.database.models.base import Entity +from sqlalchemy import ForeignKey +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +if TYPE_CHECKING: + from database.models.file import File + from database.models.sequencing_read import SequencingRead +else: + File = "File" + SequencingRead = "SequencingRead" + + +class GenomicRange(Entity): + __tablename__ = "genomic_range" + __mapper_args__ = {"polymorphic_identity": __tablename__, "polymorphic_load": "inline"} + file_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("file.id"), nullable=True) + file: Mapped["File"] = relationship("File", foreign_keys=file_id) + sequencing_reads: Mapped[list[SequencingRead]] = relationship( + "SequencingRead", back_populates="primer_file", uselist=True, foreign_keys="SequencingRead.primer_file_id" + ) + entity_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), nullable=False, primary_key=True) diff --git a/test_app/tests/output/database/models/sample.py b/test_app/tests/output/database/models/sample.py new file mode 100644 index 0000000..a077737 --- /dev/null +++ b/test_app/tests/output/database/models/sample.py @@ -0,0 +1,36 @@ +""" +SQLAlchemy database model for Sample + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/database/models/class_name.py.j2 instead. +""" + +import uuid +import datetime +from typing import TYPE_CHECKING + +from platformics.database.models.base import Entity +from sqlalchemy import ForeignKey, String, Boolean, DateTime +from sqlalchemy.orm import Mapped, mapped_column, relationship + +if TYPE_CHECKING: + from database.models.file import File + from database.models.sequencing_read import SequencingRead +else: + File = "File" + SequencingRead = "SequencingRead" + + +class Sample(Entity): + __tablename__ = "sample" + __mapper_args__ = {"polymorphic_identity": __tablename__, "polymorphic_load": "inline"} + name: Mapped[str] = mapped_column(String, nullable=False) + sample_type: Mapped[str] = mapped_column(String, nullable=False) + water_control: Mapped[bool] = mapped_column(Boolean, nullable=False) + collection_date: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=True) + collection_location: Mapped[str] = mapped_column(String, nullable=False) + description: Mapped[str] = mapped_column(String, nullable=True) + sequencing_reads: Mapped[list[SequencingRead]] = relationship( + "SequencingRead", back_populates="sample", uselist=True, foreign_keys="SequencingRead.sample_id" + ) + entity_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), nullable=False, primary_key=True) diff --git a/test_app/tests/output/database/models/sequencing_read.py b/test_app/tests/output/database/models/sequencing_read.py new file mode 100644 index 0000000..6d24748 --- /dev/null +++ b/test_app/tests/output/database/models/sequencing_read.py @@ -0,0 +1,51 @@ +""" +SQLAlchemy database model for SequencingRead + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/database/models/class_name.py.j2 instead. +""" + +import uuid +from typing import TYPE_CHECKING + +from platformics.database.models.base import Entity +from sqlalchemy import ForeignKey, Enum, Boolean +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship +from support.enums import SequencingProtocol, SequencingTechnology, NucleicAcid + +if TYPE_CHECKING: + from database.models.file import File + from database.models.sample import Sample + from database.models.genomic_range import GenomicRange + from database.models.contig import Contig +else: + File = "File" + Sample = "Sample" + GenomicRange = "GenomicRange" + Contig = "Contig" + + +class SequencingRead(Entity): + __tablename__ = "sequencing_read" + __mapper_args__ = {"polymorphic_identity": __tablename__, "polymorphic_load": "inline"} + sample_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("sample.entity_id"), nullable=True) + sample: Mapped["Sample"] = relationship("Sample", back_populates="sequencing_reads", foreign_keys=sample_id) + protocol: Mapped[SequencingProtocol] = mapped_column(Enum(SequencingProtocol, native_enum=False), nullable=False) + r1_file_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("file.id"), nullable=True) + r1_file: Mapped["File"] = relationship("File", foreign_keys=r1_file_id) + r2_file_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("file.id"), nullable=True) + r2_file: Mapped["File"] = relationship("File", foreign_keys=r2_file_id) + technology: Mapped[SequencingTechnology] = mapped_column( + Enum(SequencingTechnology, native_enum=False), nullable=False + ) + nucleic_acid: Mapped[NucleicAcid] = mapped_column(Enum(NucleicAcid, native_enum=False), nullable=False) + primer_file_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("genomic_range.entity_id"), nullable=True) + primer_file: Mapped["GenomicRange"] = relationship( + "GenomicRange", back_populates="sequencing_reads", foreign_keys=primer_file_id + ) + contigs: Mapped[list[Contig]] = relationship( + "Contig", back_populates="sequencing_read", uselist=True, foreign_keys="Contig.sequencing_read_id" + ) + clearlabs_export: Mapped[bool] = mapped_column(Boolean, nullable=False) + entity_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), nullable=False, primary_key=True) diff --git a/test_app/tests/output/support/enums.py b/test_app/tests/output/support/enums.py new file mode 100644 index 0000000..1b1e7f3 --- /dev/null +++ b/test_app/tests/output/support/enums.py @@ -0,0 +1,49 @@ +""" +GraphQL enums + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/support/enums.py.j2 instead. +""" + +import strawberry +import enum + + +@strawberry.enum +class FileStatus(enum.Enum): + SUCCESS = "SUCCESS" + FAILED = "FAILED" + PENDING = "PENDING" + + +@strawberry.enum +class FileAcessProtocol(enum.Enum): + s3 = "s3" + + +@strawberry.enum +class NucleicAcid(enum.Enum): + RNA = "RNA" + DNA = "DNA" + + +@strawberry.enum +class SequencingProtocol(enum.Enum): + ampliseq = "ampliseq" + artic = "artic" + artic_v3 = "artic_v3" + artic_v4 = "artic_v4" + artic_v5 = "artic_v5" + combined_msspe_artic = "combined_msspe_artic" + covidseq = "covidseq" + midnight = "midnight" + msspe = "msspe" + snap = "snap" + varskip = "varskip" + easyseq = "easyseq" + + +@strawberry.enum +class SequencingTechnology(enum.Enum): + Illumina = "Illumina" + Nanopore = "Nanopore" diff --git a/test_app/tests/output/test_infra/factories/contig.py b/test_app/tests/output/test_infra/factories/contig.py new file mode 100644 index 0000000..a908c54 --- /dev/null +++ b/test_app/tests/output/test_infra/factories/contig.py @@ -0,0 +1,37 @@ +""" +Factory for generating Contig objects. + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/test_infra/factories/class_name.py.j2 instead. +""" + +# ruff: noqa: E501 Line too long + +import factory +from database.models import Contig +from test_infra.factories.main import CommonFactory +from test_infra.factories.sequencing_read import SequencingReadFactory +from factory import Faker, fuzzy +from faker_biology.bioseq import Bioseq +from faker_biology.physiology import Organ +from faker_enum import EnumProvider + +Faker.add_provider(Bioseq) +Faker.add_provider(Organ) +Faker.add_provider(EnumProvider) + + +class ContigFactory(CommonFactory): + class Meta: + sqlalchemy_session = None # workaround for a bug in factoryboy + model = Contig + # Match entity_id with existing db rows to determine whether we should + # create a new row or not. + sqlalchemy_get_or_create = ("entity_id",) + + sequencing_read = factory.SubFactory( + SequencingReadFactory, + owner_user_id=factory.SelfAttribute("..owner_user_id"), + collection_id=factory.SelfAttribute("..collection_id"), + ) + sequence = fuzzy.FuzzyText() diff --git a/test_app/tests/output/test_infra/factories/genomic_range.py b/test_app/tests/output/test_infra/factories/genomic_range.py new file mode 100644 index 0000000..c814a0f --- /dev/null +++ b/test_app/tests/output/test_infra/factories/genomic_range.py @@ -0,0 +1,36 @@ +""" +Factory for generating GenomicRange objects. + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/test_infra/factories/class_name.py.j2 instead. +""" + +# ruff: noqa: E501 Line too long + +import factory +from database.models import GenomicRange +from test_infra.factories.main import CommonFactory, FileFactory +from factory import Faker +from faker_biology.bioseq import Bioseq +from faker_biology.physiology import Organ +from faker_enum import EnumProvider + +Faker.add_provider(Bioseq) +Faker.add_provider(Organ) +Faker.add_provider(EnumProvider) + + +class GenomicRangeFactory(CommonFactory): + class Meta: + sqlalchemy_session = None # workaround for a bug in factoryboy + model = GenomicRange + # Match entity_id with existing db rows to determine whether we should + # create a new row or not. + sqlalchemy_get_or_create = ("entity_id",) + + file = factory.RelatedFactory( + FileFactory, + factory_related_name="entity", + entity_field_name="file", + file_format="fastq", + ) diff --git a/test_app/tests/output/test_infra/factories/sample.py b/test_app/tests/output/test_infra/factories/sample.py new file mode 100644 index 0000000..f893e95 --- /dev/null +++ b/test_app/tests/output/test_infra/factories/sample.py @@ -0,0 +1,36 @@ +""" +Factory for generating Sample objects. + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/test_infra/factories/class_name.py.j2 instead. +""" + +# ruff: noqa: E501 Line too long + +import factory +from database.models import Sample +from test_infra.factories.main import CommonFactory +from factory import Faker, fuzzy +from faker_biology.bioseq import Bioseq +from faker_biology.physiology import Organ +from faker_enum import EnumProvider + +Faker.add_provider(Bioseq) +Faker.add_provider(Organ) +Faker.add_provider(EnumProvider) + + +class SampleFactory(CommonFactory): + class Meta: + sqlalchemy_session = None # workaround for a bug in factoryboy + model = Sample + # Match entity_id with existing db rows to determine whether we should + # create a new row or not. + sqlalchemy_get_or_create = ("entity_id",) + + name = fuzzy.FuzzyText() + sample_type = fuzzy.FuzzyText() + water_control = factory.Faker("boolean") + collection_date = factory.Faker("date") + collection_location = fuzzy.FuzzyText() + description = fuzzy.FuzzyText() diff --git a/test_app/tests/output/test_infra/factories/sequencing_read.py b/test_app/tests/output/test_infra/factories/sequencing_read.py new file mode 100644 index 0000000..9e1426d --- /dev/null +++ b/test_app/tests/output/test_infra/factories/sequencing_read.py @@ -0,0 +1,73 @@ +""" +Factory for generating SequencingRead objects. + +Auto-generated by running 'make codegen'. Do not edit. +Make changes to the template codegen/templates/test_infra/factories/class_name.py.j2 instead. +""" + +# ruff: noqa: E501 Line too long + +import factory +from database.models import SequencingRead +from test_infra.factories.main import CommonFactory, FileFactory +from test_infra.factories.sample import SampleFactory +from test_infra.factories.genomic_range import GenomicRangeFactory +from factory import Faker, fuzzy +from faker_biology.bioseq import Bioseq +from faker_biology.physiology import Organ +from faker_enum import EnumProvider + +Faker.add_provider(Bioseq) +Faker.add_provider(Organ) +Faker.add_provider(EnumProvider) + + +class SequencingReadFactory(CommonFactory): + class Meta: + sqlalchemy_session = None # workaround for a bug in factoryboy + model = SequencingRead + # Match entity_id with existing db rows to determine whether we should + # create a new row or not. + sqlalchemy_get_or_create = ("entity_id",) + + sample = factory.SubFactory( + SampleFactory, + owner_user_id=factory.SelfAttribute("..owner_user_id"), + collection_id=factory.SelfAttribute("..collection_id"), + ) + protocol = fuzzy.FuzzyChoice( + [ + "ampliseq", + "artic", + "artic_v3", + "artic_v4", + "artic_v5", + "combined_msspe_artic", + "covidseq", + "midnight", + "msspe", + "snap", + "varskip", + "easyseq", + ] + ) + r1_file = factory.RelatedFactory( + FileFactory, + factory_related_name="entity", + entity_field_name="r1_file", + file_format="fastq", + ) + r2_file = factory.RelatedFactory( + FileFactory, + factory_related_name="entity", + entity_field_name="r2_file", + file_format="fastq", + ) + technology = fuzzy.FuzzyChoice(["Illumina", "Nanopore"]) + nucleic_acid = fuzzy.FuzzyChoice(["RNA", "DNA"]) + primer_file = factory.SubFactory( + GenomicRangeFactory, + owner_user_id=factory.SelfAttribute("..owner_user_id"), + collection_id=factory.SelfAttribute("..collection_id"), + ) + clearlabs_export = factory.Faker("boolean") diff --git a/test_app/tests/test_aggregate_queries.py b/test_app/tests/test_aggregate_queries.py index dfeb722..f4fdb58 100644 --- a/test_app/tests/test_aggregate_queries.py +++ b/test_app/tests/test_aggregate_queries.py @@ -3,10 +3,12 @@ """ import pytest -from conftest import GQLTestClient, SessionStorage from platformics.database.connect import SyncDB +from conftest import GQLTestClient, SessionStorage from test_infra.factories.sample import SampleFactory from test_infra.factories.sequencing_read import SequencingReadFactory +from test_infra.factories.contig import ContigFactory +from test_infra.factories.upstream_database import UpstreamDatabaseFactory @pytest.mark.asyncio @@ -60,12 +62,12 @@ async def test_basic_aggregate_query( } """ output = await gql_client.query(query, user_id=user_id, member_projects=[project_id, secondary_project_id]) - avg_collectionId = output["data"]["samplesAggregate"]["aggregate"]["avg"]["collectionId"] - count = output["data"]["samplesAggregate"]["aggregate"]["count"] - max_collectionLocation = output["data"]["samplesAggregate"]["aggregate"]["max"]["collectionLocation"] - min_collectionLocation = output["data"]["samplesAggregate"]["aggregate"]["min"]["collectionLocation"] - stddev_collectionId = output["data"]["samplesAggregate"]["aggregate"]["stddev"]["collectionId"] - sum_ownerUserId = output["data"]["samplesAggregate"]["aggregate"]["sum"]["ownerUserId"] + avg_collectionId = output["data"]["samplesAggregate"]["aggregate"][0]["avg"]["collectionId"] + count = output["data"]["samplesAggregate"]["aggregate"][0]["count"] + max_collectionLocation = output["data"]["samplesAggregate"]["aggregate"][0]["max"]["collectionLocation"] + min_collectionLocation = output["data"]["samplesAggregate"]["aggregate"][0]["min"]["collectionLocation"] + stddev_collectionId = output["data"]["samplesAggregate"]["aggregate"][0]["stddev"]["collectionId"] + sum_ownerUserId = output["data"]["samplesAggregate"]["aggregate"][0]["sum"]["ownerUserId"] assert avg_collectionId == 189 assert count == 5 @@ -106,8 +108,8 @@ async def test_nested_aggregate_query( } """ results = await gql_client.query(query, user_id=111, member_projects=[888]) - assert results["data"]["samples"][0]["sequencingReadsAggregate"]["aggregate"]["count"] == 2 - assert results["data"]["samples"][1]["sequencingReadsAggregate"]["aggregate"]["count"] == 3 + assert results["data"]["samples"][0]["sequencingReadsAggregate"]["aggregate"][0]["count"] == 2 + assert results["data"]["samples"][1]["sequencingReadsAggregate"]["aggregate"][0]["count"] == 3 @pytest.mark.asyncio @@ -148,3 +150,273 @@ async def test_count_distinct_query( """ results = await gql_client.query(query, user_id=111, member_projects=[888]) assert results["data"]["samplesAggregate"]["aggregate"][0]["count"] == 2 + + +@pytest.mark.asyncio +async def test_groupby_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can perform a groupby query + """ + with sync_db.session() as session: + SessionStorage.set_session(session) + SampleFactory.create_batch(2, owner_user_id=111, collection_id=888, collection_location="San Francisco, CA") + SampleFactory.create_batch(3, owner_user_id=111, collection_id=888, collection_location="Mountain View, CA") + + query = """ + query MyQuery { + samplesAggregate { + aggregate { + groupBy { + collectionLocation + } + count + } + } + } + """ + results = await gql_client.query(query, user_id=111, member_projects=[888]) + aggregate = results["data"]["samplesAggregate"]["aggregate"] + for group in aggregate: + if group["groupBy"]["collectionLocation"] == "San Francisco, CA": + assert group["count"] == 2 + elif group["groupBy"]["collectionLocation"] == "Mountain View, CA": + assert group["count"] == 3 + + +@pytest.mark.asyncio +async def test_groupby_query_with_nested_fields( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can perform a groupby query with nested fields + """ + with sync_db.session() as session: + SessionStorage.set_session(session) + sample_1 = SampleFactory(owner_user_id=111, collection_id=888, collection_location="San Francisco, CA") + sample_2 = SampleFactory(owner_user_id=111, collection_id=888, collection_location="Mountain View, CA") + SequencingReadFactory.create_batch( + 2, sample=sample_1, owner_user_id=sample_1.owner_user_id, collection_id=sample_1.collection_id + ) + SequencingReadFactory.create_batch( + 3, sample=sample_2, owner_user_id=sample_2.owner_user_id, collection_id=sample_2.collection_id + ) + + query = """ + query MyQuery { + sequencingReadsAggregate { + aggregate { + groupBy { + sample { + collectionLocation + } + } + count + } + } + } + """ + results = await gql_client.query(query, user_id=111, member_projects=[888]) + aggregate = results["data"]["sequencingReadsAggregate"]["aggregate"] + for group in aggregate: + if group["groupBy"]["sample"]["collectionLocation"] == "San Francisco, CA": + assert group["count"] == 2 + elif group["groupBy"]["sample"]["collectionLocation"] == "Mountain View, CA": + assert group["count"] == 3 + + +@pytest.mark.asyncio +async def test_groupby_query_with_multiple_fields( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can perform a groupby query with fields nested at multiple levels + """ + with sync_db.session() as session: + SessionStorage.set_session(session) + sample_1 = SampleFactory(owner_user_id=111, collection_id=888, collection_location="San Francisco, CA") + sample_2 = SampleFactory(owner_user_id=111, collection_id=888, collection_location="Mountain View, CA") + SequencingReadFactory.create_batch( + 1, + sample=sample_1, + owner_user_id=sample_1.owner_user_id, + collection_id=sample_1.collection_id, + technology="Illumina", + ) + SequencingReadFactory.create_batch( + 2, + sample=sample_1, + owner_user_id=sample_1.owner_user_id, + collection_id=sample_1.collection_id, + technology="Nanopore", + ) + SequencingReadFactory.create_batch( + 3, + sample=sample_2, + owner_user_id=sample_2.owner_user_id, + collection_id=sample_2.collection_id, + technology="Illumina", + ) + SequencingReadFactory.create_batch( + 4, + sample=sample_2, + owner_user_id=sample_2.owner_user_id, + collection_id=sample_2.collection_id, + technology="Nanopore", + ) + + query = """ + query MyQuery { + sequencingReadsAggregate { + aggregate { + groupBy { + sample { + collectionLocation + } + technology + } + count + } + } + } + """ + results = await gql_client.query(query, user_id=111, member_projects=[888]) + aggregate = results["data"]["sequencingReadsAggregate"]["aggregate"] + for group in aggregate: + if ( + group["groupBy"]["sample"]["collectionLocation"] == "San Francisco, CA" + and group["groupBy"]["technology"] == "Illumina" + ): + assert group["count"] == 1 + elif ( + group["groupBy"]["sample"]["collectionLocation"] == "San Francisco, CA" + and group["groupBy"]["technology"] == "Nanopore" + ): + assert group["count"] == 2 + elif ( + group["groupBy"]["sample"]["collectionLocation"] == "Mountain View, CA" + and group["groupBy"]["technology"] == "Illumina" + ): + assert group["count"] == 3 + elif ( + group["groupBy"]["sample"]["collectionLocation"] == "Mountain View, CA" + and group["groupBy"]["technology"] == "Nanopore" + ): + assert group["count"] == 4 + + +@pytest.mark.asyncio +async def test_deeply_nested_groupby_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can perform a deeply nested groupby query + """ + + user_id = 12345 + project_id = 123 + + with sync_db.session() as session: + SessionStorage.set_session(session) + SessionStorage.set_session(session) + + upstream_db_1 = UpstreamDatabaseFactory(owner_user_id=user_id, collection_id=project_id, name="NCBI") + upstream_db_2 = UpstreamDatabaseFactory(owner_user_id=user_id, collection_id=project_id, name="GTDB") + contig_1 = ContigFactory(owner_user_id=user_id, collection_id=project_id, upstream_database=upstream_db_1) + contig_2 = ContigFactory(owner_user_id=user_id, collection_id=project_id, upstream_database=upstream_db_2) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contigs=contig_1) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contigs=contig_1) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contigs=contig_2) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contigs=contig_2) + + query = """ + query MyQuery { + sequencingReadsAggregate { + aggregate { + count + groupBy { + contigs { + upstreamDatabase { + name + } + } + } + } + } + } + """ + results = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + aggregate = results["data"]["sequencingReadsAggregate"]["aggregate"] + for group in aggregate: + if group["groupBy"]["contigs"]["upstreamDatabase"]["name"] == "NCBI": + assert group["count"] == 2 + elif group["groupBy"]["contigs"]["upstreamDatabase"]["name"] == "GTDB": + assert group["count"] == 2 + + +@pytest.mark.asyncio +async def test_soft_deleted_data_not_in_aggregate_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that soft-deleted data is not included in aggregate queries + """ + user_id = 12345 + project_id = 123 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + SampleFactory.create_batch( + 4, collection_location="San Francisco, CA", owner_user_id=user_id, collection_id=project_id + ) + sample_to_delete = SampleFactory( + collection_location="Mountain View, CA", owner_user_id=user_id, collection_id=project_id + ) + + # Verify that there are 5 samples in the database + aggregate_query = """ + query MyQuery { + samplesAggregate { + aggregate { + count + } + } + } + """ + output = await gql_client.query(aggregate_query, user_id=user_id, member_projects=[project_id]) + assert output["data"]["samplesAggregate"]["aggregate"][0]["count"] == 5 + + # Soft-delete a sample by updating its deleted_at field + soft_delete_query = f""" + mutation MyMutation {{ + updateSample (where: {{ id: {{ _eq: "{sample_to_delete.id}" }} }}, input: {{ deletedAt: "2021-01-01T00:00:00Z" }}) {{ + id + }} + }} + """ + output = await gql_client.query(soft_delete_query, user_id=user_id, member_projects=[project_id]) + assert output["data"]["updateSample"][0]["id"] == str(sample_to_delete.id) + + # The soft-deleted sample should not be included in the aggregate query anymore + output = await gql_client.query(aggregate_query, user_id=user_id, member_projects=[project_id]) + assert output["data"]["samplesAggregate"]["aggregate"][0]["count"] == 4 + + # The soft-deleted sample should be included in the aggregate query if we specifically ask for it + aggregate_soft_deleted_query = """ + query MyQuery { + samplesAggregate(where: { deletedAt: { _is_null: false } }) { + aggregate { + count + } + } + } + """ + output = await gql_client.query(aggregate_soft_deleted_query, user_id=user_id, member_projects=[project_id]) + assert output["data"]["samplesAggregate"]["aggregate"][0]["count"] == 1 diff --git a/test_app/tests/test_authorization.py b/test_app/tests/test_authorization.py index 3f794e9..8c22988 100644 --- a/test_app/tests/test_authorization.py +++ b/test_app/tests/test_authorization.py @@ -2,10 +2,14 @@ Authorization spot-checks """ +import uuid + import pytest -from platformics.database.connect import SyncDB +from database.models import Sample from conftest import GQLTestClient, SessionStorage from test_infra.factories.sample import SampleFactory +from test_infra.factories.sequencing_read import SequencingReadFactory +from platformics.database.connect import SyncDB @pytest.mark.asyncio @@ -45,3 +49,235 @@ async def test_collection_authorization( output = await gql_client.query(query, user_id=user_id, member_projects=project_ids) assert len(output["data"]["samples"]) == num_results assert {sample["collectionLocation"] for sample in output["data"]["samples"]} == set(cities) + + +@pytest.mark.asyncio +async def test_system_fields_only_creatable_by_system( + gql_client: GQLTestClient, +) -> None: + """ + Make sure only system users can set system fields + """ + user_id = 12345 + project_ids = [333] + producing_run_id = str(uuid.uuid4()) + query = f""" + mutation MyMutation {{ + createGenomicRange(input: {{collectionId: {project_ids[0]}, producingRunId: "{producing_run_id}" }}) {{ + collectionId + producingRunId + }} + }} + """ + + # Our mutation should have been saved because we are a system user. + output = await gql_client.query(query, user_id=user_id, member_projects=project_ids, service_identity="workflows") + assert output["data"]["createGenomicRange"]["collectionId"] == 333 + assert output["data"]["createGenomicRange"]["producingRunId"] == producing_run_id + + # Our mutation should have ignored producing run because we're not a system user. + output = await gql_client.query(query, user_id=user_id, member_projects=project_ids) + assert output["data"]["createGenomicRange"]["collectionId"] == 333 + assert output["data"]["createGenomicRange"]["producingRunId"] is None + + +@pytest.mark.asyncio +async def test_system_fields_only_mutable_by_system( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Make sure only system users can mutate system fields + """ + user_id = 12345 + project_ids = [333] + + with sync_db.session() as session: + SessionStorage.set_session(session) + sample = SampleFactory.create(collection_location="City1", owner_user_id=999, collection_id=333) + + # Fetch all samples + def get_query(input_value: str) -> str: + return f""" + mutation MyMutation {{ + updateSample( + where: {{id: {{_eq: "{sample.id}" }} }}, + input: {{systemMutableField: "{input_value}"}}) {{ + id + systemMutableField + }} + }} + """ + + output = await gql_client.query( + get_query("hello"), user_id=user_id, member_projects=project_ids, service_identity="workflows" + ) + # Our mutation should have been saved because we are a system user. + assert output["data"]["updateSample"][0]["systemMutableField"] == "hello" + + output = await gql_client.query(get_query("goodbye"), user_id=user_id, member_projects=project_ids) + # This field should have been ignored because we're not a system user + assert output["data"]["updateSample"][0]["systemMutableField"] == "hello" + + +@pytest.mark.asyncio +async def test_system_types_only_mutable_by_system( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Make sure only system users can mutate system fields + """ + user_id = 12345 + project_ids = [333] + + # Fetch all samples + def get_update_query(id: str, input_value: str) -> str: + return f""" + mutation MyMutation {{ + updateSystemWritableOnlyType( + where: {{id: {{_eq: "{id}" }} }}, + input: {{name: "{input_value}"}}) {{ + id + name + }} + }} + """ + + # Fetch all samples + create_query = f""" + mutation MyMutation {{ + createSystemWritableOnlyType( + input: {{ + collectionId: 333, + name: "row name here" + }} + ) {{ id, name }} + }} + """ + + # Our mutation should have been saved because we are a system user. + output = await gql_client.query( + create_query, user_id=user_id, member_projects=project_ids, service_identity="workflows" + ) + assert output["data"]["createSystemWritableOnlyType"]["name"] == "row name here" + item_id = output["data"]["createSystemWritableOnlyType"]["id"] + + # Our mutation should have failed with an authorization error because we are not a system user + output = await gql_client.query(create_query, user_id=user_id, member_projects=project_ids) + assert "Unauthorized" in output["errors"][0]["message"] + assert "not creatable" in output["errors"][0]["message"] + + # This field should have been ignored because we're not a system user + output = await gql_client.query(get_update_query(item_id, "new_name"), user_id=user_id, member_projects=project_ids) + assert "Unauthorized" in output["errors"][0]["message"] + assert "not mutable" in output["errors"][0]["message"] + + # This field should have been ignored because we're not a system user + output = await gql_client.query( + get_update_query(item_id, "new_name"), + user_id=user_id, + member_projects=project_ids, + service_identity="workflows", + ) + assert output["data"]["updateSystemWritableOnlyType"][0]["name"] == "new_name" + + +@pytest.mark.asyncio +async def test_update_wont_associate_inaccessible_relationships( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Make sure users can only see samples in collections they have access to. + """ + owner_user_id = 333 + user_id = 12345 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + test_sample0 = SampleFactory.create(collection_location="City1", owner_user_id=999, collection_id=444) + test_sample1 = SampleFactory.create(collection_location="City2", owner_user_id=999, collection_id=444) + test_sample2 = SampleFactory.create(collection_location="City3", owner_user_id=999, collection_id=444) + test_sample3 = SampleFactory.create(collection_location="City4", owner_user_id=999, collection_id=444) + test_sequencing_read = SequencingReadFactory.create( + sample=test_sample0, owner_user_id=owner_user_id, collection_id=111 + ) + + def gen_query(test_sample: Sample) -> str: + # Fetch all samples + query = f""" + mutation MyMutation {{ + updateSequencingRead( + where: {{id: {{_eq: "{test_sequencing_read.id}"}} }}, + input: {{ + sampleId: "{test_sample.id}" + }} + ) {{ + id + sample {{ + id + }} + }} + }} + """ + return query + + # We are a member of 444 so this should work. + output = await gql_client.query(gen_query(test_sample1), user_id=user_id, member_projects=[111, 444]) + assert output["data"]["updateSequencingRead"][0]["sample"]["id"] == str(test_sample1.id) + + # We are NOT a member of 444 so this should break. + output = await gql_client.query(gen_query(test_sample2), user_id=user_id, member_projects=[111, 555]) + assert "Unauthorized" in output["errors"][0]["message"] + + # Project 444 is public so this should work + output = await gql_client.query( + gen_query(test_sample3), user_id=user_id, member_projects=[111, 555], viewer_projects=[444] + ) + assert output["data"]["updateSequencingRead"][0]["sample"]["id"] == str(test_sample3.id) + + +@pytest.mark.asyncio +async def test_create_wont_associate_inaccessible_relationships( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Make sure users can only see samples in collections they have access to. + """ + owner_user_id = 333 + user_id = 12345 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + test_sample = SampleFactory.create(collection_location="City2", owner_user_id=owner_user_id, collection_id=444) + + # Fetch all samples + query = f""" + mutation MyMutation {{ + createSequencingRead( + input: {{ + collectionId: 111, + sampleId: "{test_sample.id}", + protocol: artic_v4, + technology: Illumina, + nucleicAcid: RNA, + clearlabsExport: false + }} + ) {{ id }} + }} + """ + # We are a member of 444 so this should work. + output = await gql_client.query(query, user_id=user_id, member_projects=[111, 444]) + assert output["data"]["createSequencingRead"]["id"] + + # We are NOT a member of 444 so this should break. + output = await gql_client.query(query, user_id=user_id, member_projects=[111, 555]) + assert "Unauthorized" in output["errors"][0]["message"] + + # Project 444 is public so this should work + output = await gql_client.query(query, user_id=user_id, member_projects=[111, 555], viewer_projects=[444]) + assert output["data"]["createSequencingRead"]["id"] diff --git a/test_app/tests/test_basic_queries.py b/test_app/tests/test_basic_queries.py index ac9d741..7c4d569 100644 --- a/test_app/tests/test_basic_queries.py +++ b/test_app/tests/test_basic_queries.py @@ -2,10 +2,13 @@ Test basic queries and mutations """ +import datetime import pytest from platformics.database.connect import SyncDB from conftest import GQLTestClient, SessionStorage -from test_infra.factories.sample import SampleFactory +from platformics.codegen.tests.output.test_infra.factories.sample import SampleFactory + +date_now = datetime.datetime.now() @pytest.mark.asyncio @@ -24,13 +27,25 @@ async def test_graphql_query( with sync_db.session() as session: SessionStorage.set_session(session) SampleFactory.create_batch( - 2, collection_location="San Francisco, CA", owner_user_id=user_id, collection_id=project_id + 2, + collection_location="San Francisco, CA", + collection_date=date_now, + owner_user_id=user_id, + collection_id=project_id, ) SampleFactory.create_batch( - 6, collection_location="Mountain View, CA", owner_user_id=user_id, collection_id=project_id + 6, + collection_location="Mountain View, CA", + collection_date=date_now, + owner_user_id=user_id, + collection_id=project_id, ) SampleFactory.create_batch( - 4, collection_location="Phoenix, AZ", owner_user_id=secondary_user_id, collection_id=9999 + 4, + collection_location="Phoenix, AZ", + collection_date=date_now, + owner_user_id=secondary_user_id, + collection_id=9999, ) # Fetch all samples @@ -69,6 +84,7 @@ async def test_graphql_mutations( sampleType: "Type 1" waterControl: false collectionLocation: "San Francisco, CA" + collectionDate: "2024-01-01" collectionId: 123 }) { id, diff --git a/test_app/tests/test_bulk_download_deletion.py b/test_app/tests/test_bulk_download_deletion.py new file mode 100644 index 0000000..69ceb31 --- /dev/null +++ b/test_app/tests/test_bulk_download_deletion.py @@ -0,0 +1,72 @@ +""" +Test deletion of bulkDownloads > 7 days old +""" + +import pytest +import datetime +from platformics.database.connect import SyncDB +from conftest import SessionStorage, GQLTestClient, FileFactory +from platformics.codegen.tests.output.test_infra.factories.bulk_download import BulkDownloadFactory + + +@pytest.mark.asyncio +async def test_delete_old_bulk_downloads( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can delete bulk downloads older than 7 days + """ + user_id = 12345 + project_id = 123 + + # Create mock data: 3 current bulk downloads, 2 bulk downloads from 1 week ago, and 5 bulk downloads from 1 month ago + with sync_db.session() as session: + SessionStorage.set_session(session) + current_time = datetime.datetime.now() + one_week_ago = current_time - datetime.timedelta(days=7) + one_month_ago = current_time - datetime.timedelta(days=30) + + current_bulk_downloads = BulkDownloadFactory.create_batch(3, owner_user_id=user_id, collection_id=None) + one_week_old_bulk_downloads = BulkDownloadFactory.create_batch( + 2, owner_user_id=user_id, collection_id=None, created_at=one_week_ago + ) + one_month_old_bulk_downloads = BulkDownloadFactory.create_batch( + 5, owner_user_id=user_id, collection_id=None, created_at=one_month_ago + ) + all_old_bulk_downloads = one_week_old_bulk_downloads + one_month_old_bulk_downloads + FileFactory.update_file_ids() + + # Delete old bulk downloads + query = """ + mutation MyMutation { + deleteOldBulkDownloads { + id + } + } + """ + + # Verify that the mutation can't be called by a non-system user + result = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + assert result["data"] is None + assert "Unauthorized" + + # Verify that the mutation deletes all bulk downloads older than 7 days + result = await gql_client.query(query, user_id=user_id, member_projects=[project_id], service_identity="rails") + assert len(result["data"]["deleteOldBulkDownloads"]) == 7 + assert [bd["id"] for bd in result["data"]["deleteOldBulkDownloads"]] == [ + str(bd.id) for bd in all_old_bulk_downloads + ] + + # Check that current bulk downloads are still there + query = """ + query MyQuery { + bulkDownloads { + id + } + } + """ + + result = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + assert len(result["data"]["bulkDownloads"]) == 3 + assert [bd["id"] for bd in result["data"]["bulkDownloads"]] == [str(bd.id) for bd in current_bulk_downloads] diff --git a/test_app/tests/test_bulk_download_policy.py b/test_app/tests/test_bulk_download_policy.py new file mode 100644 index 0000000..9455c45 --- /dev/null +++ b/test_app/tests/test_bulk_download_policy.py @@ -0,0 +1,123 @@ +""" +Test collection_id policy for entities and bulk downloads +1. Test that users cannot create normal entities without a collection_id, or update them to have a null collection_id. +2. Test that users cannot create bulk downloads WITH a collection_id. +3. Test that only owners can view their own bulk downloads +""" + +import pytest +from platformics.database.connect import SyncDB +from conftest import SessionStorage, GQLTestClient +from platformics.codegen.tests.output.test_infra.factories.bulk_download import BulkDownloadFactory +from platformics.codegen.tests.output.test_infra.factories.sample import SampleFactory + + +@pytest.mark.asyncio +async def test_null_collection_id_for_regular_entities( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that users cannot create normal entities without a collection_id, or update them to have a null collection_id. + """ + owner_user_id = 333 + collection_id = 444 + + # Attempt to create a sample without a collection_id + query = f""" + mutation MyMutation {{ + createSample( + input: {{ + name: "No collection id", + sampleType: "Type 1", + waterControl: false, + collectionLocation: "San Francisco, CA", + collectionDate: "2024-01-01", + }} + ) {{ id }} + }} + """ + + output = await gql_client.query(query, user_id=owner_user_id, member_projects=[collection_id]) + assert "Unauthorized: Cannot create entity in this collection" in output["errors"][0]["message"] + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + sample = SampleFactory.create(name="Test Sample", owner_user_id=owner_user_id, collection_id=collection_id) + + # Attempt to update the sample to have a null collection_id + query = f""" + mutation MyMutation {{ + updateSample( + where: {{id: {{_eq: "{sample.id}"}} }}, + input: {{ + collectionId: null + }} + ) {{ id }} + }} + """ + + output = await gql_client.query(query, user_id=owner_user_id, member_projects=[collection_id]) + assert ( + "Field 'collectionId' is not defined by type 'SampleUpdateInput'. Did you mean 'collectionDate'?" + in output["errors"][0]["message"] + ) + + +@pytest.mark.asyncio +async def test_null_collection_id_for_bulk_downloads( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that users cannot create bulk downloads WITH a collection_id. + """ + owner_user_id = 333 + collection_id = 444 + + # Attempt to create a bulk download with a collection_id + query = f""" + mutation MyMutation {{ + createBulkDownload( + input: {{ + collectionId: {collection_id}, + downloadDisplayName: "Test Bulk Download", + }} + ) {{ id }} + }} + """ + + output = await gql_client.query(query, user_id=owner_user_id, member_projects=[collection_id]) + assert "Unauthorized: Cannot create entity in this collection" in output["errors"][0]["message"] + + +@pytest.mark.asyncio +async def test_view_bulk_downloads( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that only owners can view their own bulk downloads + """ + user_id = 111 + other_user_id = 222 + project_id = 123 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + # Create 4 bulk downloads owned by user_id, and 3 for another user + BulkDownloadFactory.create_batch(4, owner_user_id=user_id, collection_id=None) + BulkDownloadFactory.create_batch(3, owner_user_id=other_user_id, collection_id=None) + + # Fetch all bulk downloads + query = """ + query MyQuery { + bulkDownloads { + id + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + assert len(output["data"]["bulkDownloads"]) == 4 diff --git a/test_app/tests/test_cascade_deletion.py b/test_app/tests/test_cascade_deletion.py new file mode 100644 index 0000000..af06a83 --- /dev/null +++ b/test_app/tests/test_cascade_deletion.py @@ -0,0 +1,84 @@ +""" +Test cascade deletion +""" + +import pytest +from mypy_boto3_s3.client import S3Client +from platformics.database.connect import SyncDB +from conftest import SessionStorage, GQLTestClient, FileFactory +from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory + + +@pytest.mark.asyncio +async def test_cascade_delete( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can make cascade deletions + """ + user_id = 12345 + project_id = 123 + + # Create mock data: 2 SequencingReads, each with a different Sample, and each with R1/R2 + with sync_db.session() as session: + SessionStorage.set_session(session) + sequencing_reads = SequencingReadFactory.create_batch( + 2, technology="Illumina", owner_user_id=user_id, collection_id=project_id + ) + FileFactory.update_file_ids() + + # Delete the first Sample + query = f""" + mutation MyMutation {{ + deleteSample (where: {{ id: {{ _eq: "{sequencing_reads[0].sample_id}" }} }}) {{ + id + }} + }} + """ + result = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + assert result["data"]["deleteSample"][0]["id"] == str(sequencing_reads[0].sample_id) + + # The first SequencingRead should be deleted + query = f""" + query MyQuery {{ + sequencingReadsAggregate(where: {{ id: {{ _eq: "{sequencing_reads[0].entity_id}"}} }}) {{ + aggregate {{ count }} + }} + }} + """ + result = await gql_client.query(query, member_projects=[project_id]) + assert result["data"]["sequencingReadsAggregate"]["aggregate"][0]["count"] == 0 + + # The second SequencingRead should still exist + query = f""" + query MyQuery {{ + sequencingReadsAggregate(where: {{ id: {{ _eq: "{sequencing_reads[1].entity_id}"}} }}) {{ + aggregate {{ count }} + }} + }} + """ + result = await gql_client.query(query, member_projects=[project_id]) + assert result["data"]["sequencingReadsAggregate"]["aggregate"][0]["count"] == 1 + + # Files from the first SequencingRead should be deleted + query = f""" + query MyQuery {{ + files(where: {{ entityId: {{ _eq: "{sequencing_reads[0].entity_id}" }} }}) {{ + id + }} + }} + """ + result = await gql_client.query(query, member_projects=[project_id]) + assert len(result["data"]["files"]) == 0 + + # Files from the second SequencingRead should NOT be deleted + query = f""" + query MyQuery {{ + files(where: {{ entityId: {{ _eq: "{sequencing_reads[1].entity_id}" }} }}) {{ + id + }} + }} + """ + result = await gql_client.query(query, member_projects=[project_id]) + assert len(result["data"]["files"]) == 2 diff --git a/test_app/tests/test_error_handling.py b/test_app/tests/test_error_handling.py new file mode 100644 index 0000000..fd92f35 --- /dev/null +++ b/test_app/tests/test_error_handling.py @@ -0,0 +1,54 @@ +""" +Test basic error handling +""" + +import pytest +from conftest import GQLTestClient + + +@pytest.mark.asyncio +async def test_unauthorized_error( + gql_client: GQLTestClient, +) -> None: + """ + Validate that expected errors don't get masked by our error handler. + """ + query = """ + mutation createOneSample { + createSample(input: { + name: "Test Sample" + sampleType: "Type 1" + waterControl: false + collectionLocation: "San Francisco, CA" + collectionDate: "2024-01-01" + collectionId: 123 + }) { + id, + collectionLocation + } + } + """ + output = await gql_client.query(query, member_projects=[456]) + + # Make sure we haven't masked expected errors. + assert output["data"] is None + assert "Unauthorized: Cannot create entity in this collection" in output["errors"][0]["message"] + + +@pytest.mark.asyncio +async def test_python_error( + gql_client: GQLTestClient, +) -> None: + """ + Validate that unexpected errors do get masked by our error handler. + """ + query = """ + query causeException { + uncaughtException + } + """ + output = await gql_client.query(query, member_projects=[456]) + + # Make sure we have masked unexpected errors. + assert output["data"] is None + assert "Unexpected error" in output["errors"][0]["message"] diff --git a/test_app/tests/test_field_constraints.py b/test_app/tests/test_field_constraints.py new file mode 100644 index 0000000..00b4c45 --- /dev/null +++ b/test_app/tests/test_field_constraints.py @@ -0,0 +1,135 @@ +""" +Authorization spot-checks +""" + +import uuid +import json +from typing import Any + +import pytest +from database.models import Sample +from conftest import GQLTestClient, SessionStorage +from platformics.codegen.tests.output.test_infra.factories.constraint_checked_type import ConstraintCheckedTypeFactory +from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory +from platformics.database.connect import SyncDB + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "field_name,invalid_values,valid_values", + [ + ("length3To8", ["a", "", " a ", " abcdefghi "], ["abcde", " abc ", "abcdefgh "]), + ("regexFormatCheck", ["hi", "aaa-bbb-ccc", "100-100-10000"], ["222-33-4444", " aaa333-44-5555ccc "]), + ("minValue0", [-7, -14], [0, 3, 999999]), + ("maxValue9", [11, 44444], [0, -17, 9]), + ("minValue0MaxValue9", [-39, 10, 9999], [0, 6, 9]), + ("float1dot1To2dot2", [0, 3.3], [1.1, 1.2, 2.2]), + ("noStringChecks", [], ["", "aasdfasdfasd", " lorem ipsum !@#$%^&*() "]), + ("noIntChecks", [], [65, 0, 400, 7]), + ("noFloatChecks", [], [-65.50, 0.00001, 400.6, 7]), + ("boolField", [], [True, False]), + ("enumField", [], ["RNA", "DNA"]), + ], +) +async def test_create_validation( + field_name: str, + invalid_values: list[Any], + valid_values: list[Any], + gql_client: GQLTestClient, +) -> None: + """ + Make sure input field validation works. + """ + user_id = 12345 + project_ids = [333] + + def get_query(field_name: str, value: Any) -> str: + if "enum" not in field_name: + value = json.dumps(value) + query = f""" + mutation MyMutation {{ + createConstraintCheckedType(input: {{collectionId: {project_ids[0]}, {field_name}: {value} }}) {{ + collectionId + {field_name} + }} + }} + """ + return query + + # These should succeed + for value in valid_values: + query = get_query(field_name, value) + output = await gql_client.query(query, user_id=user_id, member_projects=project_ids) + if type(value) == str: + assert output["data"]["createConstraintCheckedType"][field_name] == value.strip() + else: + assert output["data"]["createConstraintCheckedType"][field_name] == value + + # These should fail + for value in invalid_values: + query = get_query(field_name, value) + output = await gql_client.query(query, user_id=user_id, member_projects=project_ids) + assert "Validation Error" in output["errors"][0]["message"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "field_name,invalid_values,valid_values", + [ + ("length3To8", ["a", "", " a ", " abcdefghi "], ["abcde", " abc ", "abcdefgh "]), + ("regexFormatCheck", ["hi", "aaa-bbb-ccc", "100-100-10000"], ["222-33-4444", " aaa333-44-5555ccc "]), + ("minValue0", [-7, -14], [0, 3, 999999]), + ("maxValue9", [11, 44444], [0, -17, 9]), + ("minValue0MaxValue9", [-39, 10, 9999], [0, 6, 9]), + ("float1dot1To2dot2", [0, 3.3], [1.1, 1.2, 2.2]), + ("noStringChecks", [], ["", "aasdfasdfasd", " lorem ipsum !@#$%^&*() "]), + ("noIntChecks", [], [65, 0, 400, 7]), + ("noFloatChecks", [], [-65.50, 0.00001, 400.6, 7]), + ("boolField", [], [True, False]), + ("enumField", [], ["RNA", "DNA"]), + ], +) +async def test_update_validation( + field_name: str, + invalid_values: list[Any], + valid_values: list[Any], + gql_client: GQLTestClient, + sync_db: SyncDB, +) -> None: + """ + Make sure input field validation works. + """ + with sync_db.session() as session: + SessionStorage.set_session(session) + instance = ConstraintCheckedTypeFactory.create(owner_user_id=999, collection_id=333) + + user_id = 12345 + project_ids = [333] + + def get_query(id: uuid.UUID, field_name: str, value: Any) -> str: + if "enum" not in field_name: + value = json.dumps(value) + query = f""" + mutation MyMutation {{ + updateConstraintCheckedType(where: {{id: {{_eq: "{id}" }} }}, input: {{ {field_name}: {value} }}) {{ + collectionId + {field_name} + }} + }} + """ + return query + + # These should succeed + for value in valid_values: + query = get_query(instance.id, field_name, value) + output = await gql_client.query(query, user_id=user_id, member_projects=project_ids) + if type(value) == str: + assert output["data"]["updateConstraintCheckedType"][0][field_name] == value.strip() + else: + assert output["data"]["updateConstraintCheckedType"][0][field_name] == value + + # These should fail + for value in invalid_values: + query = get_query(instance.id, field_name, value) + output = await gql_client.query(query, user_id=user_id, member_projects=project_ids) + assert "Validation Error" in output["errors"][0]["message"] diff --git a/test_app/tests/test_field_visibility.py b/test_app/tests/test_field_visibility.py new file mode 100644 index 0000000..5779ab3 --- /dev/null +++ b/test_app/tests/test_field_visibility.py @@ -0,0 +1,184 @@ +""" +Test basic queries and mutations +""" + +import datetime + +import pytest +from conftest import GQLTestClient + +date_now = datetime.datetime.now() + + +@pytest.mark.asyncio +async def test_hidden_fields( + gql_client: GQLTestClient, +) -> None: + """ + Test that we can hide fields from the GQL interface + """ + user_id = 12345 + project_id = 123 + + # Introspect the GenomicRange type + query = """ + query MyQuery { + __type(name: "GenomicRange") { + name + fields { + name + type { + name + kind + } + } + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + field_names = [field["name"] for field in output["data"]["__type"]["fields"]] + + # entityId is a hidden field, make sure it's not part if our type def. + assert "entityId" not in field_names + + # ownerUserId is a regular field inherited from Entity, make sure we can see it. + assert "ownerUserId" in field_names + + # file is a regular field on the GR table, make sure we can see it. + assert "file" in field_names + assert "fileId" in field_names + + +@pytest.mark.asyncio +async def test_hidden_mutations( + gql_client: GQLTestClient, +) -> None: + """ + Test that we don't generate mutations unless they make sense + """ + user_id = 12345 + project_id = 123 + + # Introspect the Mutations fields + query = """ + query MyQuery { + __schema { + mutationType { + fields { + name + } + } + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + mutations = [field["name"] for field in output["data"]["__schema"]["mutationType"]["fields"]] + + # ImmutableType has `mutable: false` set on the entire table so it shouldn't have a mutation + assert "updateImmutableType" not in mutations + + # ImmutableType still allows creation. + assert "createImmutableType" in mutations + + +# Make sure we only allow certain fields to be set at entity creation time. +@pytest.mark.asyncio +async def test_update_fields( + gql_client: GQLTestClient, +) -> None: + """ + Test that we don't show immutable fields in update mutations. + """ + user_id = 12345 + project_id = 123 + + # Introspect the Mutations fields + query = """ + fragment FullType on __Type { + kind + name + inputFields { + ...InputValue + } + } + fragment InputValue on __InputValue { + name + type { + ...TypeRef + } + defaultValue + } + fragment TypeRef on __Type { + kind + name + ofType { + kind + name + } + } + query IntrospectionQuery { + __schema { + types { + ...FullType + } + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + create_type = [item for item in output["data"]["__schema"]["types"] if item["name"] == "SequencingReadUpdateInput"][ + 0 + ] + fields = [field["name"] for field in create_type["inputFields"]] + # We have a limited subset of mutable fields on SequencingRead + assert set(fields) == set(["nucleicAcid", "clearlabsExport", "technology", "sampleId", "deletedAt"]) + + +# Make sure we only allow certain fields to be set at entity creation time. +@pytest.mark.asyncio +async def test_creation_fields( + gql_client: GQLTestClient, +) -> None: + """ + Test that we don't generate mutations unless they actually do something + """ + user_id = 12345 + project_id = 123 + + # Introspect the Mutations fields + query = """ + fragment FullType on __Type { + kind + name + inputFields { + ...InputValue + } + } + fragment InputValue on __InputValue { + name + type { + ...TypeRef + } + defaultValue + } + fragment TypeRef on __Type { + kind + name + ofType { + kind + name + } + } + query IntrospectionQuery { + __schema { + types { + ...FullType + } + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + create_type = [item for item in output["data"]["__schema"]["types"] if item["name"] == "GenomicRangeCreateInput"][0] + fields = [field["name"] for field in create_type["inputFields"]] + # Producing run id and collection id are always settable on a new entity. + # producingRunId is only settable by a system user, and collectionId is settable by users. + assert set(fields) == set(["producingRunId", "collectionId", "deletedAt"]) diff --git a/test_app/tests/test_file_concatenation.py b/test_app/tests/test_file_concatenation.py index d1ea3e0..edc5d46 100644 --- a/test_app/tests/test_file_concatenation.py +++ b/test_app/tests/test_file_concatenation.py @@ -4,10 +4,10 @@ import pytest import requests -from conftest import GQLTestClient, SessionStorage from mypy_boto3_s3.client import S3Client from platformics.database.connect import SyncDB -from test_infra.factories.sequencing_read import SequencingReadFactory +from conftest import SessionStorage, GQLTestClient +from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory @pytest.mark.parametrize( @@ -27,8 +27,8 @@ async def test_concatenation( user_id = 12345 project_id = 111 member_projects = [project_id] - fasta_file_1 = f"tests/fixtures/{file_name_1}" - fasta_file_2 = f"tests/fixtures/{file_name_2}" + fasta_file_1 = f"test_infra/fixtures/{file_name_1}" + fasta_file_2 = f"test_infra/fixtures/{file_name_2}" # Create mock data with sync_db.session() as session: diff --git a/test_app/tests/test_file_mutations.py b/test_app/tests/test_file_mutations.py index 68336f6..4a97cf5 100644 --- a/test_app/tests/test_file_mutations.py +++ b/test_app/tests/test_file_mutations.py @@ -3,15 +3,15 @@ """ import os - import pytest +import typing import sqlalchemy as sa -from conftest import FileFactory, GQLTestClient, SessionStorage -from database.models import SequencingRead from mypy_boto3_s3.client import S3Client from platformics.database.connect import SyncDB -from platformics.database.models import File, FileStatus -from test_infra.factories.sequencing_read import SequencingReadFactory +from database.models import File, FileStatus +from conftest import SessionStorage, FileFactory, GQLTestClient +from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory +from platformics.codegen.tests.output.database.models import SequencingRead @pytest.mark.asyncio @@ -35,7 +35,7 @@ async def test_file_validation( files = session.execute(sa.select(File)).scalars().all() file = list(filter(lambda file: file.entity_field_name == "r1_file", files))[0] - valid_fastq_file = "tests/fixtures/test1.fastq" + valid_fastq_file = "test_infra/fixtures/test1.fastq" file_size = os.stat(valid_fastq_file).st_size moto_client.put_object(Bucket=file.namespace, Key=file.path, Body=open(valid_fastq_file, "rb")) @@ -106,16 +106,8 @@ async def test_invalid_fastq( @pytest.mark.parametrize( "member_projects,project_id,entity_field", [ - ( - [456], - 123, - "r1_file", - ), # Can't create file for entity you don't have access to - ( - [123], - 123, - "does_not_exist", - ), # Can't create file for entity that isn't connected to a valid file type + ([456], 123, "r1_file"), # Can't create file for entity you don't have access to + ([123], 123, "does_not_exist"), # Can't create file for entity that isn't connected to a valid file type ([123], 123, "r1_file"), # Can create file for entity you have access to ], ) @@ -170,7 +162,7 @@ async def test_upload_file( assert output["errors"] is not None return - # Moto produces a hard-coded tokens + # Moto produces hard-coded tokens assert output["data"]["uploadFile"]["credentials"]["accessKeyId"].endswith("EXAMPLE") assert output["data"]["uploadFile"]["credentials"]["secretAccessKey"].endswith("EXAMPLEKEY") @@ -197,7 +189,7 @@ async def test_create_file( # Upload a fastq file to a mock bucket so we can create a file object from it file_namespace = "local-bucket" file_path = "test1.fastq" - file_path_local = "tests/fixtures/test1.fastq" + file_path_local = "test_infra/fixtures/test1.fastq" file_size = os.stat(file_path_local).st_size with open(file_path_local, "rb") as fp: moto_client.put_object(Bucket=file_namespace, Key=file_path, Body=fp) @@ -211,7 +203,7 @@ async def test_create_file( file: {{ name: "{file_path}", fileFormat: "fastq", - protocol: "s3", + protocol: s3, namespace: "{file_namespace}", path: "{file_path}" }} @@ -221,5 +213,95 @@ async def test_create_file( }} }} """ - output = await gql_client.query(mutation, member_projects=[123]) + output = await gql_client.query(mutation, member_projects=[123], service_identity="workflows") assert output["data"]["createFile"]["size"] == file_size + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "file_path,multiple_files_for_one_path,should_delete", + [ + ("nextgen/test1.fastq", False, True), + ("bla/test1.fastq", False, False), + ("nextgen/test1.fastq", True, False), + ("bla/test1.fastq", True, False), + ], +) +async def test_delete_from_s3( + file_path: str, + should_delete: bool, + multiple_files_for_one_path: bool, + sync_db: SyncDB, + gql_client: GQLTestClient, + moto_client: S3Client, + monkeypatch: typing.Any, +) -> None: + """ + Test that we delete a file from S3 under the right circumstances + """ + user1_id = 12345 + project1_id = 123 + user2_id = 67890 + project2_id = 456 + bucket = "local-bucket" + + # Patch the S3 client to make sure tests are operating on the same mock bucket + monkeypatch.setattr(File, "get_s3_client", lambda: moto_client) + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + SequencingReadFactory.create(owner_user_id=user1_id, collection_id=project1_id) + FileFactory.update_file_ids() + session.commit() + files = session.execute(sa.select(File)).scalars().all() + file = list(filter(lambda file: file.entity_field_name == "r1_file", files))[0] + file.path = file_path + file.namespace = bucket # set the bucket to make sure the mock file is in the right place + session.commit() + + # Also test the case where multiple files point to the same path + if multiple_files_for_one_path: + sequencing_read = SequencingReadFactory.create(owner_user_id=user2_id, collection_id=project2_id) + FileFactory.update_file_ids() + session.commit() + session.refresh(sequencing_read) + sequencing_read.r1_file.path = file_path + sequencing_read.r1_file.namespace = bucket + session.commit() + + valid_fastq_file = "test_infra/fixtures/test1.fastq" + moto_client.put_object(Bucket=file.namespace, Key=file.path, Body=open(valid_fastq_file, "rb")) + + # Delete SequencingRead and cascade to File objects + query = f""" + mutation MyMutation {{ + deleteSequencingRead(where: {{ id: {{ _eq: "{file.entity_id}" }} }}) {{ + id + }} + }} + """ + + # File should exist on S3 before the deletion + assert "Contents" in moto_client.list_objects(Bucket=file.namespace, Prefix=file.path) + + # Issue deletion + result = await gql_client.query(query, user_id=user1_id, member_projects=[project1_id]) + assert result["data"]["deleteSequencingRead"][0]["id"] == str(file.entity_id) + + # Make sure file either does or does not exist + if should_delete: + assert "Contents" not in moto_client.list_objects(Bucket=file.namespace, Prefix=file.path) + else: + assert "Contents" in moto_client.list_objects(Bucket=file.namespace, Prefix=file.path) + + # Make sure File object doesn't exist either + query = f""" + query MyQuery {{ + files(where: {{ id: {{ _eq: "{file.id}" }} }}) {{ + id + }} + }} + """ + result = await gql_client.query(query, user_id=user1_id, member_projects=[project1_id]) + assert result["data"]["files"] == [] diff --git a/test_app/tests/test_file_queries.py b/test_app/tests/test_file_queries.py index 0cb555c..5ec5504 100644 --- a/test_app/tests/test_file_queries.py +++ b/test_app/tests/test_file_queries.py @@ -3,9 +3,11 @@ """ import pytest -from platformics.database.connect import SyncDB +import sqlalchemy as sa +from database.models.file import File from conftest import FileFactory, GQLTestClient, SessionStorage -from test_infra.factories.sequencing_read import SequencingReadFactory +from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory +from platformics.database.connect import SyncDB @pytest.mark.asyncio @@ -46,12 +48,11 @@ async def test_file_query( } """ output = await gql_client.query(query, member_projects=[project1_id]) - # Each SequencingRead results in 5 files: + # Each SequencingRead results in 3 files: # r1_file, r2_file # primer_file -> GenomicRange file - # GenomicRange produces ReferenceGenome -> file and file_index - # so we expect 8 * 5 = 40 files. - assert len(output["data"]["files"]) == 40 + # so we expect 8 * 3 = 24 files. + assert len(output["data"]["files"]) == 24 for file in output["data"]["files"]: assert file["path"] is not None assert file["entity"]["collectionId"] == project1_id diff --git a/test_app/tests/test_file_uploads.py b/test_app/tests/test_file_uploads.py index 38a79b0..5de2b5d 100644 --- a/test_app/tests/test_file_uploads.py +++ b/test_app/tests/test_file_uploads.py @@ -7,7 +7,7 @@ from mypy_boto3_s3.client import S3Client from platformics.database.connect import SyncDB from conftest import SessionStorage, GQLTestClient -from test_infra.factories.sequencing_read import SequencingReadFactory +from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory @pytest.mark.asyncio @@ -60,7 +60,7 @@ async def test_upload_process( credentials = output["data"]["uploadFile"]["credentials"] # Upload the file - fastq_file = "tests/fixtures/test1.fastq" + fastq_file = "test_infra/fixtures/test1.fastq" fastq_file_size = os.stat(fastq_file).st_size moto_client.put_object(Bucket=credentials["namespace"], Key=credentials["path"], Body=open(fastq_file, "rb")) @@ -93,7 +93,7 @@ async def test_upload_process_multiple_files_per_entity( user_id = 12345 project_id = 111 member_projects = [project_id] - fastq_file = "tests/fixtures/test1.fastq" + fastq_file = "test_infra/fixtures/test1.fastq" # Create mock data with sync_db.session() as session: diff --git a/test_app/tests/test_limit_offset_queries.py b/test_app/tests/test_limit_offset_queries.py new file mode 100644 index 0000000..5fd079c --- /dev/null +++ b/test_app/tests/test_limit_offset_queries.py @@ -0,0 +1,108 @@ +""" +Test limit/offset on top-level queries +""" + +import datetime +import pytest +from platformics.database.connect import SyncDB +from conftest import GQLTestClient, SessionStorage +from platformics.codegen.tests.output.test_infra.factories.sample import SampleFactory + +date_now = datetime.datetime.now() + + +@pytest.mark.asyncio +async def test_limit_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can limit the number of samples returned + """ + user_id = 12345 + project_id = 123 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + SampleFactory.create_batch(10, owner_user_id=user_id, collection_id=project_id) + + # Fetch all samples + query = """ + query limitQuery { + samples(limitOffset: {limit: 3}) { + id, + collectionLocation + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + assert len(output["data"]["samples"]) == 3 + + +@pytest.mark.asyncio +async def test_offset_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can offset the number of samples returned + """ + user_id = 12345 + project_id = 123 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + for i in range(10): + SampleFactory.create(name=f"Sample {i}", owner_user_id=user_id, collection_id=project_id) + + # Fetch all samples + all_samples_query = """ + query allSamples { + samples(orderBy: {name: asc}) { + name + } + } + """ + + output = await gql_client.query(all_samples_query, user_id=user_id, member_projects=[project_id]) + all_sample_names = [sample["name"] for sample in output["data"]["samples"]] + + # Fetch samples with limit: 3, offset: 3 + query = """ + query offsetQuery { + samples(limitOffset: {limit: 3, offset: 3}, orderBy: {name: asc}) { + name + } + } + """ + + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + offset_sample_names = [sample["name"] for sample in output["data"]["samples"]] + assert offset_sample_names == all_sample_names[3:6] + + # If we offset by 10, we should get an empty list + query = """ + query offsetQuery { + samples(limitOffset: {limit: 1, offset: 10}, orderBy: {name: asc}) { + name + } + } + """ + + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + assert len(output["data"]["samples"]) == 0 + + # If a user includes an offset without a limit, we should get an error + query = """ + query offsetQuery { + samples(limitOffset: {offset: 1}, orderBy: {name: asc}) { + name + } + } + """ + + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + assert output["data"] is None + assert "Cannot use offset without limit" in output["errors"][0]["message"] diff --git a/test_app/tests/test_nested_queries.py b/test_app/tests/test_nested_queries.py index d1adb7b..22fde81 100644 --- a/test_app/tests/test_nested_queries.py +++ b/test_app/tests/test_nested_queries.py @@ -3,14 +3,13 @@ """ import base64 -from collections import defaultdict - import pytest -from conftest import GQLTestClient, SessionStorage -from platformics.api.types.entities import Entity +from collections import defaultdict from platformics.database.connect import SyncDB -from test_infra.factories.sample import SampleFactory -from test_infra.factories.sequencing_read import SequencingReadFactory +from conftest import GQLTestClient, SessionStorage +from platformics.codegen.tests.output.test_infra.factories.sample import SampleFactory +from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory +from api.types.entities import Entity def get_id(entity: Entity) -> str: diff --git a/test_app/tests/test_schemas/overrides/api/.gitignore b/test_app/tests/test_schemas/overrides/api/.gitignore new file mode 100644 index 0000000..f833585 --- /dev/null +++ b/test_app/tests/test_schemas/overrides/api/.gitignore @@ -0,0 +1 @@ +queries.py.j2 diff --git a/test_app/tests/test_schemas/overrides/api/extra_test_code.py.j2 b/test_app/tests/test_schemas/overrides/api/extra_test_code.py.j2 new file mode 100644 index 0000000..cb2dfe0 --- /dev/null +++ b/test_app/tests/test_schemas/overrides/api/extra_test_code.py.j2 @@ -0,0 +1,4 @@ + @strawberry.field + def uncaught_exception(self) -> str: + # Trigger an AttributeException + return self.kaboom # type: ignore diff --git a/test_app/tests/test_schemas/platformics.yaml b/test_app/tests/test_schemas/platformics.yaml new file mode 100644 index 0000000..31d3440 --- /dev/null +++ b/test_app/tests/test_schemas/platformics.yaml @@ -0,0 +1,443 @@ +id: https://czid.org/entities/schema/platformics +title: CZID Platformics Bio-Entities Schema +name: platformics +default_range: string + +types: + string: + uri: xsd:string + base: str + description: A character string + + integer: + uri: xsd:integer + base: int + description: An integer + + uuid: + uri: xsd:string + typeof: str + base: str + description: A UUID + +enums: + FileStatus: + permissible_values: + SUCCESS: + FAILED: + PENDING: + FileAcessProtocol: + permissible_values: + s3: + NucleicAcid: + permissible_values: + RNA: + DNA: + SequencingProtocol: + permissible_values: + ampliseq: + artic: + artic_v3: + artic_v4: + artic_v5: + combined_msspe_artic: + covidseq: + midnight: + msspe: + snap: + varskip: + easyseq: + SequencingTechnology: + permissible_values: + Illumina: + Nanopore: + TaxonLevel: + permissible_values: + level_sublevel: + level_species: + level_genus: + level_family: + level_order: + level_class: + level_phylum: + level_kingdom: + level_superkingdom: + FileAccessProtocol: + permissible_values: + s3: + description: This file is accessible via the (AWS) S3 protocol + FileUploadClient: + permissible_values: + browser: + description: File uploaded from the user's browser + cli: + description: File uploaded from the CLI + s3: + description: File uploaded from S3 + basespace: + description: File uploaded from Illumina Basespace Cloud + +classes: + Entity: + attributes: + id: + identifier: true + range: uuid + readonly: true # The API handles generating the values for these fields + required: true + producing_run_id: + range: uuid + minimum_value: 0 + annotations: + mutable: false # This field can't be modified by an `Update` mutation + system_writable_only: True + owner_user_id: + range: integer + minimum_value: 0 + readonly: true + required: true + collection_id: + range: integer + minimum_value: 0 + required: false + annotations: + mutable: false + created_at: + range: date + required: true + readonly: true + updated_at: + range: date + readonly: true + deleted_at: + range: date + # NOTE - the LinkML schema doesn't support a native "plural name" field as far as I can tell, so + # we're using an annotation here to tack on the extra functionality that we need. We do this because + # English pluralization is hard, and we don't want to have to write a custom pluralization function. + # This basically means we now have our own "dialect" of LinkML to worry about. We may want to see if + # pluralization can be added to the core spec in the future. + annotations: + plural: Entities + + File: + attributes: + id: + identifier: true + range: uuid + entity_field_name: + range: string + required: true + entity: + range: Entity + required: true + status: + range: FileStatus + required: true + protocol: + range: FileAccessProtocol + required: true + namespace: + range: string + required: true + path: + range: string + required: true + file_format: + range: string + required: true + compression_type: + range: string + size: + range: integer + minimum_value: 0 + + Sample: + is_a: Entity + mixins: + - EntityMixin + attributes: + name: + range: string + required: true + sample_type: + range: string + required: true + water_control: + range: boolean + required: true + collection_date: + range: date + collection_location: + range: string + required: true + notes: + range: string + sequencing_reads: + range: SequencingRead + multivalued: true + inverse: SequencingRead.sample + annotations: + cascade_delete: true + system_mutable_field: + range: string + annotations: + system_writable_only: True + annotations: + plural: Samples + + SequencingRead: + is_a: Entity + mixins: + - EntityMixin + attributes: + sample: + range: Sample + inverse: Sample.sequencing_reads + protocol: + range: SequencingProtocol + required: true + annotations: + mutable: false + r1_file: + range: File + readonly: true + annotations: + cascade_delete: true + r2_file: + range: File + readonly: true + annotations: + cascade_delete: true + technology: + range: SequencingTechnology + required: true + nucleic_acid: + range: NucleicAcid + required: true + primer_file: + range: GenomicRange + inverse: GenomicRange.sequencing_reads + annotations: + mutable: false + consensus_genomes: + range: ConsensusGenome + inverse: ConsensusGenome.sequence_read + multivalued: true + annotations: + cascade_delete: true + clearlabs_export: + range: boolean + required: true + taxon: + range: Taxon + inverse: Taxon.sequencing_reads + annotations: + mutable: false + annotations: + plural: SequencingReads + + ConsensusGenome: + is_a: Entity + mixins: + - EntityMixin + attributes: + sequence_read: + range: SequencingRead + required: true + inverse: SequencingRead.consensus_genomes + annotations: + mutable: false + sequence: + readonly: true + range: File + annotations: + cascade_delete: true + metrics: + range: MetricConsensusGenome + inverse: MetricConsensusGenome.consensus_genome + inlined: true + annotations: + cascade_delete: true + intermediate_outputs: + range: File + readonly: true + annotations: + cascade_delete: true + annotations: + plural: ConsensusGenomes + + MetricConsensusGenome: + is_a: Entity + mixins: + - EntityMixin + attributes: + consensus_genome: + range: ConsensusGenome + inverse: ConsensusGenome.metrics + required: true + annotations: + mutable: false + total_reads: + range: integer + minimum_value: 0 + maximum_value: 999999999999 + annotations: + mutable: false + mapped_reads: + range: integer + annotations: + mutable: false + annotations: + plural: MetricsConsensusGenomes + + GenomicRange: + is_a: Entity + mixins: + - EntityMixin + attributes: + file: + range: File + readonly: true + sequencing_reads: + range: SequencingRead + inverse: SequencingRead.primer_file + multivalued: true + annotations: + plural: GenomicRanges + + Taxon: + is_a: Entity + mixins: + - EntityMixin + attributes: + name: + range: string + required: true + is_phage: + range: boolean + required: true + upstream_database: + range: UpstreamDatabase + required: true + inverse: UpstreamDatabase.taxa + upstream_database_identifier: + range: string + required: true + level: + range: TaxonLevel + required: true + sequencing_reads: + range: SequencingRead + inverse: SequencingRead.taxon + multivalued: true + annotations: + plural: Taxa + + UpstreamDatabase: + is_a: Entity + mixins: + - EntityMixin + attributes: + name: + range: string + required: true + taxa: + range: Taxon + multivalued: true + inverse: Taxon.upstream_database + annotations: + plural: UpstreamDatabases + + BulkDownload: + is_a: Entity + mixins: + - EntityMixin + attributes: + download_display_name: + range: string + required: true + annotations: + mutable: false + file: + range: File + readonly: true + annotations: + cascade_delete: true + annotations: + plural: BulkDownloads + + SystemWritableOnlyType: + is_a: Entity + mixins: + - EntityMixin + attributes: + name: + range: string + required: true + annotations: + system_writable_only: true + plural: SystemWritableOnlyTypes + + ImmutableType: + is_a: Entity + mixins: + - EntityMixin + attributes: + name: + range: string + required: true + annotations: + mutable: false + plural: ImmutableTypes + + ConstraintCheckedType: + is_a: Entity + mixins: + - EntityMixin + attributes: + length_3_to_8: + range: string + annotations: + minimum_length: 3 + maximum_length: 8 + regex_format_check: + range: string + pattern: '\d{3}-\d{2}-\d{4}' + min_value_0: + range: integer + minimum_value: 0 + enum_field: + range: NucleicAcid + bool_field: + range: boolean + max_value_9: + range: integer + maximum_value: 9 + min_value_0_max_value_9: + range: integer + minimum_value: 0 + maximum_value: 9 + float_1dot1_to_2dot2: + range: float + minimum_value: 1.1 + maximum_value: 2.2 + no_string_checks: + range: string + no_int_checks: + range: integer + no_float_checks: + range: float + annotations: + plural: ConstraintCheckedTypes + + EntityMixin: + mixin: true + attributes: + entity_id: + required: true + readonly: true + range: uuid + identifier: true + inverse: entity.id + annotations: + hidden: true diff --git a/test_app/tests/test_sorting_queries.py b/test_app/tests/test_sorting_queries.py new file mode 100644 index 0000000..6230262 --- /dev/null +++ b/test_app/tests/test_sorting_queries.py @@ -0,0 +1,243 @@ +""" +Test queries with an ORDER BY clause +""" + +import pytest +from platformics.database.connect import SyncDB +from conftest import GQLTestClient, SessionStorage +from platformics.codegen.tests.output.test_infra.factories.sample import SampleFactory +from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory +from platformics.codegen.tests.output.test_infra.factories.taxon import TaxonFactory +from platformics.codegen.tests.output.test_infra.factories.upstream_database import UpstreamDatabaseFactory + + +@pytest.mark.asyncio +async def test_basic_order_by_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can add an ORDER BY clause to a query + """ + user_id = 12345 + project_id = 123 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + SampleFactory.create_batch( + 2, collection_location="San Francisco, CA", owner_user_id=user_id, collection_id=project_id + ) + SampleFactory.create_batch( + 1, collection_location="Mountain View, CA", owner_user_id=user_id, collection_id=project_id + ) + SampleFactory.create_batch( + 2, collection_location="Los Angeles, CA", owner_user_id=user_id, collection_id=project_id + ) + + # Fetch all samples, in descending order of collection location + query = """ + query MyQuery { + samples(orderBy: {collectionLocation: desc}) { + collectionLocation + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + locations = [sample["collectionLocation"] for sample in output["data"]["samples"]] + assert locations == [ + "San Francisco, CA", + "San Francisco, CA", + "Mountain View, CA", + "Los Angeles, CA", + "Los Angeles, CA", + ] + + +@pytest.mark.asyncio +async def test_order_multiple_fields_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can sort by multiple fields, and that the order of the fields are preserved + """ + user_id = 12345 + project_id = 123 + secondary_project_id = 234 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + SampleFactory(owner_user_id=user_id, collection_id=project_id, collection_location="San Francisco, CA") + SampleFactory( + owner_user_id=user_id, collection_id=secondary_project_id, collection_location="San Francisco, CA" + ) + SampleFactory(owner_user_id=user_id, collection_id=project_id, collection_location="Mountain View, CA") + SampleFactory( + owner_user_id=user_id, collection_id=secondary_project_id, collection_location="Mountain View, CA" + ) + + # Fetch all samples, in descending order of collection id and then ascending order of collection location + query = """ + query MyQuery { + samples(orderBy: [{collectionId: desc}, {collectionLocation: asc}]) { + collectionLocation, + collectionId + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id, secondary_project_id]) + locations = [sample["collectionLocation"] for sample in output["data"]["samples"]] + collection_ids = [sample["collectionId"] for sample in output["data"]["samples"]] + assert locations == ["Mountain View, CA", "San Francisco, CA", "Mountain View, CA", "San Francisco, CA"] + assert collection_ids == [234, 234, 123, 123] + + # Fetch all samples, in ascending order of collection location and then descending order of collection id + query = """ + query MyQuery { + samples(orderBy: [{collectionLocation: asc}, {collectionId: desc}]) { + collectionLocation, + collectionId + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id, secondary_project_id]) + locations = [sample["collectionLocation"] for sample in output["data"]["samples"]] + collection_ids = [sample["collectionId"] for sample in output["data"]["samples"]] + assert locations == ["Mountain View, CA", "Mountain View, CA", "San Francisco, CA", "San Francisco, CA"] + assert collection_ids == [234, 123, 234, 123] + + +@pytest.mark.asyncio +async def test_sort_nested_objects_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can sort nested objects + """ + user_id = 12345 + project_id = 123 + + with sync_db.session() as session: + SessionStorage.set_session(session) + sample_1 = SampleFactory( + owner_user_id=user_id, collection_id=project_id, collection_location="San Francisco, CA" + ) + sample_2 = SampleFactory( + owner_user_id=user_id, collection_id=project_id, collection_location="Mountain View, CA" + ) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_1, nucleic_acid="DNA") + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_1, nucleic_acid="RNA") + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_2, nucleic_acid="DNA") + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_2, nucleic_acid="RNA") + + # Fetch all samples, in descending order of collection location, and then in ascending order of the related sequencing read's nucleic acid + query = """ + query MyQuery { + samples(orderBy: {collectionLocation: desc}) { + collectionLocation + sequencingReads(orderBy: {nucleicAcid: asc}) { + edges { + node { + nucleicAcid + } + } + } + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + locations = [sample["collectionLocation"] for sample in output["data"]["samples"]] + assert locations == ["San Francisco, CA", "Mountain View, CA"] + nucleic_acids = [] + for sample in output["data"]["samples"]: + for sr in sample["sequencingReads"]["edges"]: + nucleic_acids.append(sr["node"]["nucleicAcid"]) + assert nucleic_acids == ["DNA", "RNA", "DNA", "RNA"] + + +@pytest.mark.asyncio +async def test_order_by_related_field_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can sort by fields of a related object + """ + user_id = 12345 + project_id = 123 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + sample_1 = SampleFactory( + owner_user_id=user_id, collection_id=project_id, collection_location="Mountain View, CA" + ) + sample_2 = SampleFactory( + owner_user_id=user_id, collection_id=project_id, collection_location="San Francisco, CA" + ) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_1, nucleic_acid="DNA") + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_1, nucleic_acid="RNA") + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_2, nucleic_acid="DNA") + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_2, nucleic_acid="RNA") + + # Fetch all sequencing reads, in descending order of the related sample's collection location + query = """ + query MyQuery { + sequencingReads(orderBy: {sample: {collectionLocation: desc}}) { + nucleicAcid + sample { + collectionLocation + } + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + nucleic_acids = [sr["nucleicAcid"] for sr in output["data"]["sequencingReads"]] + collection_locations = [sr["sample"]["collectionLocation"] for sr in output["data"]["sequencingReads"]] + assert nucleic_acids == ["DNA", "RNA", "DNA", "RNA"] + assert collection_locations == ["San Francisco, CA", "San Francisco, CA", "Mountain View, CA", "Mountain View, CA"] + + +@pytest.mark.asyncio +async def test_deeply_nested_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can sort by fields of a very deeply nested object + """ + user_id = 12345 + project_id = 123 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + upstream_db_1 = UpstreamDatabaseFactory(owner_user_id=user_id, collection_id=project_id, name="NCBI") + upstream_db_2 = UpstreamDatabaseFactory(owner_user_id=user_id, collection_id=project_id, name="GTDB") + taxon_1 = TaxonFactory(owner_user_id=user_id, collection_id=project_id, upstream_database=upstream_db_1) + taxon_2 = TaxonFactory(owner_user_id=user_id, collection_id=project_id, upstream_database=upstream_db_2) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, taxon=taxon_1) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, taxon=taxon_1) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, taxon=taxon_2) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, taxon=taxon_2) + + # Fetch all contigs, in descending order of the related sequencing read's taxon's upstream database's name + query = """ + query MyQuery { + sequencingReads(orderBy: {taxon: {upstreamDatabase: {name: desc}}}) { + id + taxon { + upstreamDatabase { + name + } + } + } + } + """ + + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + upstream_database_names = [d["taxon"]["upstreamDatabase"]["name"] for d in output["data"]["sequencingReads"]] + assert upstream_database_names == ["NCBI", "NCBI", "GTDB", "GTDB"] diff --git a/test_app/tests/test_where_clause.py b/test_app/tests/test_where_clause.py index 7910d4f..3ffeb57 100644 --- a/test_app/tests/test_where_clause.py +++ b/test_app/tests/test_where_clause.py @@ -4,8 +4,9 @@ import pytest from platformics.database.connect import SyncDB -from conftest import GQLTestClient, SessionStorage -from test_infra.factories.sequencing_read import SequencingReadFactory +from conftest import GQLTestClient, SessionStorage, FileFactory +from platformics.codegen.tests.output.test_infra.factories.sample import SampleFactory +from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory from support.enums import SequencingTechnology user_id = 12345 @@ -155,3 +156,166 @@ async def test_where_clause_mutations(sync_db: SyncDB, gql_client: GQLTestClient assert sequencing_read["technology"] == new_technology else: assert sequencing_read["technology"] == prev_technology + + +@pytest.mark.asyncio +async def test_where_clause_regex_match(sync_db: SyncDB, gql_client: GQLTestClient) -> None: + """ + Verify that the regex operators work as expected. + """ + # Regex for any string with "MATCH" in it + regex = ".*MATCH.*" + with sync_db.session() as session: + SessionStorage.set_session(session) + # Sample names that match the regex (case-sensitive) + case_sensitive_matches = ["A MATCH", "A MATCHING SAMPLE"] + # Sample names that match the regex if case is ignored, but do not match if case is considered + case_insensitive_matches = ["a match if ignore case", "a matching sample if ignore case"] + # Sample names that don't match the regex at all + no_matches = ["asdf1234", "HCTAM"] + # Create the samples + all_sample_names = case_sensitive_matches + case_insensitive_matches + no_matches + for name in all_sample_names: + SampleFactory.create(name=name, owner_user_id=user_id, collection_id=project_id) + + match_case_query = f""" + query GetSamplesMatchingCase {{ + samples ( where: {{ + name: {{ + _regex: "{regex}" + }} + }}) {{ + name + }} + }} + """ + + match_case_query_output = await gql_client.query(match_case_query, member_projects=[project_id]) + assert len(match_case_query_output["data"]["samples"]) == 2 + output_sample_names = [sample["name"] for sample in match_case_query_output["data"]["samples"]] + assert sorted(output_sample_names) == sorted(case_sensitive_matches) + + ignore_case_query = f""" + query GetSamplesMatchingIgnoreCase {{ + samples ( where: {{ + name: {{ + _iregex: "{regex}" + }} + }}) {{ + name + }} + }} + """ + + ignore_case_query_output = await gql_client.query(ignore_case_query, member_projects=[project_id]) + assert len(ignore_case_query_output["data"]["samples"]) == 4 + output_sample_names = [sample["name"] for sample in ignore_case_query_output["data"]["samples"]] + assert sorted(output_sample_names) == sorted(case_sensitive_matches + case_insensitive_matches) + + no_match_query = f""" + query GetSamplesNoMatch {{ + samples ( where: {{ + name: {{ + _nregex: "{regex}" + }} + }}) {{ + name + }} + }} + """ + + no_match_query_output = await gql_client.query(no_match_query, member_projects=[project_id]) + assert len(no_match_query_output["data"]["samples"]) == 4 + output_sample_names = [sample["name"] for sample in no_match_query_output["data"]["samples"]] + assert sorted(output_sample_names) == sorted(no_matches + case_insensitive_matches) + + no_match_ignore_case_query = f""" + query GetSamplesNoMatchIgnoreCase {{ + samples ( where: {{ + name: {{ + _niregex: "{regex}" + }} + }}) {{ + name + }} + }} + """ + + no_match_ignore_case_query_output = await gql_client.query(no_match_ignore_case_query, member_projects=[project_id]) + assert len(no_match_ignore_case_query_output["data"]["samples"]) == 2 + output_sample_names = [sample["name"] for sample in no_match_ignore_case_query_output["data"]["samples"]] + assert sorted(output_sample_names) == sorted(no_matches) + + +@pytest.mark.asyncio +async def test_soft_deleted_objects(sync_db: SyncDB, gql_client: GQLTestClient) -> None: + """ + Verify that the where clause works as expected with soft-deleted objects. + By default, soft-deleted objects should not be returned. + """ + sequencing_reads = generate_sequencing_reads(sync_db) + FileFactory.update_file_ids() + # Soft delete the first 3 sequencing reads by updating the deleted_at field + deleted_ids = [str(sequencing_reads[0].id), str(sequencing_reads[1].id), str(sequencing_reads[2].id)] + soft_delete_mutation = f""" + mutation SoftDeleteSequencingReads {{ + updateSequencingRead ( + where: {{ + id: {{ _in: [ "{deleted_ids[0]}", "{deleted_ids[1]}", "{deleted_ids[2]}" ] }}, + }}, + input: {{ + deletedAt: "2021-01-01T00:00:00Z", + }} + ) {{ + id + }} + }} + """ + output = await gql_client.query(soft_delete_mutation, member_projects=[project_id]) + assert len(output["data"]["updateSequencingRead"]) == 3 + + # Check that the soft-deleted sequencing reads are not returned + regular_query = """ + query GetSequencingReads { + sequencingReads { + id + } + } + """ + output = await gql_client.query(regular_query, member_projects=[project_id]) + assert len(output["data"]["sequencingReads"]) == 2 + for sequencing_read in output["data"]["sequencingReads"]: + assert sequencing_read["id"] not in deleted_ids + + # Check that the soft-deleted sequencing reads are returned when explicitly requested + soft_deleted_query = """ + query GetSequencingReads { + sequencingReads ( where: { deletedAt: { _is_null: false } }) { + id + } + } + """ + output = await gql_client.query(soft_deleted_query, member_projects=[project_id]) + assert len(output["data"]["sequencingReads"]) == 3 + for sequencing_read in output["data"]["sequencingReads"]: + assert sequencing_read["id"] in deleted_ids + + # Check that we can hard-delete the soft-deleted objects + hard_delete_mutation = f""" + mutation DeleteSequencingReads {{ + deleteSequencingRead ( + where: {{ + id: {{ _in: [ "{deleted_ids[0]}", "{deleted_ids[1]}", "{deleted_ids[2]}" ] }}, + }} + ) {{ + id + }} + }} + """ + + output = await gql_client.query(hard_delete_mutation, user_id=user_id, member_projects=[project_id]) + assert len(output["data"]["deleteSequencingRead"]) == 3 + + # Check that the hard-deleted sequencing reads are not returned + output = await gql_client.query(soft_deleted_query, member_projects=[project_id]) + assert len(output["data"]["sequencingReads"]) == 0 From 9b42b3d7e3e53e36951d86f698e3c059773ca2c5 Mon Sep 17 00:00:00 2001 From: Omar Valenzuela Date: Wed, 5 Jun 2024 11:46:42 -0700 Subject: [PATCH 02/16] get test_deeply_nested_groupby_query to pass --- test_app/schema/schema.yaml | 15 ++++++++------- test_app/tests/test_aggregate_queries.py | 14 +++++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/test_app/schema/schema.yaml b/test_app/schema/schema.yaml index 386fb08..2fca120 100644 --- a/test_app/schema/schema.yaml +++ b/test_app/schema/schema.yaml @@ -189,10 +189,11 @@ classes: primer_file: range: GenomicRange inverse: GenomicRange.sequencing_reads - contigs: + contig: range: Contig - inverse: Contig.sequencing_read - multivalued: true + inverse: Contig.sequencing_reads + annotations: + mutable: false annotations: plural: SequencingReads @@ -215,15 +216,15 @@ classes: mixins: - EntityMixin attributes: - sequencing_read: + sequencing_reads: range: SequencingRead - inverse: Sample.contigs + inverse: SequencingRead.contig multivalued: true sequence: required: true upstream_database: range: UpstreamDatabase - inverse: UpstreamDatabase.contigs + inverse: UpstreamDatabase.contig required: true annotations: mutable: false @@ -252,7 +253,7 @@ classes: required: true annotations: indexed: true - contigs: + contig: range: Contig inverse: Contig.upstream_database multivalued: true diff --git a/test_app/tests/test_aggregate_queries.py b/test_app/tests/test_aggregate_queries.py index f4fdb58..cf52bd6 100644 --- a/test_app/tests/test_aggregate_queries.py +++ b/test_app/tests/test_aggregate_queries.py @@ -329,10 +329,10 @@ async def test_deeply_nested_groupby_query( upstream_db_2 = UpstreamDatabaseFactory(owner_user_id=user_id, collection_id=project_id, name="GTDB") contig_1 = ContigFactory(owner_user_id=user_id, collection_id=project_id, upstream_database=upstream_db_1) contig_2 = ContigFactory(owner_user_id=user_id, collection_id=project_id, upstream_database=upstream_db_2) - SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contigs=contig_1) - SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contigs=contig_1) - SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contigs=contig_2) - SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contigs=contig_2) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_1) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_1) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_2) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_2) query = """ query MyQuery { @@ -340,7 +340,7 @@ async def test_deeply_nested_groupby_query( aggregate { count groupBy { - contigs { + contig { upstreamDatabase { name } @@ -353,9 +353,9 @@ async def test_deeply_nested_groupby_query( results = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) aggregate = results["data"]["sequencingReadsAggregate"]["aggregate"] for group in aggregate: - if group["groupBy"]["contigs"]["upstreamDatabase"]["name"] == "NCBI": + if group["groupBy"]["contig"]["upstreamDatabase"]["name"] == "NCBI": assert group["count"] == 2 - elif group["groupBy"]["contigs"]["upstreamDatabase"]["name"] == "GTDB": + elif group["groupBy"]["contig"]["upstreamDatabase"]["name"] == "GTDB": assert group["count"] == 2 From ed9cf3d95d2d1ba9d28fcde63f826d6a25638860 Mon Sep 17 00:00:00 2001 From: Omar Valenzuela Date: Wed, 5 Jun 2024 12:55:34 -0700 Subject: [PATCH 03/16] resolve test errors, add ConstraintCheckedType --- test_app/schema/schema.yaml | 40 ++++++ test_app/tests/test_basic_queries.py | 2 +- test_app/tests/test_bulk_download_deletion.py | 72 ---------- test_app/tests/test_bulk_download_policy.py | 123 ------------------ test_app/tests/test_cascade_deletion.py | 2 +- test_app/tests/test_field_constraints.py | 4 +- test_app/tests/test_file_concatenation.py | 2 +- test_app/tests/test_file_mutations.py | 4 +- test_app/tests/test_file_queries.py | 4 +- test_app/tests/test_file_uploads.py | 2 +- test_app/tests/test_limit_offset_queries.py | 2 +- test_app/tests/test_nested_queries.py | 6 +- test_app/tests/test_sorting_queries.py | 24 ++-- test_app/tests/test_where_clause.py | 4 +- 14 files changed, 68 insertions(+), 223 deletions(-) delete mode 100644 test_app/tests/test_bulk_download_deletion.py delete mode 100644 test_app/tests/test_bulk_download_policy.py diff --git a/test_app/schema/schema.yaml b/test_app/schema/schema.yaml index 2fca120..8b3b497 100644 --- a/test_app/schema/schema.yaml +++ b/test_app/schema/schema.yaml @@ -260,3 +260,43 @@ classes: # This is where NCBI indexes would live annotations: plural: UpstreamDatabases + + ConstraintCheckedType: + is_a: Entity + mixins: + - EntityMixin + attributes: + length_3_to_8: + range: string + annotations: + minimum_length: 3 + maximum_length: 8 + regex_format_check: + range: string + pattern: '\d{3}-\d{2}-\d{4}' + min_value_0: + range: integer + minimum_value: 0 + enum_field: + range: NucleicAcid + bool_field: + range: boolean + max_value_9: + range: integer + maximum_value: 9 + min_value_0_max_value_9: + range: integer + minimum_value: 0 + maximum_value: 9 + float_1dot1_to_2dot2: + range: float + minimum_value: 1.1 + maximum_value: 2.2 + no_string_checks: + range: string + no_int_checks: + range: integer + no_float_checks: + range: float + annotations: + plural: ConstraintCheckedTypes diff --git a/test_app/tests/test_basic_queries.py b/test_app/tests/test_basic_queries.py index 7c4d569..0ab1f6a 100644 --- a/test_app/tests/test_basic_queries.py +++ b/test_app/tests/test_basic_queries.py @@ -6,7 +6,7 @@ import pytest from platformics.database.connect import SyncDB from conftest import GQLTestClient, SessionStorage -from platformics.codegen.tests.output.test_infra.factories.sample import SampleFactory +from test_infra.factories.sample import SampleFactory date_now = datetime.datetime.now() diff --git a/test_app/tests/test_bulk_download_deletion.py b/test_app/tests/test_bulk_download_deletion.py deleted file mode 100644 index 69ceb31..0000000 --- a/test_app/tests/test_bulk_download_deletion.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -Test deletion of bulkDownloads > 7 days old -""" - -import pytest -import datetime -from platformics.database.connect import SyncDB -from conftest import SessionStorage, GQLTestClient, FileFactory -from platformics.codegen.tests.output.test_infra.factories.bulk_download import BulkDownloadFactory - - -@pytest.mark.asyncio -async def test_delete_old_bulk_downloads( - sync_db: SyncDB, - gql_client: GQLTestClient, -) -> None: - """ - Test that we can delete bulk downloads older than 7 days - """ - user_id = 12345 - project_id = 123 - - # Create mock data: 3 current bulk downloads, 2 bulk downloads from 1 week ago, and 5 bulk downloads from 1 month ago - with sync_db.session() as session: - SessionStorage.set_session(session) - current_time = datetime.datetime.now() - one_week_ago = current_time - datetime.timedelta(days=7) - one_month_ago = current_time - datetime.timedelta(days=30) - - current_bulk_downloads = BulkDownloadFactory.create_batch(3, owner_user_id=user_id, collection_id=None) - one_week_old_bulk_downloads = BulkDownloadFactory.create_batch( - 2, owner_user_id=user_id, collection_id=None, created_at=one_week_ago - ) - one_month_old_bulk_downloads = BulkDownloadFactory.create_batch( - 5, owner_user_id=user_id, collection_id=None, created_at=one_month_ago - ) - all_old_bulk_downloads = one_week_old_bulk_downloads + one_month_old_bulk_downloads - FileFactory.update_file_ids() - - # Delete old bulk downloads - query = """ - mutation MyMutation { - deleteOldBulkDownloads { - id - } - } - """ - - # Verify that the mutation can't be called by a non-system user - result = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) - assert result["data"] is None - assert "Unauthorized" - - # Verify that the mutation deletes all bulk downloads older than 7 days - result = await gql_client.query(query, user_id=user_id, member_projects=[project_id], service_identity="rails") - assert len(result["data"]["deleteOldBulkDownloads"]) == 7 - assert [bd["id"] for bd in result["data"]["deleteOldBulkDownloads"]] == [ - str(bd.id) for bd in all_old_bulk_downloads - ] - - # Check that current bulk downloads are still there - query = """ - query MyQuery { - bulkDownloads { - id - } - } - """ - - result = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) - assert len(result["data"]["bulkDownloads"]) == 3 - assert [bd["id"] for bd in result["data"]["bulkDownloads"]] == [str(bd.id) for bd in current_bulk_downloads] diff --git a/test_app/tests/test_bulk_download_policy.py b/test_app/tests/test_bulk_download_policy.py deleted file mode 100644 index 9455c45..0000000 --- a/test_app/tests/test_bulk_download_policy.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -Test collection_id policy for entities and bulk downloads -1. Test that users cannot create normal entities without a collection_id, or update them to have a null collection_id. -2. Test that users cannot create bulk downloads WITH a collection_id. -3. Test that only owners can view their own bulk downloads -""" - -import pytest -from platformics.database.connect import SyncDB -from conftest import SessionStorage, GQLTestClient -from platformics.codegen.tests.output.test_infra.factories.bulk_download import BulkDownloadFactory -from platformics.codegen.tests.output.test_infra.factories.sample import SampleFactory - - -@pytest.mark.asyncio -async def test_null_collection_id_for_regular_entities( - sync_db: SyncDB, - gql_client: GQLTestClient, -) -> None: - """ - Test that users cannot create normal entities without a collection_id, or update them to have a null collection_id. - """ - owner_user_id = 333 - collection_id = 444 - - # Attempt to create a sample without a collection_id - query = f""" - mutation MyMutation {{ - createSample( - input: {{ - name: "No collection id", - sampleType: "Type 1", - waterControl: false, - collectionLocation: "San Francisco, CA", - collectionDate: "2024-01-01", - }} - ) {{ id }} - }} - """ - - output = await gql_client.query(query, user_id=owner_user_id, member_projects=[collection_id]) - assert "Unauthorized: Cannot create entity in this collection" in output["errors"][0]["message"] - - # Create mock data - with sync_db.session() as session: - SessionStorage.set_session(session) - sample = SampleFactory.create(name="Test Sample", owner_user_id=owner_user_id, collection_id=collection_id) - - # Attempt to update the sample to have a null collection_id - query = f""" - mutation MyMutation {{ - updateSample( - where: {{id: {{_eq: "{sample.id}"}} }}, - input: {{ - collectionId: null - }} - ) {{ id }} - }} - """ - - output = await gql_client.query(query, user_id=owner_user_id, member_projects=[collection_id]) - assert ( - "Field 'collectionId' is not defined by type 'SampleUpdateInput'. Did you mean 'collectionDate'?" - in output["errors"][0]["message"] - ) - - -@pytest.mark.asyncio -async def test_null_collection_id_for_bulk_downloads( - sync_db: SyncDB, - gql_client: GQLTestClient, -) -> None: - """ - Test that users cannot create bulk downloads WITH a collection_id. - """ - owner_user_id = 333 - collection_id = 444 - - # Attempt to create a bulk download with a collection_id - query = f""" - mutation MyMutation {{ - createBulkDownload( - input: {{ - collectionId: {collection_id}, - downloadDisplayName: "Test Bulk Download", - }} - ) {{ id }} - }} - """ - - output = await gql_client.query(query, user_id=owner_user_id, member_projects=[collection_id]) - assert "Unauthorized: Cannot create entity in this collection" in output["errors"][0]["message"] - - -@pytest.mark.asyncio -async def test_view_bulk_downloads( - sync_db: SyncDB, - gql_client: GQLTestClient, -) -> None: - """ - Test that only owners can view their own bulk downloads - """ - user_id = 111 - other_user_id = 222 - project_id = 123 - - # Create mock data - with sync_db.session() as session: - SessionStorage.set_session(session) - # Create 4 bulk downloads owned by user_id, and 3 for another user - BulkDownloadFactory.create_batch(4, owner_user_id=user_id, collection_id=None) - BulkDownloadFactory.create_batch(3, owner_user_id=other_user_id, collection_id=None) - - # Fetch all bulk downloads - query = """ - query MyQuery { - bulkDownloads { - id - } - } - """ - output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) - assert len(output["data"]["bulkDownloads"]) == 4 diff --git a/test_app/tests/test_cascade_deletion.py b/test_app/tests/test_cascade_deletion.py index af06a83..ee5644d 100644 --- a/test_app/tests/test_cascade_deletion.py +++ b/test_app/tests/test_cascade_deletion.py @@ -6,7 +6,7 @@ from mypy_boto3_s3.client import S3Client from platformics.database.connect import SyncDB from conftest import SessionStorage, GQLTestClient, FileFactory -from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory +from test_infra.factories.sequencing_read import SequencingReadFactory @pytest.mark.asyncio diff --git a/test_app/tests/test_field_constraints.py b/test_app/tests/test_field_constraints.py index 00b4c45..a798034 100644 --- a/test_app/tests/test_field_constraints.py +++ b/test_app/tests/test_field_constraints.py @@ -9,8 +9,8 @@ import pytest from database.models import Sample from conftest import GQLTestClient, SessionStorage -from platformics.codegen.tests.output.test_infra.factories.constraint_checked_type import ConstraintCheckedTypeFactory -from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory +from test_infra.factories.constraint_checked_type import ConstraintCheckedTypeFactory +from test_infra.factories.sequencing_read import SequencingReadFactory from platformics.database.connect import SyncDB diff --git a/test_app/tests/test_file_concatenation.py b/test_app/tests/test_file_concatenation.py index edc5d46..d39c116 100644 --- a/test_app/tests/test_file_concatenation.py +++ b/test_app/tests/test_file_concatenation.py @@ -7,7 +7,7 @@ from mypy_boto3_s3.client import S3Client from platformics.database.connect import SyncDB from conftest import SessionStorage, GQLTestClient -from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory +from test_infra.factories.sequencing_read import SequencingReadFactory @pytest.mark.parametrize( diff --git a/test_app/tests/test_file_mutations.py b/test_app/tests/test_file_mutations.py index 4a97cf5..b277fd2 100644 --- a/test_app/tests/test_file_mutations.py +++ b/test_app/tests/test_file_mutations.py @@ -10,8 +10,8 @@ from platformics.database.connect import SyncDB from database.models import File, FileStatus from conftest import SessionStorage, FileFactory, GQLTestClient -from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory -from platformics.codegen.tests.output.database.models import SequencingRead +from test_infra.factories.sequencing_read import SequencingReadFactory +from database.models import SequencingRead @pytest.mark.asyncio diff --git a/test_app/tests/test_file_queries.py b/test_app/tests/test_file_queries.py index 5ec5504..5143157 100644 --- a/test_app/tests/test_file_queries.py +++ b/test_app/tests/test_file_queries.py @@ -4,9 +4,9 @@ import pytest import sqlalchemy as sa -from database.models.file import File +from database.models import File from conftest import FileFactory, GQLTestClient, SessionStorage -from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory +from test_infra.factories.sequencing_read import SequencingReadFactory from platformics.database.connect import SyncDB diff --git a/test_app/tests/test_file_uploads.py b/test_app/tests/test_file_uploads.py index 5de2b5d..7b06cc2 100644 --- a/test_app/tests/test_file_uploads.py +++ b/test_app/tests/test_file_uploads.py @@ -7,7 +7,7 @@ from mypy_boto3_s3.client import S3Client from platformics.database.connect import SyncDB from conftest import SessionStorage, GQLTestClient -from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory +from test_infra.factories.sequencing_read import SequencingReadFactory @pytest.mark.asyncio diff --git a/test_app/tests/test_limit_offset_queries.py b/test_app/tests/test_limit_offset_queries.py index 5fd079c..acd5497 100644 --- a/test_app/tests/test_limit_offset_queries.py +++ b/test_app/tests/test_limit_offset_queries.py @@ -6,7 +6,7 @@ import pytest from platformics.database.connect import SyncDB from conftest import GQLTestClient, SessionStorage -from platformics.codegen.tests.output.test_infra.factories.sample import SampleFactory +from test_infra.factories.sample import SampleFactory date_now = datetime.datetime.now() diff --git a/test_app/tests/test_nested_queries.py b/test_app/tests/test_nested_queries.py index 22fde81..3a0d8f8 100644 --- a/test_app/tests/test_nested_queries.py +++ b/test_app/tests/test_nested_queries.py @@ -7,9 +7,9 @@ from collections import defaultdict from platformics.database.connect import SyncDB from conftest import GQLTestClient, SessionStorage -from platformics.codegen.tests.output.test_infra.factories.sample import SampleFactory -from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory -from api.types.entities import Entity +from test_infra.factories.sample import SampleFactory +from test_infra.factories.sequencing_read import SequencingReadFactory +from platformics.api.types.entities import Entity def get_id(entity: Entity) -> str: diff --git a/test_app/tests/test_sorting_queries.py b/test_app/tests/test_sorting_queries.py index 6230262..bbdfb2e 100644 --- a/test_app/tests/test_sorting_queries.py +++ b/test_app/tests/test_sorting_queries.py @@ -5,10 +5,10 @@ import pytest from platformics.database.connect import SyncDB from conftest import GQLTestClient, SessionStorage -from platformics.codegen.tests.output.test_infra.factories.sample import SampleFactory -from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory -from platformics.codegen.tests.output.test_infra.factories.taxon import TaxonFactory -from platformics.codegen.tests.output.test_infra.factories.upstream_database import UpstreamDatabaseFactory +from test_infra.factories.sample import SampleFactory +from test_infra.factories.sequencing_read import SequencingReadFactory +from test_infra.factories.contig import ContigFactory +from test_infra.factories.upstream_database import UpstreamDatabaseFactory @pytest.mark.asyncio @@ -217,19 +217,19 @@ async def test_deeply_nested_query( SessionStorage.set_session(session) upstream_db_1 = UpstreamDatabaseFactory(owner_user_id=user_id, collection_id=project_id, name="NCBI") upstream_db_2 = UpstreamDatabaseFactory(owner_user_id=user_id, collection_id=project_id, name="GTDB") - taxon_1 = TaxonFactory(owner_user_id=user_id, collection_id=project_id, upstream_database=upstream_db_1) - taxon_2 = TaxonFactory(owner_user_id=user_id, collection_id=project_id, upstream_database=upstream_db_2) - SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, taxon=taxon_1) - SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, taxon=taxon_1) - SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, taxon=taxon_2) - SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, taxon=taxon_2) + contig_1 = ContigFactory(owner_user_id=user_id, collection_id=project_id, upstream_database=upstream_db_1) + contig_2 = ContigFactory(owner_user_id=user_id, collection_id=project_id, upstream_database=upstream_db_2) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_1) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_1) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_2) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_2) # Fetch all contigs, in descending order of the related sequencing read's taxon's upstream database's name query = """ query MyQuery { sequencingReads(orderBy: {taxon: {upstreamDatabase: {name: desc}}}) { id - taxon { + contig { upstreamDatabase { name } @@ -239,5 +239,5 @@ async def test_deeply_nested_query( """ output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) - upstream_database_names = [d["taxon"]["upstreamDatabase"]["name"] for d in output["data"]["sequencingReads"]] + upstream_database_names = [d["contig"]["upstreamDatabase"]["name"] for d in output["data"]["sequencingReads"]] assert upstream_database_names == ["NCBI", "NCBI", "GTDB", "GTDB"] diff --git a/test_app/tests/test_where_clause.py b/test_app/tests/test_where_clause.py index 3ffeb57..a7ee270 100644 --- a/test_app/tests/test_where_clause.py +++ b/test_app/tests/test_where_clause.py @@ -5,8 +5,8 @@ import pytest from platformics.database.connect import SyncDB from conftest import GQLTestClient, SessionStorage, FileFactory -from platformics.codegen.tests.output.test_infra.factories.sample import SampleFactory -from platformics.codegen.tests.output.test_infra.factories.sequencing_read import SequencingReadFactory +from test_infra.factories.sample import SampleFactory +from test_infra.factories.sequencing_read import SequencingReadFactory from support.enums import SequencingTechnology user_id = 12345 From 577db0cc81c78ba3e8dc60e5f57ad125a0f1e258 Mon Sep 17 00:00:00 2001 From: Omar Valenzuela Date: Wed, 5 Jun 2024 12:58:11 -0700 Subject: [PATCH 04/16] fix tests/test_test_setup.py::test_graphql_query --- test_app/tests/test_test_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_app/tests/test_test_setup.py b/test_app/tests/test_test_setup.py index 89c9f7e..509c637 100644 --- a/test_app/tests/test_test_setup.py +++ b/test_app/tests/test_test_setup.py @@ -12,7 +12,7 @@ async def test_graphql_query(gql_client: GQLTestClient, api_test_schema: FastAPI """ Make sure we're using the right schema and http client """ - assert api_test_schema.title == "Codegen Tests" + assert api_test_schema.title == "Platformics" assert gql_client.http_client.base_url.host == "test-codegen" # Reference genomes is not an entity in the mock schema but is one in the real schema From ee3b78b40776c5800cb37c04d9156107a12b4330 Mon Sep 17 00:00:00 2001 From: Omar Valenzuela Date: Wed, 5 Jun 2024 13:00:25 -0700 Subject: [PATCH 05/16] fix tests/test_sorting_queries.py --- test_app/tests/test_sorting_queries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_app/tests/test_sorting_queries.py b/test_app/tests/test_sorting_queries.py index bbdfb2e..8a5c48d 100644 --- a/test_app/tests/test_sorting_queries.py +++ b/test_app/tests/test_sorting_queries.py @@ -227,7 +227,7 @@ async def test_deeply_nested_query( # Fetch all contigs, in descending order of the related sequencing read's taxon's upstream database's name query = """ query MyQuery { - sequencingReads(orderBy: {taxon: {upstreamDatabase: {name: desc}}}) { + sequencingReads(orderBy: {contig: {upstreamDatabase: {name: desc}}}) { id contig { upstreamDatabase { From 1c69ec5ac1c88647c00198c06a8e2dbbbf4e4ec1 Mon Sep 17 00:00:00 2001 From: Omar Valenzuela Date: Wed, 5 Jun 2024 15:30:09 -0700 Subject: [PATCH 06/16] remove unecessary files --- Makefile | 4 - test_app/tests/output/api/mutations.py | 56 -- test_app/tests/output/api/queries.py | 51 -- test_app/tests/output/api/types/contig.py | 378 -------------- .../tests/output/api/types/genomic_range.py | 426 --------------- test_app/tests/output/api/types/sample.py | 432 --------------- .../tests/output/api/types/sequencing_read.py | 492 ------------------ .../tests/output/cerbos/policies/contig.yaml | 22 - .../output/cerbos/policies/genomic_range.yaml | 22 - .../tests/output/cerbos/policies/sample.yaml | 22 - .../cerbos/policies/sequencing_read.yaml | 22 - .../tests/output/database/models/__init__.py | 20 - .../tests/output/database/models/contig.py | 32 -- .../output/database/models/genomic_range.py | 32 -- .../tests/output/database/models/sample.py | 36 -- .../output/database/models/sequencing_read.py | 51 -- test_app/tests/output/support/enums.py | 49 -- .../output/test_infra/factories/contig.py | 37 -- .../test_infra/factories/genomic_range.py | 36 -- .../output/test_infra/factories/sample.py | 36 -- .../test_infra/factories/sequencing_read.py | 73 --- .../test_schemas/overrides/api/.gitignore | 1 - .../overrides/api/extra_test_code.py.j2 | 4 - test_app/tests/test_schemas/platformics.yaml | 443 ---------------- 24 files changed, 2777 deletions(-) delete mode 100644 test_app/tests/output/api/mutations.py delete mode 100644 test_app/tests/output/api/queries.py delete mode 100644 test_app/tests/output/api/types/contig.py delete mode 100644 test_app/tests/output/api/types/genomic_range.py delete mode 100644 test_app/tests/output/api/types/sample.py delete mode 100644 test_app/tests/output/api/types/sequencing_read.py delete mode 100644 test_app/tests/output/cerbos/policies/contig.yaml delete mode 100644 test_app/tests/output/cerbos/policies/genomic_range.yaml delete mode 100644 test_app/tests/output/cerbos/policies/sample.yaml delete mode 100644 test_app/tests/output/cerbos/policies/sequencing_read.yaml delete mode 100644 test_app/tests/output/database/models/__init__.py delete mode 100644 test_app/tests/output/database/models/contig.py delete mode 100644 test_app/tests/output/database/models/genomic_range.py delete mode 100644 test_app/tests/output/database/models/sample.py delete mode 100644 test_app/tests/output/database/models/sequencing_read.py delete mode 100644 test_app/tests/output/support/enums.py delete mode 100644 test_app/tests/output/test_infra/factories/contig.py delete mode 100644 test_app/tests/output/test_infra/factories/genomic_range.py delete mode 100644 test_app/tests/output/test_infra/factories/sample.py delete mode 100644 test_app/tests/output/test_infra/factories/sequencing_read.py delete mode 100644 test_app/tests/test_schemas/overrides/api/.gitignore delete mode 100644 test_app/tests/test_schemas/overrides/api/extra_test_code.py.j2 delete mode 100644 test_app/tests/test_schemas/platformics.yaml diff --git a/Makefile b/Makefile index 3b02fc4..1a7915e 100644 --- a/Makefile +++ b/Makefile @@ -20,10 +20,6 @@ help: ## display help for this makefile @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' @echo "### SHARED FUNCTIONS END ###" -.PHONY: codegen -codegen: - $(MAKE_TEST_APP) codegen - .PHONY: codegen codegen: ## Run codegen to convert the LinkML schema to a GQL API $(docker_compose_run) $(CONTAINER) python3 -m platformics.cli.main api generate --schemafile ./schema/schema.yaml --output-prefix . diff --git a/test_app/tests/output/api/mutations.py b/test_app/tests/output/api/mutations.py deleted file mode 100644 index 01772cb..0000000 --- a/test_app/tests/output/api/mutations.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -GraphQL mutations for files and entities - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/api/mutations.py.j2 instead. -""" - -import strawberry -from typing import Sequence -from api.files import ( - File, - create_file, - upload_file, - mark_upload_complete, - concatenate_files, - SignedURL, - MultipartUploadResponse, -) -from api.types.sample import Sample, create_sample, update_sample, delete_sample -from api.types.sequencing_read import ( - SequencingRead, - create_sequencing_read, - update_sequencing_read, - delete_sequencing_read, -) -from api.types.genomic_range import GenomicRange, create_genomic_range, update_genomic_range, delete_genomic_range -from api.types.contig import Contig, create_contig, update_contig, delete_contig - - -@strawberry.type -class Mutation: - # File mutations - create_file: File = create_file - upload_file: MultipartUploadResponse = upload_file - mark_upload_complete: File = mark_upload_complete - concatenate_files: SignedURL = concatenate_files - - # Sample mutations - create_sample: Sample = create_sample - update_sample: Sequence[Sample] = update_sample - delete_sample: Sequence[Sample] = delete_sample - - # SequencingRead mutations - create_sequencing_read: SequencingRead = create_sequencing_read - update_sequencing_read: Sequence[SequencingRead] = update_sequencing_read - delete_sequencing_read: Sequence[SequencingRead] = delete_sequencing_read - - # GenomicRange mutations - create_genomic_range: GenomicRange = create_genomic_range - update_genomic_range: Sequence[GenomicRange] = update_genomic_range - delete_genomic_range: Sequence[GenomicRange] = delete_genomic_range - - # Contig mutations - create_contig: Contig = create_contig - update_contig: Sequence[Contig] = update_contig - delete_contig: Sequence[Contig] = delete_contig diff --git a/test_app/tests/output/api/queries.py b/test_app/tests/output/api/queries.py deleted file mode 100644 index 5c659a7..0000000 --- a/test_app/tests/output/api/queries.py +++ /dev/null @@ -1,51 +0,0 @@ -""" -Supported GraphQL queries for files and entities - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/api/queries.py.j2 instead. -""" - -import strawberry -from strawberry import relay -from typing import Sequence, List -from api.files import File, resolve_files -from api.types.sample import Sample, resolve_samples, SampleAggregate, resolve_samples_aggregate -from api.types.sequencing_read import ( - SequencingRead, - resolve_sequencing_reads, - SequencingReadAggregate, - resolve_sequencing_reads_aggregate, -) -from api.types.genomic_range import ( - GenomicRange, - resolve_genomic_ranges, - GenomicRangeAggregate, - resolve_genomic_ranges_aggregate, -) -from api.types.contig import Contig, resolve_contigs, ContigAggregate, resolve_contigs_aggregate - - -@strawberry.type -class Query: - # Allow relay-style queries by node ID - node: relay.Node = relay.node() - nodes: List[relay.Node] = relay.node() - # Query files - files: Sequence[File] = resolve_files - - # Query entities - samples: Sequence[Sample] = resolve_samples - sequencing_reads: Sequence[SequencingRead] = resolve_sequencing_reads - genomic_ranges: Sequence[GenomicRange] = resolve_genomic_ranges - contigs: Sequence[Contig] = resolve_contigs - - # Query entity aggregates - samples_aggregate: SampleAggregate = resolve_samples_aggregate - sequencing_reads_aggregate: SequencingReadAggregate = resolve_sequencing_reads_aggregate - genomic_ranges_aggregate: GenomicRangeAggregate = resolve_genomic_ranges_aggregate - contigs_aggregate: ContigAggregate = resolve_contigs_aggregate - - @strawberry.field - def uncaught_exception(self) -> str: - # Trigger an AttributeException - return self.kaboom diff --git a/test_app/tests/output/api/types/contig.py b/test_app/tests/output/api/types/contig.py deleted file mode 100644 index 04baf39..0000000 --- a/test_app/tests/output/api/types/contig.py +++ /dev/null @@ -1,378 +0,0 @@ -""" -GraphQL type for Contig - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/api/types/class_name.py.j2 instead. -""" - -# ruff: noqa: E501 Line too long - - -import typing -from typing import TYPE_CHECKING, Annotated, Optional, Sequence - -import database.models as db -import strawberry -from platformics.api.core.helpers import get_db_rows, get_aggregate_db_rows -from api.types.entities import EntityInterface -from cerbos.sdk.client import CerbosClient -from cerbos.sdk.model import Principal, Resource -from fastapi import Depends -from platformics.api.core.errors import PlatformicsException -from platformics.api.core.deps import get_cerbos_client, get_db_session, require_auth_principal -from platformics.api.core.gql_to_sql import ( - aggregator_map, - IntComparators, - StrComparators, - UUIDComparators, -) -from platformics.api.core.strawberry_extensions import DependencyExtension -from platformics.security.authorization import CerbosAction -from sqlalchemy import inspect -from sqlalchemy.engine.row import RowMapping -from sqlalchemy.ext.asyncio import AsyncSession -from strawberry.types import Info -from typing_extensions import TypedDict -import enum - -E = typing.TypeVar("E", db.File, db.Entity) -T = typing.TypeVar("T") - -if TYPE_CHECKING: - from api.types.sequencing_read import SequencingReadWhereClause, SequencingRead - - pass -else: - SequencingReadWhereClause = "SequencingReadWhereClause" - SequencingRead = "SequencingRead" - pass - - -""" ------------------------------------------------------------------------------- -Dataloaders ------------------------------------------------------------------------------- -These are batching functions for loading related objects to avoid N+1 queries. -""" - - -@strawberry.field -async def load_sequencing_read_rows( - root: "Contig", - info: Info, - where: Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")] | None = None, -) -> Optional[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]]: - dataloader = info.context["sqlalchemy_loader"] - mapper = inspect(db.Contig) - relationship = mapper.relationships["sequencing_read"] - return await dataloader.loader_for(relationship, where).load(root.sequencing_read_id) # type:ignore - - -""" ------------------------------------------------------------------------------- -Define Strawberry GQL types ------------------------------------------------------------------------------- -""" - -""" -Only let users specify IDs in WHERE clause when mutating data (for safety). -We can extend that list as we gather more use cases from the FE team. -""" - - -@strawberry.input -class ContigWhereClauseMutations(TypedDict): - id: UUIDComparators | None - - -""" -Supported WHERE clause attributes -""" - - -@strawberry.input -class ContigWhereClause(TypedDict): - id: UUIDComparators | None - producing_run_id: IntComparators | None - owner_user_id: IntComparators | None - collection_id: IntComparators | None - sequencing_read: ( - Optional[Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")]] | None - ) - sequence: Optional[StrComparators] | None - - -""" -Define Contig type -""" - - -@strawberry.type -class Contig(EntityInterface): - id: strawberry.ID - producing_run_id: Optional[int] - owner_user_id: int - collection_id: int - sequencing_read: Optional[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]] = ( - load_sequencing_read_rows - ) # type:ignore - sequence: str - - -""" -We need to add this to each Queryable type so that strawberry will accept either our -Strawberry type *or* a SQLAlchemy model instance as a valid response class from a resolver -""" -Contig.__strawberry_definition__.is_type_of = ( # type: ignore - lambda obj, info: type(obj) == db.Contig or type(obj) == Contig -) - -""" ------------------------------------------------------------------------------- -Aggregation types ------------------------------------------------------------------------------- -""" - -""" -Define columns that support numerical aggregations -""" - - -@strawberry.type -class ContigNumericalColumns: - producing_run_id: Optional[int] = None - owner_user_id: Optional[int] = None - collection_id: Optional[int] = None - - -""" -Define columns that support min/max aggregations -""" - - -@strawberry.type -class ContigMinMaxColumns: - producing_run_id: Optional[int] = None - owner_user_id: Optional[int] = None - collection_id: Optional[int] = None - sequence: Optional[str] = None - - -""" -Define enum of all columns to support count and count(distinct) aggregations -""" - - -@strawberry.enum -class ContigCountColumns(enum.Enum): - sequencing_read = "sequencing_read" - sequence = "sequence" - entity_id = "entity_id" - id = "id" - producing_run_id = "producing_run_id" - owner_user_id = "owner_user_id" - collection_id = "collection_id" - created_at = "created_at" - updated_at = "updated_at" - deleted_at = "deleted_at" - - -""" -All supported aggregation functions -""" - - -@strawberry.type -class ContigAggregateFunctions: - # This is a hack to accept "distinct" and "columns" as arguments to "count" - @strawberry.field - def count(self, distinct: Optional[bool] = False, columns: Optional[ContigCountColumns] = None) -> Optional[int]: - # Count gets set with the proper value in the resolver, so we just return it here - return self.count # type: ignore - - sum: Optional[ContigNumericalColumns] = None - avg: Optional[ContigNumericalColumns] = None - min: Optional[ContigMinMaxColumns] = None - max: Optional[ContigMinMaxColumns] = None - stddev: Optional[ContigNumericalColumns] = None - variance: Optional[ContigNumericalColumns] = None - - -""" -Wrapper around ContigAggregateFunctions -""" - - -@strawberry.type -class ContigAggregate: - aggregate: Optional[ContigAggregateFunctions] = None - - -""" ------------------------------------------------------------------------------- -Mutation types ------------------------------------------------------------------------------- -""" - - -@strawberry.input() -class ContigCreateInput: - collection_id: int - sequencing_read_id: Optional[strawberry.ID] = None - sequence: str - - -@strawberry.input() -class ContigUpdateInput: - collection_id: Optional[int] = None - sequencing_read_id: Optional[strawberry.ID] = None - sequence: Optional[str] = None - - -""" ------------------------------------------------------------------------------- -Utilities ------------------------------------------------------------------------------- -""" - - -@strawberry.field(extensions=[DependencyExtension()]) -async def resolve_contigs( - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), - where: Optional[ContigWhereClause] = None, -) -> typing.Sequence[Contig]: - """ - Resolve Contig objects. Used for queries (see api/queries.py). - """ - return await get_db_rows(db.Contig, session, cerbos_client, principal, where, []) # type: ignore - - -def format_contig_aggregate_output(query_results: RowMapping) -> ContigAggregateFunctions: - """ - Given a row from the DB containing the results of an aggregate query, - format the results using the proper GraphQL types. - """ - output = ContigAggregateFunctions() - for aggregate_name, value in query_results.items(): - if aggregate_name == "count": - output.count = value - else: - aggregator_fn, col_name = aggregate_name.split("_", 1) - # Filter out the group_by key from the results if one was provided. - if aggregator_fn in aggregator_map.keys(): - if not getattr(output, aggregator_fn): - if aggregate_name in ["min", "max"]: - setattr(output, aggregator_fn, ContigMinMaxColumns()) - else: - setattr(output, aggregator_fn, ContigNumericalColumns()) - setattr(getattr(output, aggregator_fn), col_name, value) - return output - - -@strawberry.field(extensions=[DependencyExtension()]) -async def resolve_contigs_aggregate( - info: Info, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), - where: Optional[ContigWhereClause] = None, -) -> ContigAggregate: - """ - Aggregate values for Contig objects. Used for queries (see api/queries.py). - """ - # Get the selected aggregate functions and columns to operate on - # TODO: not sure why selected_fields is a list - # The first list of selections will always be ["aggregate"], so just grab the first item - selections = info.selected_fields[0].selections[0].selections - rows = await get_aggregate_db_rows(db.Contig, session, cerbos_client, principal, where, selections, []) # type: ignore - aggregate_output = format_contig_aggregate_output(rows) - return ContigAggregate(aggregate=aggregate_output) - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def create_contig( - input: ContigCreateInput, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), -) -> db.Entity: - """ - Create a new Contig object. Used for mutations (see api/mutations.py). - """ - params = input.__dict__ - - # Validate that user can create entity in this collection - attr = {"collection_id": input.collection_id} - resource = Resource(id="NEW_ID", kind=db.Contig.__tablename__, attr=attr) - if not cerbos_client.is_allowed("create", principal, resource): - raise PlatformicsException("Unauthorized: Cannot create entity in this collection") - - # Save to DB - params["owner_user_id"] = int(principal.id) - new_entity = db.Contig(**params) - session.add(new_entity) - await session.commit() - return new_entity - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def update_contig( - input: ContigUpdateInput, - where: ContigWhereClauseMutations, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), -) -> Sequence[db.Entity]: - """ - Update Contig objects. Used for mutations (see api/mutations.py). - """ - params = input.__dict__ - - # Need at least one thing to update - num_params = len([x for x in params if params[x] is not None]) - if num_params == 0: - raise PlatformicsException("No fields to update") - - # Fetch entities for update, if we have access to them - entities = await get_db_rows(db.Contig, session, cerbos_client, principal, where, [], CerbosAction.UPDATE) - if len(entities) == 0: - raise PlatformicsException("Unauthorized: Cannot update entities") - - # Validate that the user has access to the new collection ID - if input.collection_id: - attr = {"collection_id": input.collection_id} - resource = Resource(id="SOME_ID", kind=db.Contig.__tablename__, attr=attr) - if not cerbos_client.is_allowed(CerbosAction.UPDATE, principal, resource): - raise PlatformicsException("Unauthorized: Cannot access new collection") - - # Update DB - for entity in entities: - for key in params: - if params[key]: - setattr(entity, key, params[key]) - await session.commit() - return entities - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def delete_contig( - where: ContigWhereClauseMutations, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), -) -> Sequence[db.Entity]: - """ - Delete Contig objects. Used for mutations (see api/mutations.py). - """ - # Fetch entities for deletion, if we have access to them - entities = await get_db_rows(db.Contig, session, cerbos_client, principal, where, [], CerbosAction.DELETE) - if len(entities) == 0: - raise PlatformicsException("Unauthorized: Cannot delete entities") - - # Update DB - for entity in entities: - await session.delete(entity) - await session.commit() - return entities diff --git a/test_app/tests/output/api/types/genomic_range.py b/test_app/tests/output/api/types/genomic_range.py deleted file mode 100644 index 1745de4..0000000 --- a/test_app/tests/output/api/types/genomic_range.py +++ /dev/null @@ -1,426 +0,0 @@ -""" -GraphQL type for GenomicRange - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/api/types/class_name.py.j2 instead. -""" - -# ruff: noqa: E501 Line too long - - -import typing -from typing import TYPE_CHECKING, Annotated, Optional, Sequence, Callable - -import database.models as db -import strawberry -from platformics.api.core.helpers import get_db_rows, get_aggregate_db_rows -from api.files import File, FileWhereClause -from api.types.entities import EntityInterface -from api.types.sequencing_read import SequencingReadAggregate, format_sequencing_read_aggregate_output -from cerbos.sdk.client import CerbosClient -from cerbos.sdk.model import Principal, Resource -from fastapi import Depends -from platformics.api.core.errors import PlatformicsException -from platformics.api.core.deps import get_cerbos_client, get_db_session, require_auth_principal -from platformics.api.core.gql_to_sql import ( - aggregator_map, - IntComparators, - UUIDComparators, -) -from platformics.api.core.strawberry_extensions import DependencyExtension -from platformics.security.authorization import CerbosAction -from sqlalchemy import inspect -from sqlalchemy.engine.row import RowMapping -from sqlalchemy.ext.asyncio import AsyncSession -from strawberry import relay -from strawberry.types import Info -from typing_extensions import TypedDict -import enum - -E = typing.TypeVar("E", db.File, db.Entity) -T = typing.TypeVar("T") - -if TYPE_CHECKING: - from api.types.sequencing_read import SequencingReadWhereClause, SequencingRead - - pass -else: - SequencingReadWhereClause = "SequencingReadWhereClause" - SequencingRead = "SequencingRead" - pass - - -""" ------------------------------------------------------------------------------- -Dataloaders ------------------------------------------------------------------------------- -These are batching functions for loading related objects to avoid N+1 queries. -""" - - -@relay.connection( - relay.ListConnection[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]] # type:ignore -) -async def load_sequencing_read_rows( - root: "GenomicRange", - info: Info, - where: Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")] | None = None, -) -> Sequence[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]]: - dataloader = info.context["sqlalchemy_loader"] - mapper = inspect(db.GenomicRange) - relationship = mapper.relationships["sequencing_reads"] - return await dataloader.loader_for(relationship, where).load(root.id) # type:ignore - - -@strawberry.field -async def load_sequencing_read_aggregate_rows( - root: "GenomicRange", - info: Info, - where: Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")] | None = None, -) -> Optional[Annotated["SequencingReadAggregate", strawberry.lazy("api.types.sequencing_read")]]: - selections = info.selected_fields[0].selections[0].selections - dataloader = info.context["sqlalchemy_loader"] - mapper = inspect(db.GenomicRange) - relationship = mapper.relationships["sequencing_reads"] - rows = await dataloader.aggregate_loader_for(relationship, where, selections).load(root.id) # type:ignore - # Aggregate queries always return a single row, so just grab the first one - result = rows[0] if rows else None - aggregate_output = format_sequencing_read_aggregate_output(result) - return SequencingReadAggregate(aggregate=aggregate_output) - - -""" ------------------------------------------------------------------------------- -Dataloader for File object ------------------------------------------------------------------------------- -""" - - -def load_files_from(attr_name: str) -> Callable: - @strawberry.field - async def load_files( - root: "GenomicRange", - info: Info, - where: Annotated["FileWhereClause", strawberry.lazy("api.files")] | None = None, - ) -> Optional[Annotated["File", strawberry.lazy("api.files")]]: - """ - Given a list of GenomicRange IDs for a certain file type, return related Files - """ - dataloader = info.context["sqlalchemy_loader"] - mapper = inspect(db.GenomicRange) - relationship = mapper.relationships[attr_name] - return await dataloader.loader_for(relationship, where).load(getattr(root, f"{attr_name}_id")) # type:ignore - - return load_files - - -""" ------------------------------------------------------------------------------- -Define Strawberry GQL types ------------------------------------------------------------------------------- -""" - -""" -Only let users specify IDs in WHERE clause when mutating data (for safety). -We can extend that list as we gather more use cases from the FE team. -""" - - -@strawberry.input -class GenomicRangeWhereClauseMutations(TypedDict): - id: UUIDComparators | None - - -""" -Supported WHERE clause attributes -""" - - -@strawberry.input -class GenomicRangeWhereClause(TypedDict): - id: UUIDComparators | None - producing_run_id: IntComparators | None - owner_user_id: IntComparators | None - collection_id: IntComparators | None - sequencing_reads: ( - Optional[Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")]] | None - ) - - -""" -Define GenomicRange type -""" - - -@strawberry.type -class GenomicRange(EntityInterface): - id: strawberry.ID - producing_run_id: Optional[int] - owner_user_id: int - collection_id: int - file_id: Optional[strawberry.ID] - file: Optional[Annotated["File", strawberry.lazy("api.files")]] = load_files_from("file") # type: ignore - sequencing_reads: Sequence[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]] = ( - load_sequencing_read_rows - ) # type:ignore - sequencing_reads_aggregate: Optional[ - Annotated["SequencingReadAggregate", strawberry.lazy("api.types.sequencing_read")] - ] = load_sequencing_read_aggregate_rows # type:ignore - - -""" -We need to add this to each Queryable type so that strawberry will accept either our -Strawberry type *or* a SQLAlchemy model instance as a valid response class from a resolver -""" -GenomicRange.__strawberry_definition__.is_type_of = ( # type: ignore - lambda obj, info: type(obj) == db.GenomicRange or type(obj) == GenomicRange -) - -""" ------------------------------------------------------------------------------- -Aggregation types ------------------------------------------------------------------------------- -""" - -""" -Define columns that support numerical aggregations -""" - - -@strawberry.type -class GenomicRangeNumericalColumns: - producing_run_id: Optional[int] = None - owner_user_id: Optional[int] = None - collection_id: Optional[int] = None - - -""" -Define columns that support min/max aggregations -""" - - -@strawberry.type -class GenomicRangeMinMaxColumns: - producing_run_id: Optional[int] = None - owner_user_id: Optional[int] = None - collection_id: Optional[int] = None - - -""" -Define enum of all columns to support count and count(distinct) aggregations -""" - - -@strawberry.enum -class GenomicRangeCountColumns(enum.Enum): - file = "file" - sequencing_reads = "sequencing_reads" - entity_id = "entity_id" - id = "id" - producing_run_id = "producing_run_id" - owner_user_id = "owner_user_id" - collection_id = "collection_id" - created_at = "created_at" - updated_at = "updated_at" - deleted_at = "deleted_at" - - -""" -All supported aggregation functions -""" - - -@strawberry.type -class GenomicRangeAggregateFunctions: - # This is a hack to accept "distinct" and "columns" as arguments to "count" - @strawberry.field - def count( - self, distinct: Optional[bool] = False, columns: Optional[GenomicRangeCountColumns] = None - ) -> Optional[int]: - # Count gets set with the proper value in the resolver, so we just return it here - return self.count # type: ignore - - sum: Optional[GenomicRangeNumericalColumns] = None - avg: Optional[GenomicRangeNumericalColumns] = None - min: Optional[GenomicRangeMinMaxColumns] = None - max: Optional[GenomicRangeMinMaxColumns] = None - stddev: Optional[GenomicRangeNumericalColumns] = None - variance: Optional[GenomicRangeNumericalColumns] = None - - -""" -Wrapper around GenomicRangeAggregateFunctions -""" - - -@strawberry.type -class GenomicRangeAggregate: - aggregate: Optional[GenomicRangeAggregateFunctions] = None - - -""" ------------------------------------------------------------------------------- -Mutation types ------------------------------------------------------------------------------- -""" - - -@strawberry.input() -class GenomicRangeCreateInput: - collection_id: int - file_id: Optional[strawberry.ID] = None - - -@strawberry.input() -class GenomicRangeUpdateInput: - collection_id: Optional[int] = None - file_id: Optional[strawberry.ID] = None - - -""" ------------------------------------------------------------------------------- -Utilities ------------------------------------------------------------------------------- -""" - - -@strawberry.field(extensions=[DependencyExtension()]) -async def resolve_genomic_ranges( - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), - where: Optional[GenomicRangeWhereClause] = None, -) -> typing.Sequence[GenomicRange]: - """ - Resolve GenomicRange objects. Used for queries (see api/queries.py). - """ - return await get_db_rows(db.GenomicRange, session, cerbos_client, principal, where, []) # type: ignore - - -def format_genomic_range_aggregate_output(query_results: RowMapping) -> GenomicRangeAggregateFunctions: - """ - Given a row from the DB containing the results of an aggregate query, - format the results using the proper GraphQL types. - """ - output = GenomicRangeAggregateFunctions() - for aggregate_name, value in query_results.items(): - if aggregate_name == "count": - output.count = value - else: - aggregator_fn, col_name = aggregate_name.split("_", 1) - # Filter out the group_by key from the results if one was provided. - if aggregator_fn in aggregator_map.keys(): - if not getattr(output, aggregator_fn): - if aggregate_name in ["min", "max"]: - setattr(output, aggregator_fn, GenomicRangeMinMaxColumns()) - else: - setattr(output, aggregator_fn, GenomicRangeNumericalColumns()) - setattr(getattr(output, aggregator_fn), col_name, value) - return output - - -@strawberry.field(extensions=[DependencyExtension()]) -async def resolve_genomic_ranges_aggregate( - info: Info, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), - where: Optional[GenomicRangeWhereClause] = None, -) -> GenomicRangeAggregate: - """ - Aggregate values for GenomicRange objects. Used for queries (see api/queries.py). - """ - # Get the selected aggregate functions and columns to operate on - # TODO: not sure why selected_fields is a list - # The first list of selections will always be ["aggregate"], so just grab the first item - selections = info.selected_fields[0].selections[0].selections - rows = await get_aggregate_db_rows(db.GenomicRange, session, cerbos_client, principal, where, selections, []) # type: ignore - aggregate_output = format_genomic_range_aggregate_output(rows) - return GenomicRangeAggregate(aggregate=aggregate_output) - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def create_genomic_range( - input: GenomicRangeCreateInput, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), -) -> db.Entity: - """ - Create a new GenomicRange object. Used for mutations (see api/mutations.py). - """ - params = input.__dict__ - - # Validate that user can create entity in this collection - attr = {"collection_id": input.collection_id} - resource = Resource(id="NEW_ID", kind=db.GenomicRange.__tablename__, attr=attr) - if not cerbos_client.is_allowed("create", principal, resource): - raise PlatformicsException("Unauthorized: Cannot create entity in this collection") - - # Save to DB - params["owner_user_id"] = int(principal.id) - new_entity = db.GenomicRange(**params) - session.add(new_entity) - await session.commit() - return new_entity - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def update_genomic_range( - input: GenomicRangeUpdateInput, - where: GenomicRangeWhereClauseMutations, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), -) -> Sequence[db.Entity]: - """ - Update GenomicRange objects. Used for mutations (see api/mutations.py). - """ - params = input.__dict__ - - # Need at least one thing to update - num_params = len([x for x in params if params[x] is not None]) - if num_params == 0: - raise PlatformicsException("No fields to update") - - # Fetch entities for update, if we have access to them - entities = await get_db_rows(db.GenomicRange, session, cerbos_client, principal, where, [], CerbosAction.UPDATE) - if len(entities) == 0: - raise PlatformicsException("Unauthorized: Cannot update entities") - - # Validate that the user has access to the new collection ID - if input.collection_id: - attr = {"collection_id": input.collection_id} - resource = Resource(id="SOME_ID", kind=db.GenomicRange.__tablename__, attr=attr) - if not cerbos_client.is_allowed(CerbosAction.UPDATE, principal, resource): - raise PlatformicsException("Unauthorized: Cannot access new collection") - - # Update DB - for entity in entities: - for key in params: - if params[key]: - setattr(entity, key, params[key]) - await session.commit() - return entities - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def delete_genomic_range( - where: GenomicRangeWhereClauseMutations, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), -) -> Sequence[db.Entity]: - """ - Delete GenomicRange objects. Used for mutations (see api/mutations.py). - """ - # Fetch entities for deletion, if we have access to them - entities = await get_db_rows(db.GenomicRange, session, cerbos_client, principal, where, [], CerbosAction.DELETE) - if len(entities) == 0: - raise PlatformicsException("Unauthorized: Cannot delete entities") - - # Update DB - for entity in entities: - await session.delete(entity) - await session.commit() - return entities diff --git a/test_app/tests/output/api/types/sample.py b/test_app/tests/output/api/types/sample.py deleted file mode 100644 index 4d902f8..0000000 --- a/test_app/tests/output/api/types/sample.py +++ /dev/null @@ -1,432 +0,0 @@ -""" -GraphQL type for Sample - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/api/types/class_name.py.j2 instead. -""" - -# ruff: noqa: E501 Line too long - - -import typing -from typing import TYPE_CHECKING, Annotated, Optional, Sequence - -import database.models as db -import strawberry -import datetime -from platformics.api.core.helpers import get_db_rows, get_aggregate_db_rows -from api.types.entities import EntityInterface -from api.types.sequencing_read import SequencingReadAggregate, format_sequencing_read_aggregate_output -from cerbos.sdk.client import CerbosClient -from cerbos.sdk.model import Principal, Resource -from fastapi import Depends -from platformics.api.core.errors import PlatformicsException -from platformics.api.core.deps import get_cerbos_client, get_db_session, require_auth_principal -from platformics.api.core.gql_to_sql import ( - aggregator_map, - DatetimeComparators, - IntComparators, - StrComparators, - UUIDComparators, - BoolComparators, -) -from platformics.api.core.strawberry_extensions import DependencyExtension -from platformics.security.authorization import CerbosAction -from sqlalchemy import inspect -from sqlalchemy.engine.row import RowMapping -from sqlalchemy.ext.asyncio import AsyncSession -from strawberry import relay -from strawberry.types import Info -from typing_extensions import TypedDict -import enum - -E = typing.TypeVar("E", db.File, db.Entity) -T = typing.TypeVar("T") - -if TYPE_CHECKING: - from api.types.sequencing_read import SequencingReadWhereClause, SequencingRead - - pass -else: - SequencingReadWhereClause = "SequencingReadWhereClause" - SequencingRead = "SequencingRead" - pass - - -""" ------------------------------------------------------------------------------- -Dataloaders ------------------------------------------------------------------------------- -These are batching functions for loading related objects to avoid N+1 queries. -""" - - -@relay.connection( - relay.ListConnection[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]] # type:ignore -) -async def load_sequencing_read_rows( - root: "Sample", - info: Info, - where: Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")] | None = None, -) -> Sequence[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]]: - dataloader = info.context["sqlalchemy_loader"] - mapper = inspect(db.Sample) - relationship = mapper.relationships["sequencing_reads"] - return await dataloader.loader_for(relationship, where).load(root.id) # type:ignore - - -@strawberry.field -async def load_sequencing_read_aggregate_rows( - root: "Sample", - info: Info, - where: Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")] | None = None, -) -> Optional[Annotated["SequencingReadAggregate", strawberry.lazy("api.types.sequencing_read")]]: - selections = info.selected_fields[0].selections[0].selections - dataloader = info.context["sqlalchemy_loader"] - mapper = inspect(db.Sample) - relationship = mapper.relationships["sequencing_reads"] - rows = await dataloader.aggregate_loader_for(relationship, where, selections).load(root.id) # type:ignore - # Aggregate queries always return a single row, so just grab the first one - result = rows[0] if rows else None - aggregate_output = format_sequencing_read_aggregate_output(result) - return SequencingReadAggregate(aggregate=aggregate_output) - - -""" ------------------------------------------------------------------------------- -Define Strawberry GQL types ------------------------------------------------------------------------------- -""" - -""" -Only let users specify IDs in WHERE clause when mutating data (for safety). -We can extend that list as we gather more use cases from the FE team. -""" - - -@strawberry.input -class SampleWhereClauseMutations(TypedDict): - id: UUIDComparators | None - - -""" -Supported WHERE clause attributes -""" - - -@strawberry.input -class SampleWhereClause(TypedDict): - id: UUIDComparators | None - producing_run_id: IntComparators | None - owner_user_id: IntComparators | None - collection_id: IntComparators | None - name: Optional[StrComparators] | None - sample_type: Optional[StrComparators] | None - water_control: Optional[BoolComparators] | None - collection_date: Optional[DatetimeComparators] | None - collection_location: Optional[StrComparators] | None - description: Optional[StrComparators] | None - sequencing_reads: ( - Optional[Annotated["SequencingReadWhereClause", strawberry.lazy("api.types.sequencing_read")]] | None - ) - - -""" -Define Sample type -""" - - -@strawberry.type -class Sample(EntityInterface): - id: strawberry.ID - producing_run_id: Optional[int] - owner_user_id: int - collection_id: int - name: str - sample_type: str - water_control: bool - collection_date: Optional[datetime.datetime] = None - collection_location: str - description: Optional[str] = None - sequencing_reads: Sequence[Annotated["SequencingRead", strawberry.lazy("api.types.sequencing_read")]] = ( - load_sequencing_read_rows - ) # type:ignore - sequencing_reads_aggregate: Optional[ - Annotated["SequencingReadAggregate", strawberry.lazy("api.types.sequencing_read")] - ] = load_sequencing_read_aggregate_rows # type:ignore - - -""" -We need to add this to each Queryable type so that strawberry will accept either our -Strawberry type *or* a SQLAlchemy model instance as a valid response class from a resolver -""" -Sample.__strawberry_definition__.is_type_of = ( # type: ignore - lambda obj, info: type(obj) == db.Sample or type(obj) == Sample -) - -""" ------------------------------------------------------------------------------- -Aggregation types ------------------------------------------------------------------------------- -""" - -""" -Define columns that support numerical aggregations -""" - - -@strawberry.type -class SampleNumericalColumns: - producing_run_id: Optional[int] = None - owner_user_id: Optional[int] = None - collection_id: Optional[int] = None - - -""" -Define columns that support min/max aggregations -""" - - -@strawberry.type -class SampleMinMaxColumns: - producing_run_id: Optional[int] = None - owner_user_id: Optional[int] = None - collection_id: Optional[int] = None - name: Optional[str] = None - sample_type: Optional[str] = None - collection_date: Optional[datetime.datetime] = None - collection_location: Optional[str] = None - description: Optional[str] = None - - -""" -Define enum of all columns to support count and count(distinct) aggregations -""" - - -@strawberry.enum -class SampleCountColumns(enum.Enum): - name = "name" - sample_type = "sample_type" - water_control = "water_control" - collection_date = "collection_date" - collection_location = "collection_location" - description = "description" - sequencing_reads = "sequencing_reads" - entity_id = "entity_id" - id = "id" - producing_run_id = "producing_run_id" - owner_user_id = "owner_user_id" - collection_id = "collection_id" - created_at = "created_at" - updated_at = "updated_at" - deleted_at = "deleted_at" - - -""" -All supported aggregation functions -""" - - -@strawberry.type -class SampleAggregateFunctions: - # This is a hack to accept "distinct" and "columns" as arguments to "count" - @strawberry.field - def count(self, distinct: Optional[bool] = False, columns: Optional[SampleCountColumns] = None) -> Optional[int]: - # Count gets set with the proper value in the resolver, so we just return it here - return self.count # type: ignore - - sum: Optional[SampleNumericalColumns] = None - avg: Optional[SampleNumericalColumns] = None - min: Optional[SampleMinMaxColumns] = None - max: Optional[SampleMinMaxColumns] = None - stddev: Optional[SampleNumericalColumns] = None - variance: Optional[SampleNumericalColumns] = None - - -""" -Wrapper around SampleAggregateFunctions -""" - - -@strawberry.type -class SampleAggregate: - aggregate: Optional[SampleAggregateFunctions] = None - - -""" ------------------------------------------------------------------------------- -Mutation types ------------------------------------------------------------------------------- -""" - - -@strawberry.input() -class SampleCreateInput: - collection_id: int - name: str - sample_type: str - water_control: bool - collection_date: Optional[datetime.datetime] = None - collection_location: str - description: Optional[str] = None - - -@strawberry.input() -class SampleUpdateInput: - collection_id: Optional[int] = None - name: Optional[str] = None - sample_type: Optional[str] = None - water_control: Optional[bool] = None - collection_date: Optional[datetime.datetime] = None - collection_location: Optional[str] = None - description: Optional[str] = None - - -""" ------------------------------------------------------------------------------- -Utilities ------------------------------------------------------------------------------- -""" - - -@strawberry.field(extensions=[DependencyExtension()]) -async def resolve_samples( - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), - where: Optional[SampleWhereClause] = None, -) -> typing.Sequence[Sample]: - """ - Resolve Sample objects. Used for queries (see api/queries.py). - """ - return await get_db_rows(db.Sample, session, cerbos_client, principal, where, []) # type: ignore - - -def format_sample_aggregate_output(query_results: RowMapping) -> SampleAggregateFunctions: - """ - Given a row from the DB containing the results of an aggregate query, - format the results using the proper GraphQL types. - """ - output = SampleAggregateFunctions() - for aggregate_name, value in query_results.items(): - if aggregate_name == "count": - output.count = value - else: - aggregator_fn, col_name = aggregate_name.split("_", 1) - # Filter out the group_by key from the results if one was provided. - if aggregator_fn in aggregator_map.keys(): - if not getattr(output, aggregator_fn): - if aggregate_name in ["min", "max"]: - setattr(output, aggregator_fn, SampleMinMaxColumns()) - else: - setattr(output, aggregator_fn, SampleNumericalColumns()) - setattr(getattr(output, aggregator_fn), col_name, value) - return output - - -@strawberry.field(extensions=[DependencyExtension()]) -async def resolve_samples_aggregate( - info: Info, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), - where: Optional[SampleWhereClause] = None, -) -> SampleAggregate: - """ - Aggregate values for Sample objects. Used for queries (see api/queries.py). - """ - # Get the selected aggregate functions and columns to operate on - # TODO: not sure why selected_fields is a list - # The first list of selections will always be ["aggregate"], so just grab the first item - selections = info.selected_fields[0].selections[0].selections - rows = await get_aggregate_db_rows(db.Sample, session, cerbos_client, principal, where, selections, []) # type: ignore - aggregate_output = format_sample_aggregate_output(rows) - return SampleAggregate(aggregate=aggregate_output) - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def create_sample( - input: SampleCreateInput, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), -) -> db.Entity: - """ - Create a new Sample object. Used for mutations (see api/mutations.py). - """ - params = input.__dict__ - - # Validate that user can create entity in this collection - attr = {"collection_id": input.collection_id} - resource = Resource(id="NEW_ID", kind=db.Sample.__tablename__, attr=attr) - if not cerbos_client.is_allowed("create", principal, resource): - raise PlatformicsException("Unauthorized: Cannot create entity in this collection") - - # Save to DB - params["owner_user_id"] = int(principal.id) - new_entity = db.Sample(**params) - session.add(new_entity) - await session.commit() - return new_entity - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def update_sample( - input: SampleUpdateInput, - where: SampleWhereClauseMutations, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), -) -> Sequence[db.Entity]: - """ - Update Sample objects. Used for mutations (see api/mutations.py). - """ - params = input.__dict__ - - # Need at least one thing to update - num_params = len([x for x in params if params[x] is not None]) - if num_params == 0: - raise PlatformicsException("No fields to update") - - # Fetch entities for update, if we have access to them - entities = await get_db_rows(db.Sample, session, cerbos_client, principal, where, [], CerbosAction.UPDATE) - if len(entities) == 0: - raise PlatformicsException("Unauthorized: Cannot update entities") - - # Validate that the user has access to the new collection ID - if input.collection_id: - attr = {"collection_id": input.collection_id} - resource = Resource(id="SOME_ID", kind=db.Sample.__tablename__, attr=attr) - if not cerbos_client.is_allowed(CerbosAction.UPDATE, principal, resource): - raise PlatformicsException("Unauthorized: Cannot access new collection") - - # Update DB - for entity in entities: - for key in params: - if params[key]: - setattr(entity, key, params[key]) - await session.commit() - return entities - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def delete_sample( - where: SampleWhereClauseMutations, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), -) -> Sequence[db.Entity]: - """ - Delete Sample objects. Used for mutations (see api/mutations.py). - """ - # Fetch entities for deletion, if we have access to them - entities = await get_db_rows(db.Sample, session, cerbos_client, principal, where, [], CerbosAction.DELETE) - if len(entities) == 0: - raise PlatformicsException("Unauthorized: Cannot delete entities") - - # Update DB - for entity in entities: - await session.delete(entity) - await session.commit() - return entities diff --git a/test_app/tests/output/api/types/sequencing_read.py b/test_app/tests/output/api/types/sequencing_read.py deleted file mode 100644 index 28a00e1..0000000 --- a/test_app/tests/output/api/types/sequencing_read.py +++ /dev/null @@ -1,492 +0,0 @@ -""" -GraphQL type for SequencingRead - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/api/types/class_name.py.j2 instead. -""" - -# ruff: noqa: E501 Line too long - - -import typing -from typing import TYPE_CHECKING, Annotated, Optional, Sequence, Callable - -import database.models as db -import strawberry -from platformics.api.core.helpers import get_db_rows, get_aggregate_db_rows -from api.files import File, FileWhereClause -from api.types.entities import EntityInterface -from api.types.contig import ContigAggregate, format_contig_aggregate_output -from cerbos.sdk.client import CerbosClient -from cerbos.sdk.model import Principal, Resource -from fastapi import Depends -from platformics.api.core.errors import PlatformicsException -from platformics.api.core.deps import get_cerbos_client, get_db_session, require_auth_principal -from platformics.api.core.gql_to_sql import ( - aggregator_map, - EnumComparators, - IntComparators, - UUIDComparators, - BoolComparators, -) -from platformics.api.core.strawberry_extensions import DependencyExtension -from platformics.security.authorization import CerbosAction -from sqlalchemy import inspect -from sqlalchemy.engine.row import RowMapping -from sqlalchemy.ext.asyncio import AsyncSession -from strawberry import relay -from strawberry.types import Info -from typing_extensions import TypedDict -import enum -from support.enums import SequencingProtocol, SequencingTechnology, NucleicAcid - -E = typing.TypeVar("E", db.File, db.Entity) -T = typing.TypeVar("T") - -if TYPE_CHECKING: - from api.types.sample import SampleWhereClause, Sample - from api.types.genomic_range import GenomicRangeWhereClause, GenomicRange - from api.types.contig import ContigWhereClause, Contig - - pass -else: - SampleWhereClause = "SampleWhereClause" - Sample = "Sample" - GenomicRangeWhereClause = "GenomicRangeWhereClause" - GenomicRange = "GenomicRange" - ContigWhereClause = "ContigWhereClause" - Contig = "Contig" - pass - - -""" ------------------------------------------------------------------------------- -Dataloaders ------------------------------------------------------------------------------- -These are batching functions for loading related objects to avoid N+1 queries. -""" - - -@strawberry.field -async def load_sample_rows( - root: "SequencingRead", - info: Info, - where: Annotated["SampleWhereClause", strawberry.lazy("api.types.sample")] | None = None, -) -> Optional[Annotated["Sample", strawberry.lazy("api.types.sample")]]: - dataloader = info.context["sqlalchemy_loader"] - mapper = inspect(db.SequencingRead) - relationship = mapper.relationships["sample"] - return await dataloader.loader_for(relationship, where).load(root.sample_id) # type:ignore - - -@strawberry.field -async def load_genomic_range_rows( - root: "SequencingRead", - info: Info, - where: Annotated["GenomicRangeWhereClause", strawberry.lazy("api.types.genomic_range")] | None = None, -) -> Optional[Annotated["GenomicRange", strawberry.lazy("api.types.genomic_range")]]: - dataloader = info.context["sqlalchemy_loader"] - mapper = inspect(db.SequencingRead) - relationship = mapper.relationships["primer_file"] - return await dataloader.loader_for(relationship, where).load(root.primer_file_id) # type:ignore - - -@relay.connection( - relay.ListConnection[Annotated["Contig", strawberry.lazy("api.types.contig")]] # type:ignore -) -async def load_contig_rows( - root: "SequencingRead", - info: Info, - where: Annotated["ContigWhereClause", strawberry.lazy("api.types.contig")] | None = None, -) -> Sequence[Annotated["Contig", strawberry.lazy("api.types.contig")]]: - dataloader = info.context["sqlalchemy_loader"] - mapper = inspect(db.SequencingRead) - relationship = mapper.relationships["contigs"] - return await dataloader.loader_for(relationship, where).load(root.id) # type:ignore - - -@strawberry.field -async def load_contig_aggregate_rows( - root: "SequencingRead", - info: Info, - where: Annotated["ContigWhereClause", strawberry.lazy("api.types.contig")] | None = None, -) -> Optional[Annotated["ContigAggregate", strawberry.lazy("api.types.contig")]]: - selections = info.selected_fields[0].selections[0].selections - dataloader = info.context["sqlalchemy_loader"] - mapper = inspect(db.SequencingRead) - relationship = mapper.relationships["contigs"] - rows = await dataloader.aggregate_loader_for(relationship, where, selections).load(root.id) # type:ignore - # Aggregate queries always return a single row, so just grab the first one - result = rows[0] if rows else None - aggregate_output = format_contig_aggregate_output(result) - return ContigAggregate(aggregate=aggregate_output) - - -""" ------------------------------------------------------------------------------- -Dataloader for File object ------------------------------------------------------------------------------- -""" - - -def load_files_from(attr_name: str) -> Callable: - @strawberry.field - async def load_files( - root: "SequencingRead", - info: Info, - where: Annotated["FileWhereClause", strawberry.lazy("api.files")] | None = None, - ) -> Optional[Annotated["File", strawberry.lazy("api.files")]]: - """ - Given a list of SequencingRead IDs for a certain file type, return related Files - """ - dataloader = info.context["sqlalchemy_loader"] - mapper = inspect(db.SequencingRead) - relationship = mapper.relationships[attr_name] - return await dataloader.loader_for(relationship, where).load(getattr(root, f"{attr_name}_id")) # type:ignore - - return load_files - - -""" ------------------------------------------------------------------------------- -Define Strawberry GQL types ------------------------------------------------------------------------------- -""" - -""" -Only let users specify IDs in WHERE clause when mutating data (for safety). -We can extend that list as we gather more use cases from the FE team. -""" - - -@strawberry.input -class SequencingReadWhereClauseMutations(TypedDict): - id: UUIDComparators | None - - -""" -Supported WHERE clause attributes -""" - - -@strawberry.input -class SequencingReadWhereClause(TypedDict): - id: UUIDComparators | None - producing_run_id: IntComparators | None - owner_user_id: IntComparators | None - collection_id: IntComparators | None - sample: Optional[Annotated["SampleWhereClause", strawberry.lazy("api.types.sample")]] | None - protocol: Optional[EnumComparators[SequencingProtocol]] | None - technology: Optional[EnumComparators[SequencingTechnology]] | None - nucleic_acid: Optional[EnumComparators[NucleicAcid]] | None - primer_file: Optional[Annotated["GenomicRangeWhereClause", strawberry.lazy("api.types.genomic_range")]] | None - contigs: Optional[Annotated["ContigWhereClause", strawberry.lazy("api.types.contig")]] | None - clearlabs_export: Optional[BoolComparators] | None - - -""" -Define SequencingRead type -""" - - -@strawberry.type -class SequencingRead(EntityInterface): - id: strawberry.ID - producing_run_id: Optional[int] - owner_user_id: int - collection_id: int - sample: Optional[Annotated["Sample", strawberry.lazy("api.types.sample")]] = load_sample_rows # type:ignore - protocol: SequencingProtocol - r1_file_id: Optional[strawberry.ID] - r1_file: Optional[Annotated["File", strawberry.lazy("api.files")]] = load_files_from("r1_file") # type: ignore - r2_file_id: Optional[strawberry.ID] - r2_file: Optional[Annotated["File", strawberry.lazy("api.files")]] = load_files_from("r2_file") # type: ignore - technology: SequencingTechnology - nucleic_acid: NucleicAcid - primer_file: Optional[Annotated["GenomicRange", strawberry.lazy("api.types.genomic_range")]] = ( - load_genomic_range_rows - ) # type:ignore - contigs: Sequence[Annotated["Contig", strawberry.lazy("api.types.contig")]] = load_contig_rows # type:ignore - contigs_aggregate: Optional[Annotated["ContigAggregate", strawberry.lazy("api.types.contig")]] = ( - load_contig_aggregate_rows - ) # type:ignore - clearlabs_export: bool - - -""" -We need to add this to each Queryable type so that strawberry will accept either our -Strawberry type *or* a SQLAlchemy model instance as a valid response class from a resolver -""" -SequencingRead.__strawberry_definition__.is_type_of = ( # type: ignore - lambda obj, info: type(obj) == db.SequencingRead or type(obj) == SequencingRead -) - -""" ------------------------------------------------------------------------------- -Aggregation types ------------------------------------------------------------------------------- -""" - -""" -Define columns that support numerical aggregations -""" - - -@strawberry.type -class SequencingReadNumericalColumns: - producing_run_id: Optional[int] = None - owner_user_id: Optional[int] = None - collection_id: Optional[int] = None - - -""" -Define columns that support min/max aggregations -""" - - -@strawberry.type -class SequencingReadMinMaxColumns: - producing_run_id: Optional[int] = None - owner_user_id: Optional[int] = None - collection_id: Optional[int] = None - - -""" -Define enum of all columns to support count and count(distinct) aggregations -""" - - -@strawberry.enum -class SequencingReadCountColumns(enum.Enum): - sample = "sample" - protocol = "protocol" - r1_file = "r1_file" - r2_file = "r2_file" - technology = "technology" - nucleic_acid = "nucleic_acid" - primer_file = "primer_file" - contigs = "contigs" - clearlabs_export = "clearlabs_export" - entity_id = "entity_id" - id = "id" - producing_run_id = "producing_run_id" - owner_user_id = "owner_user_id" - collection_id = "collection_id" - created_at = "created_at" - updated_at = "updated_at" - deleted_at = "deleted_at" - - -""" -All supported aggregation functions -""" - - -@strawberry.type -class SequencingReadAggregateFunctions: - # This is a hack to accept "distinct" and "columns" as arguments to "count" - @strawberry.field - def count( - self, distinct: Optional[bool] = False, columns: Optional[SequencingReadCountColumns] = None - ) -> Optional[int]: - # Count gets set with the proper value in the resolver, so we just return it here - return self.count # type: ignore - - sum: Optional[SequencingReadNumericalColumns] = None - avg: Optional[SequencingReadNumericalColumns] = None - min: Optional[SequencingReadMinMaxColumns] = None - max: Optional[SequencingReadMinMaxColumns] = None - stddev: Optional[SequencingReadNumericalColumns] = None - variance: Optional[SequencingReadNumericalColumns] = None - - -""" -Wrapper around SequencingReadAggregateFunctions -""" - - -@strawberry.type -class SequencingReadAggregate: - aggregate: Optional[SequencingReadAggregateFunctions] = None - - -""" ------------------------------------------------------------------------------- -Mutation types ------------------------------------------------------------------------------- -""" - - -@strawberry.input() -class SequencingReadCreateInput: - collection_id: int - sample_id: Optional[strawberry.ID] = None - protocol: SequencingProtocol - r1_file_id: Optional[strawberry.ID] = None - r2_file_id: Optional[strawberry.ID] = None - technology: SequencingTechnology - nucleic_acid: NucleicAcid - primer_file_id: Optional[strawberry.ID] = None - clearlabs_export: bool - - -@strawberry.input() -class SequencingReadUpdateInput: - collection_id: Optional[int] = None - sample_id: Optional[strawberry.ID] = None - protocol: Optional[SequencingProtocol] = None - r1_file_id: Optional[strawberry.ID] = None - r2_file_id: Optional[strawberry.ID] = None - technology: Optional[SequencingTechnology] = None - nucleic_acid: Optional[NucleicAcid] = None - primer_file_id: Optional[strawberry.ID] = None - clearlabs_export: Optional[bool] = None - - -""" ------------------------------------------------------------------------------- -Utilities ------------------------------------------------------------------------------- -""" - - -@strawberry.field(extensions=[DependencyExtension()]) -async def resolve_sequencing_reads( - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), - where: Optional[SequencingReadWhereClause] = None, -) -> typing.Sequence[SequencingRead]: - """ - Resolve SequencingRead objects. Used for queries (see api/queries.py). - """ - return await get_db_rows(db.SequencingRead, session, cerbos_client, principal, where, []) # type: ignore - - -def format_sequencing_read_aggregate_output(query_results: RowMapping) -> SequencingReadAggregateFunctions: - """ - Given a row from the DB containing the results of an aggregate query, - format the results using the proper GraphQL types. - """ - output = SequencingReadAggregateFunctions() - for aggregate_name, value in query_results.items(): - if aggregate_name == "count": - output.count = value - else: - aggregator_fn, col_name = aggregate_name.split("_", 1) - # Filter out the group_by key from the results if one was provided. - if aggregator_fn in aggregator_map.keys(): - if not getattr(output, aggregator_fn): - if aggregate_name in ["min", "max"]: - setattr(output, aggregator_fn, SequencingReadMinMaxColumns()) - else: - setattr(output, aggregator_fn, SequencingReadNumericalColumns()) - setattr(getattr(output, aggregator_fn), col_name, value) - return output - - -@strawberry.field(extensions=[DependencyExtension()]) -async def resolve_sequencing_reads_aggregate( - info: Info, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), - where: Optional[SequencingReadWhereClause] = None, -) -> SequencingReadAggregate: - """ - Aggregate values for SequencingRead objects. Used for queries (see api/queries.py). - """ - # Get the selected aggregate functions and columns to operate on - # TODO: not sure why selected_fields is a list - # The first list of selections will always be ["aggregate"], so just grab the first item - selections = info.selected_fields[0].selections[0].selections - rows = await get_aggregate_db_rows(db.SequencingRead, session, cerbos_client, principal, where, selections, []) # type: ignore - aggregate_output = format_sequencing_read_aggregate_output(rows) - return SequencingReadAggregate(aggregate=aggregate_output) - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def create_sequencing_read( - input: SequencingReadCreateInput, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), -) -> db.Entity: - """ - Create a new SequencingRead object. Used for mutations (see api/mutations.py). - """ - params = input.__dict__ - - # Validate that user can create entity in this collection - attr = {"collection_id": input.collection_id} - resource = Resource(id="NEW_ID", kind=db.SequencingRead.__tablename__, attr=attr) - if not cerbos_client.is_allowed("create", principal, resource): - raise PlatformicsException("Unauthorized: Cannot create entity in this collection") - - # Save to DB - params["owner_user_id"] = int(principal.id) - new_entity = db.SequencingRead(**params) - session.add(new_entity) - await session.commit() - return new_entity - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def update_sequencing_read( - input: SequencingReadUpdateInput, - where: SequencingReadWhereClauseMutations, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), -) -> Sequence[db.Entity]: - """ - Update SequencingRead objects. Used for mutations (see api/mutations.py). - """ - params = input.__dict__ - - # Need at least one thing to update - num_params = len([x for x in params if params[x] is not None]) - if num_params == 0: - raise PlatformicsException("No fields to update") - - # Fetch entities for update, if we have access to them - entities = await get_db_rows(db.SequencingRead, session, cerbos_client, principal, where, [], CerbosAction.UPDATE) - if len(entities) == 0: - raise PlatformicsException("Unauthorized: Cannot update entities") - - # Validate that the user has access to the new collection ID - if input.collection_id: - attr = {"collection_id": input.collection_id} - resource = Resource(id="SOME_ID", kind=db.SequencingRead.__tablename__, attr=attr) - if not cerbos_client.is_allowed(CerbosAction.UPDATE, principal, resource): - raise PlatformicsException("Unauthorized: Cannot access new collection") - - # Update DB - for entity in entities: - for key in params: - if params[key]: - setattr(entity, key, params[key]) - await session.commit() - return entities - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def delete_sequencing_read( - where: SequencingReadWhereClauseMutations, - session: AsyncSession = Depends(get_db_session, use_cache=False), - cerbos_client: CerbosClient = Depends(get_cerbos_client), - principal: Principal = Depends(require_auth_principal), -) -> Sequence[db.Entity]: - """ - Delete SequencingRead objects. Used for mutations (see api/mutations.py). - """ - # Fetch entities for deletion, if we have access to them - entities = await get_db_rows(db.SequencingRead, session, cerbos_client, principal, where, [], CerbosAction.DELETE) - if len(entities) == 0: - raise PlatformicsException("Unauthorized: Cannot delete entities") - - # Update DB - for entity in entities: - await session.delete(entity) - await session.commit() - return entities diff --git a/test_app/tests/output/cerbos/policies/contig.yaml b/test_app/tests/output/cerbos/policies/contig.yaml deleted file mode 100644 index 6df5ad5..0000000 --- a/test_app/tests/output/cerbos/policies/contig.yaml +++ /dev/null @@ -1,22 +0,0 @@ -# Auto-generated by running 'make codegen'. Do not edit. -# Make changes to the template codegen/templates/cerbos/policies/class_name.yaml.j2 instead. -# yaml-language-server: $schema=https://api.cerbos.dev/latest/cerbos/policy/v1/Policy.schema.json -apiVersion: api.cerbos.dev/v1 -resourcePolicy: - version: "default" - importDerivedRoles: - - common_roles - resource: "contig" - rules: - - actions: ['view', 'create', 'update'] - effect: EFFECT_ALLOW - derivedRoles: - - project_member - - - actions: ['download', 'delete'] - effect: EFFECT_ALLOW - derivedRoles: - - owner - schemas: - principalSchema: - ref: cerbos:///principal.json diff --git a/test_app/tests/output/cerbos/policies/genomic_range.yaml b/test_app/tests/output/cerbos/policies/genomic_range.yaml deleted file mode 100644 index cdee2d7..0000000 --- a/test_app/tests/output/cerbos/policies/genomic_range.yaml +++ /dev/null @@ -1,22 +0,0 @@ -# Auto-generated by running 'make codegen'. Do not edit. -# Make changes to the template codegen/templates/cerbos/policies/class_name.yaml.j2 instead. -# yaml-language-server: $schema=https://api.cerbos.dev/latest/cerbos/policy/v1/Policy.schema.json -apiVersion: api.cerbos.dev/v1 -resourcePolicy: - version: "default" - importDerivedRoles: - - common_roles - resource: "genomic_range" - rules: - - actions: ['view', 'create', 'update'] - effect: EFFECT_ALLOW - derivedRoles: - - project_member - - - actions: ['download', 'delete'] - effect: EFFECT_ALLOW - derivedRoles: - - owner - schemas: - principalSchema: - ref: cerbos:///principal.json diff --git a/test_app/tests/output/cerbos/policies/sample.yaml b/test_app/tests/output/cerbos/policies/sample.yaml deleted file mode 100644 index b34e82e..0000000 --- a/test_app/tests/output/cerbos/policies/sample.yaml +++ /dev/null @@ -1,22 +0,0 @@ -# Auto-generated by running 'make codegen'. Do not edit. -# Make changes to the template codegen/templates/cerbos/policies/class_name.yaml.j2 instead. -# yaml-language-server: $schema=https://api.cerbos.dev/latest/cerbos/policy/v1/Policy.schema.json -apiVersion: api.cerbos.dev/v1 -resourcePolicy: - version: "default" - importDerivedRoles: - - common_roles - resource: "sample" - rules: - - actions: ['view', 'create', 'update'] - effect: EFFECT_ALLOW - derivedRoles: - - project_member - - - actions: ['download', 'delete'] - effect: EFFECT_ALLOW - derivedRoles: - - owner - schemas: - principalSchema: - ref: cerbos:///principal.json diff --git a/test_app/tests/output/cerbos/policies/sequencing_read.yaml b/test_app/tests/output/cerbos/policies/sequencing_read.yaml deleted file mode 100644 index 964dd3c..0000000 --- a/test_app/tests/output/cerbos/policies/sequencing_read.yaml +++ /dev/null @@ -1,22 +0,0 @@ -# Auto-generated by running 'make codegen'. Do not edit. -# Make changes to the template codegen/templates/cerbos/policies/class_name.yaml.j2 instead. -# yaml-language-server: $schema=https://api.cerbos.dev/latest/cerbos/policy/v1/Policy.schema.json -apiVersion: api.cerbos.dev/v1 -resourcePolicy: - version: "default" - importDerivedRoles: - - common_roles - resource: "sequencing_read" - rules: - - actions: ['view', 'create', 'update'] - effect: EFFECT_ALLOW - derivedRoles: - - project_member - - - actions: ['download', 'delete'] - effect: EFFECT_ALLOW - derivedRoles: - - owner - schemas: - principalSchema: - ref: cerbos:///principal.json diff --git a/test_app/tests/output/database/models/__init__.py b/test_app/tests/output/database/models/__init__.py deleted file mode 100644 index 76a31d7..0000000 --- a/test_app/tests/output/database/models/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -Make database models importable - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/database/models/__init__.py.j2 instead. -""" - -# isort: skip_file - -from sqlalchemy.orm import configure_mappers - -from platformics.database.models.base import Base, meta, Entity # noqa: F401 -from database.models.sample import Sample # noqa: F401 -from database.models.sequencing_read import SequencingRead # noqa: F401 -from database.models.genomic_range import GenomicRange # noqa: F401 -from database.models.contig import Contig # noqa: F401 - -from database.models.file import File, FileStatus # noqa: F401 - -configure_mappers() diff --git a/test_app/tests/output/database/models/contig.py b/test_app/tests/output/database/models/contig.py deleted file mode 100644 index 9cfffb4..0000000 --- a/test_app/tests/output/database/models/contig.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -SQLAlchemy database model for Contig - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/database/models/class_name.py.j2 instead. -""" - -import uuid -from typing import TYPE_CHECKING - -from platformics.database.models.base import Entity -from sqlalchemy import ForeignKey, String -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Mapped, mapped_column, relationship - -if TYPE_CHECKING: - from database.models.file import File - from database.models.sequencing_read import SequencingRead -else: - File = "File" - SequencingRead = "SequencingRead" - - -class Contig(Entity): - __tablename__ = "contig" - __mapper_args__ = {"polymorphic_identity": __tablename__, "polymorphic_load": "inline"} - sequencing_read_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("sequencing_read.entity_id"), nullable=True) - sequencing_read: Mapped["SequencingRead"] = relationship( - "SequencingRead", back_populates="contigs", foreign_keys=sequencing_read_id - ) - sequence: Mapped[str] = mapped_column(String, nullable=False) - entity_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), nullable=False, primary_key=True) diff --git a/test_app/tests/output/database/models/genomic_range.py b/test_app/tests/output/database/models/genomic_range.py deleted file mode 100644 index 059dab0..0000000 --- a/test_app/tests/output/database/models/genomic_range.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -SQLAlchemy database model for GenomicRange - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/database/models/class_name.py.j2 instead. -""" - -import uuid -from typing import TYPE_CHECKING - -from platformics.database.models.base import Entity -from sqlalchemy import ForeignKey -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Mapped, mapped_column, relationship - -if TYPE_CHECKING: - from database.models.file import File - from database.models.sequencing_read import SequencingRead -else: - File = "File" - SequencingRead = "SequencingRead" - - -class GenomicRange(Entity): - __tablename__ = "genomic_range" - __mapper_args__ = {"polymorphic_identity": __tablename__, "polymorphic_load": "inline"} - file_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("file.id"), nullable=True) - file: Mapped["File"] = relationship("File", foreign_keys=file_id) - sequencing_reads: Mapped[list[SequencingRead]] = relationship( - "SequencingRead", back_populates="primer_file", uselist=True, foreign_keys="SequencingRead.primer_file_id" - ) - entity_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), nullable=False, primary_key=True) diff --git a/test_app/tests/output/database/models/sample.py b/test_app/tests/output/database/models/sample.py deleted file mode 100644 index a077737..0000000 --- a/test_app/tests/output/database/models/sample.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -SQLAlchemy database model for Sample - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/database/models/class_name.py.j2 instead. -""" - -import uuid -import datetime -from typing import TYPE_CHECKING - -from platformics.database.models.base import Entity -from sqlalchemy import ForeignKey, String, Boolean, DateTime -from sqlalchemy.orm import Mapped, mapped_column, relationship - -if TYPE_CHECKING: - from database.models.file import File - from database.models.sequencing_read import SequencingRead -else: - File = "File" - SequencingRead = "SequencingRead" - - -class Sample(Entity): - __tablename__ = "sample" - __mapper_args__ = {"polymorphic_identity": __tablename__, "polymorphic_load": "inline"} - name: Mapped[str] = mapped_column(String, nullable=False) - sample_type: Mapped[str] = mapped_column(String, nullable=False) - water_control: Mapped[bool] = mapped_column(Boolean, nullable=False) - collection_date: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=True) - collection_location: Mapped[str] = mapped_column(String, nullable=False) - description: Mapped[str] = mapped_column(String, nullable=True) - sequencing_reads: Mapped[list[SequencingRead]] = relationship( - "SequencingRead", back_populates="sample", uselist=True, foreign_keys="SequencingRead.sample_id" - ) - entity_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), nullable=False, primary_key=True) diff --git a/test_app/tests/output/database/models/sequencing_read.py b/test_app/tests/output/database/models/sequencing_read.py deleted file mode 100644 index 6d24748..0000000 --- a/test_app/tests/output/database/models/sequencing_read.py +++ /dev/null @@ -1,51 +0,0 @@ -""" -SQLAlchemy database model for SequencingRead - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/database/models/class_name.py.j2 instead. -""" - -import uuid -from typing import TYPE_CHECKING - -from platformics.database.models.base import Entity -from sqlalchemy import ForeignKey, Enum, Boolean -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Mapped, mapped_column, relationship -from support.enums import SequencingProtocol, SequencingTechnology, NucleicAcid - -if TYPE_CHECKING: - from database.models.file import File - from database.models.sample import Sample - from database.models.genomic_range import GenomicRange - from database.models.contig import Contig -else: - File = "File" - Sample = "Sample" - GenomicRange = "GenomicRange" - Contig = "Contig" - - -class SequencingRead(Entity): - __tablename__ = "sequencing_read" - __mapper_args__ = {"polymorphic_identity": __tablename__, "polymorphic_load": "inline"} - sample_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("sample.entity_id"), nullable=True) - sample: Mapped["Sample"] = relationship("Sample", back_populates="sequencing_reads", foreign_keys=sample_id) - protocol: Mapped[SequencingProtocol] = mapped_column(Enum(SequencingProtocol, native_enum=False), nullable=False) - r1_file_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("file.id"), nullable=True) - r1_file: Mapped["File"] = relationship("File", foreign_keys=r1_file_id) - r2_file_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("file.id"), nullable=True) - r2_file: Mapped["File"] = relationship("File", foreign_keys=r2_file_id) - technology: Mapped[SequencingTechnology] = mapped_column( - Enum(SequencingTechnology, native_enum=False), nullable=False - ) - nucleic_acid: Mapped[NucleicAcid] = mapped_column(Enum(NucleicAcid, native_enum=False), nullable=False) - primer_file_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("genomic_range.entity_id"), nullable=True) - primer_file: Mapped["GenomicRange"] = relationship( - "GenomicRange", back_populates="sequencing_reads", foreign_keys=primer_file_id - ) - contigs: Mapped[list[Contig]] = relationship( - "Contig", back_populates="sequencing_read", uselist=True, foreign_keys="Contig.sequencing_read_id" - ) - clearlabs_export: Mapped[bool] = mapped_column(Boolean, nullable=False) - entity_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), nullable=False, primary_key=True) diff --git a/test_app/tests/output/support/enums.py b/test_app/tests/output/support/enums.py deleted file mode 100644 index 1b1e7f3..0000000 --- a/test_app/tests/output/support/enums.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -GraphQL enums - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/support/enums.py.j2 instead. -""" - -import strawberry -import enum - - -@strawberry.enum -class FileStatus(enum.Enum): - SUCCESS = "SUCCESS" - FAILED = "FAILED" - PENDING = "PENDING" - - -@strawberry.enum -class FileAcessProtocol(enum.Enum): - s3 = "s3" - - -@strawberry.enum -class NucleicAcid(enum.Enum): - RNA = "RNA" - DNA = "DNA" - - -@strawberry.enum -class SequencingProtocol(enum.Enum): - ampliseq = "ampliseq" - artic = "artic" - artic_v3 = "artic_v3" - artic_v4 = "artic_v4" - artic_v5 = "artic_v5" - combined_msspe_artic = "combined_msspe_artic" - covidseq = "covidseq" - midnight = "midnight" - msspe = "msspe" - snap = "snap" - varskip = "varskip" - easyseq = "easyseq" - - -@strawberry.enum -class SequencingTechnology(enum.Enum): - Illumina = "Illumina" - Nanopore = "Nanopore" diff --git a/test_app/tests/output/test_infra/factories/contig.py b/test_app/tests/output/test_infra/factories/contig.py deleted file mode 100644 index a908c54..0000000 --- a/test_app/tests/output/test_infra/factories/contig.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Factory for generating Contig objects. - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/test_infra/factories/class_name.py.j2 instead. -""" - -# ruff: noqa: E501 Line too long - -import factory -from database.models import Contig -from test_infra.factories.main import CommonFactory -from test_infra.factories.sequencing_read import SequencingReadFactory -from factory import Faker, fuzzy -from faker_biology.bioseq import Bioseq -from faker_biology.physiology import Organ -from faker_enum import EnumProvider - -Faker.add_provider(Bioseq) -Faker.add_provider(Organ) -Faker.add_provider(EnumProvider) - - -class ContigFactory(CommonFactory): - class Meta: - sqlalchemy_session = None # workaround for a bug in factoryboy - model = Contig - # Match entity_id with existing db rows to determine whether we should - # create a new row or not. - sqlalchemy_get_or_create = ("entity_id",) - - sequencing_read = factory.SubFactory( - SequencingReadFactory, - owner_user_id=factory.SelfAttribute("..owner_user_id"), - collection_id=factory.SelfAttribute("..collection_id"), - ) - sequence = fuzzy.FuzzyText() diff --git a/test_app/tests/output/test_infra/factories/genomic_range.py b/test_app/tests/output/test_infra/factories/genomic_range.py deleted file mode 100644 index c814a0f..0000000 --- a/test_app/tests/output/test_infra/factories/genomic_range.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -Factory for generating GenomicRange objects. - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/test_infra/factories/class_name.py.j2 instead. -""" - -# ruff: noqa: E501 Line too long - -import factory -from database.models import GenomicRange -from test_infra.factories.main import CommonFactory, FileFactory -from factory import Faker -from faker_biology.bioseq import Bioseq -from faker_biology.physiology import Organ -from faker_enum import EnumProvider - -Faker.add_provider(Bioseq) -Faker.add_provider(Organ) -Faker.add_provider(EnumProvider) - - -class GenomicRangeFactory(CommonFactory): - class Meta: - sqlalchemy_session = None # workaround for a bug in factoryboy - model = GenomicRange - # Match entity_id with existing db rows to determine whether we should - # create a new row or not. - sqlalchemy_get_or_create = ("entity_id",) - - file = factory.RelatedFactory( - FileFactory, - factory_related_name="entity", - entity_field_name="file", - file_format="fastq", - ) diff --git a/test_app/tests/output/test_infra/factories/sample.py b/test_app/tests/output/test_infra/factories/sample.py deleted file mode 100644 index f893e95..0000000 --- a/test_app/tests/output/test_infra/factories/sample.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -Factory for generating Sample objects. - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/test_infra/factories/class_name.py.j2 instead. -""" - -# ruff: noqa: E501 Line too long - -import factory -from database.models import Sample -from test_infra.factories.main import CommonFactory -from factory import Faker, fuzzy -from faker_biology.bioseq import Bioseq -from faker_biology.physiology import Organ -from faker_enum import EnumProvider - -Faker.add_provider(Bioseq) -Faker.add_provider(Organ) -Faker.add_provider(EnumProvider) - - -class SampleFactory(CommonFactory): - class Meta: - sqlalchemy_session = None # workaround for a bug in factoryboy - model = Sample - # Match entity_id with existing db rows to determine whether we should - # create a new row or not. - sqlalchemy_get_or_create = ("entity_id",) - - name = fuzzy.FuzzyText() - sample_type = fuzzy.FuzzyText() - water_control = factory.Faker("boolean") - collection_date = factory.Faker("date") - collection_location = fuzzy.FuzzyText() - description = fuzzy.FuzzyText() diff --git a/test_app/tests/output/test_infra/factories/sequencing_read.py b/test_app/tests/output/test_infra/factories/sequencing_read.py deleted file mode 100644 index 9e1426d..0000000 --- a/test_app/tests/output/test_infra/factories/sequencing_read.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -Factory for generating SequencingRead objects. - -Auto-generated by running 'make codegen'. Do not edit. -Make changes to the template codegen/templates/test_infra/factories/class_name.py.j2 instead. -""" - -# ruff: noqa: E501 Line too long - -import factory -from database.models import SequencingRead -from test_infra.factories.main import CommonFactory, FileFactory -from test_infra.factories.sample import SampleFactory -from test_infra.factories.genomic_range import GenomicRangeFactory -from factory import Faker, fuzzy -from faker_biology.bioseq import Bioseq -from faker_biology.physiology import Organ -from faker_enum import EnumProvider - -Faker.add_provider(Bioseq) -Faker.add_provider(Organ) -Faker.add_provider(EnumProvider) - - -class SequencingReadFactory(CommonFactory): - class Meta: - sqlalchemy_session = None # workaround for a bug in factoryboy - model = SequencingRead - # Match entity_id with existing db rows to determine whether we should - # create a new row or not. - sqlalchemy_get_or_create = ("entity_id",) - - sample = factory.SubFactory( - SampleFactory, - owner_user_id=factory.SelfAttribute("..owner_user_id"), - collection_id=factory.SelfAttribute("..collection_id"), - ) - protocol = fuzzy.FuzzyChoice( - [ - "ampliseq", - "artic", - "artic_v3", - "artic_v4", - "artic_v5", - "combined_msspe_artic", - "covidseq", - "midnight", - "msspe", - "snap", - "varskip", - "easyseq", - ] - ) - r1_file = factory.RelatedFactory( - FileFactory, - factory_related_name="entity", - entity_field_name="r1_file", - file_format="fastq", - ) - r2_file = factory.RelatedFactory( - FileFactory, - factory_related_name="entity", - entity_field_name="r2_file", - file_format="fastq", - ) - technology = fuzzy.FuzzyChoice(["Illumina", "Nanopore"]) - nucleic_acid = fuzzy.FuzzyChoice(["RNA", "DNA"]) - primer_file = factory.SubFactory( - GenomicRangeFactory, - owner_user_id=factory.SelfAttribute("..owner_user_id"), - collection_id=factory.SelfAttribute("..collection_id"), - ) - clearlabs_export = factory.Faker("boolean") diff --git a/test_app/tests/test_schemas/overrides/api/.gitignore b/test_app/tests/test_schemas/overrides/api/.gitignore deleted file mode 100644 index f833585..0000000 --- a/test_app/tests/test_schemas/overrides/api/.gitignore +++ /dev/null @@ -1 +0,0 @@ -queries.py.j2 diff --git a/test_app/tests/test_schemas/overrides/api/extra_test_code.py.j2 b/test_app/tests/test_schemas/overrides/api/extra_test_code.py.j2 deleted file mode 100644 index cb2dfe0..0000000 --- a/test_app/tests/test_schemas/overrides/api/extra_test_code.py.j2 +++ /dev/null @@ -1,4 +0,0 @@ - @strawberry.field - def uncaught_exception(self) -> str: - # Trigger an AttributeException - return self.kaboom # type: ignore diff --git a/test_app/tests/test_schemas/platformics.yaml b/test_app/tests/test_schemas/platformics.yaml deleted file mode 100644 index 31d3440..0000000 --- a/test_app/tests/test_schemas/platformics.yaml +++ /dev/null @@ -1,443 +0,0 @@ -id: https://czid.org/entities/schema/platformics -title: CZID Platformics Bio-Entities Schema -name: platformics -default_range: string - -types: - string: - uri: xsd:string - base: str - description: A character string - - integer: - uri: xsd:integer - base: int - description: An integer - - uuid: - uri: xsd:string - typeof: str - base: str - description: A UUID - -enums: - FileStatus: - permissible_values: - SUCCESS: - FAILED: - PENDING: - FileAcessProtocol: - permissible_values: - s3: - NucleicAcid: - permissible_values: - RNA: - DNA: - SequencingProtocol: - permissible_values: - ampliseq: - artic: - artic_v3: - artic_v4: - artic_v5: - combined_msspe_artic: - covidseq: - midnight: - msspe: - snap: - varskip: - easyseq: - SequencingTechnology: - permissible_values: - Illumina: - Nanopore: - TaxonLevel: - permissible_values: - level_sublevel: - level_species: - level_genus: - level_family: - level_order: - level_class: - level_phylum: - level_kingdom: - level_superkingdom: - FileAccessProtocol: - permissible_values: - s3: - description: This file is accessible via the (AWS) S3 protocol - FileUploadClient: - permissible_values: - browser: - description: File uploaded from the user's browser - cli: - description: File uploaded from the CLI - s3: - description: File uploaded from S3 - basespace: - description: File uploaded from Illumina Basespace Cloud - -classes: - Entity: - attributes: - id: - identifier: true - range: uuid - readonly: true # The API handles generating the values for these fields - required: true - producing_run_id: - range: uuid - minimum_value: 0 - annotations: - mutable: false # This field can't be modified by an `Update` mutation - system_writable_only: True - owner_user_id: - range: integer - minimum_value: 0 - readonly: true - required: true - collection_id: - range: integer - minimum_value: 0 - required: false - annotations: - mutable: false - created_at: - range: date - required: true - readonly: true - updated_at: - range: date - readonly: true - deleted_at: - range: date - # NOTE - the LinkML schema doesn't support a native "plural name" field as far as I can tell, so - # we're using an annotation here to tack on the extra functionality that we need. We do this because - # English pluralization is hard, and we don't want to have to write a custom pluralization function. - # This basically means we now have our own "dialect" of LinkML to worry about. We may want to see if - # pluralization can be added to the core spec in the future. - annotations: - plural: Entities - - File: - attributes: - id: - identifier: true - range: uuid - entity_field_name: - range: string - required: true - entity: - range: Entity - required: true - status: - range: FileStatus - required: true - protocol: - range: FileAccessProtocol - required: true - namespace: - range: string - required: true - path: - range: string - required: true - file_format: - range: string - required: true - compression_type: - range: string - size: - range: integer - minimum_value: 0 - - Sample: - is_a: Entity - mixins: - - EntityMixin - attributes: - name: - range: string - required: true - sample_type: - range: string - required: true - water_control: - range: boolean - required: true - collection_date: - range: date - collection_location: - range: string - required: true - notes: - range: string - sequencing_reads: - range: SequencingRead - multivalued: true - inverse: SequencingRead.sample - annotations: - cascade_delete: true - system_mutable_field: - range: string - annotations: - system_writable_only: True - annotations: - plural: Samples - - SequencingRead: - is_a: Entity - mixins: - - EntityMixin - attributes: - sample: - range: Sample - inverse: Sample.sequencing_reads - protocol: - range: SequencingProtocol - required: true - annotations: - mutable: false - r1_file: - range: File - readonly: true - annotations: - cascade_delete: true - r2_file: - range: File - readonly: true - annotations: - cascade_delete: true - technology: - range: SequencingTechnology - required: true - nucleic_acid: - range: NucleicAcid - required: true - primer_file: - range: GenomicRange - inverse: GenomicRange.sequencing_reads - annotations: - mutable: false - consensus_genomes: - range: ConsensusGenome - inverse: ConsensusGenome.sequence_read - multivalued: true - annotations: - cascade_delete: true - clearlabs_export: - range: boolean - required: true - taxon: - range: Taxon - inverse: Taxon.sequencing_reads - annotations: - mutable: false - annotations: - plural: SequencingReads - - ConsensusGenome: - is_a: Entity - mixins: - - EntityMixin - attributes: - sequence_read: - range: SequencingRead - required: true - inverse: SequencingRead.consensus_genomes - annotations: - mutable: false - sequence: - readonly: true - range: File - annotations: - cascade_delete: true - metrics: - range: MetricConsensusGenome - inverse: MetricConsensusGenome.consensus_genome - inlined: true - annotations: - cascade_delete: true - intermediate_outputs: - range: File - readonly: true - annotations: - cascade_delete: true - annotations: - plural: ConsensusGenomes - - MetricConsensusGenome: - is_a: Entity - mixins: - - EntityMixin - attributes: - consensus_genome: - range: ConsensusGenome - inverse: ConsensusGenome.metrics - required: true - annotations: - mutable: false - total_reads: - range: integer - minimum_value: 0 - maximum_value: 999999999999 - annotations: - mutable: false - mapped_reads: - range: integer - annotations: - mutable: false - annotations: - plural: MetricsConsensusGenomes - - GenomicRange: - is_a: Entity - mixins: - - EntityMixin - attributes: - file: - range: File - readonly: true - sequencing_reads: - range: SequencingRead - inverse: SequencingRead.primer_file - multivalued: true - annotations: - plural: GenomicRanges - - Taxon: - is_a: Entity - mixins: - - EntityMixin - attributes: - name: - range: string - required: true - is_phage: - range: boolean - required: true - upstream_database: - range: UpstreamDatabase - required: true - inverse: UpstreamDatabase.taxa - upstream_database_identifier: - range: string - required: true - level: - range: TaxonLevel - required: true - sequencing_reads: - range: SequencingRead - inverse: SequencingRead.taxon - multivalued: true - annotations: - plural: Taxa - - UpstreamDatabase: - is_a: Entity - mixins: - - EntityMixin - attributes: - name: - range: string - required: true - taxa: - range: Taxon - multivalued: true - inverse: Taxon.upstream_database - annotations: - plural: UpstreamDatabases - - BulkDownload: - is_a: Entity - mixins: - - EntityMixin - attributes: - download_display_name: - range: string - required: true - annotations: - mutable: false - file: - range: File - readonly: true - annotations: - cascade_delete: true - annotations: - plural: BulkDownloads - - SystemWritableOnlyType: - is_a: Entity - mixins: - - EntityMixin - attributes: - name: - range: string - required: true - annotations: - system_writable_only: true - plural: SystemWritableOnlyTypes - - ImmutableType: - is_a: Entity - mixins: - - EntityMixin - attributes: - name: - range: string - required: true - annotations: - mutable: false - plural: ImmutableTypes - - ConstraintCheckedType: - is_a: Entity - mixins: - - EntityMixin - attributes: - length_3_to_8: - range: string - annotations: - minimum_length: 3 - maximum_length: 8 - regex_format_check: - range: string - pattern: '\d{3}-\d{2}-\d{4}' - min_value_0: - range: integer - minimum_value: 0 - enum_field: - range: NucleicAcid - bool_field: - range: boolean - max_value_9: - range: integer - maximum_value: 9 - min_value_0_max_value_9: - range: integer - minimum_value: 0 - maximum_value: 9 - float_1dot1_to_2dot2: - range: float - minimum_value: 1.1 - maximum_value: 2.2 - no_string_checks: - range: string - no_int_checks: - range: integer - no_float_checks: - range: float - annotations: - plural: ConstraintCheckedTypes - - EntityMixin: - mixin: true - attributes: - entity_id: - required: true - readonly: true - range: uuid - identifier: true - inverse: entity.id - annotations: - hidden: true From 6889a84095726967680aae46fe0f2a012e7f3bf5 Mon Sep 17 00:00:00 2001 From: Omar Valenzuela Date: Fri, 7 Jun 2024 10:08:21 -0700 Subject: [PATCH 07/16] update format_handlers, bump python image, update fixture path --- Dockerfile.dev | 2 +- platformics/support/format_handlers.py | 118 ++++++++++++++++++------- test_app/tests/test_file_mutations.py | 8 +- 3 files changed, 92 insertions(+), 36 deletions(-) diff --git a/Dockerfile.dev b/Dockerfile.dev index 5cf36e3..5204862 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -1,4 +1,4 @@ -FROM python:3.11-slim-bookworm AS build +FROM python:3.12-slim-bookworm AS build RUN apt-get update && \ apt-get install -y vim wget nginx nginx-extras procps ripgrep make gcc && \ apt-get clean diff --git a/platformics/support/format_handlers.py b/platformics/support/format_handlers.py index 96b7220..6f4f7be 100644 --- a/platformics/support/format_handlers.py +++ b/platformics/support/format_handlers.py @@ -3,70 +3,126 @@ """ import gzip -import tempfile -import typing +import io +import json from abc import abstractmethod -from mypy_boto3_s3.client import S3Client from Bio import SeqIO from typing import Protocol +from mypy_boto3_s3 import S3Client + + class FileFormatHandler(Protocol): """ Interface for a file format handler """ - @classmethod + s3client: S3Client + bucket: str + key: str + + def __init__(self, s3client: S3Client, bucket: str, key: str): + self.s3client = s3client + self.bucket = bucket + self.key = key + + def contents(self) -> str: + """ + Get the contents of the file + """ + body = self.s3client.get_object(Bucket=self.bucket, Key=self.key, Range="bytes=0-1000000")["Body"] + if self.key.endswith(".gz"): + with gzip.GzipFile(fileobj=body) as fp: + return fp.read().decode("utf-8") + return body.read().decode("utf-8") + @abstractmethod - def validate(cls, client: S3Client, bucket: str, file_path: str) -> None: + def validate(self) -> None: raise NotImplementedError +class FastaHandler(FileFormatHandler): + """ + Validate FASTA files. Note that even truncated FASTA files are supported: + ">" is a valid FASTA file, and so is ">abc" (without a sequence). + """ + + def validate(self) -> None: + sequences = 0 + for _ in SeqIO.parse(io.StringIO(self.contents()), "fasta"): + sequences += 1 + assert sequences > 0 + + class FastqHandler(FileFormatHandler): """ - Validate FASTQ files (contain sequencing reads) + Validate FASTQ files. Can't use biopython directly because large file would be truncated. + This removes truncated FASTQ records by assuming 1 read = 4 lines. """ - @classmethod - def validate(cls, client: S3Client, bucket: str, file_path: str) -> None: - fp = get_file_preview(client, bucket, file_path) - assert len([read for read in SeqIO.parse(fp, "fastq")]) > 0 + def validate(self) -> None: + # Load file and only keep non-truncated FASTQ records (4 lines per record) + fastq = self.contents().split("\n") + fastq = fastq[: len(fastq) - (len(fastq) % 4)] + # Validate it with SeqIO + reads = 0 + for _ in SeqIO.parse(io.StringIO("\n".join(fastq)), "fastq"): + reads += 1 + assert reads > 0 -class FastaHandler(FileFormatHandler): + +class BedHandler(FileFormatHandler): """ - Validate FASTA files (contain sequences) + Validate BED files using basic checks. """ - @classmethod - def validate(cls, client: S3Client, bucket: str, file_path: str) -> None: - fp = get_file_preview(client, bucket, file_path) - assert len([read for read in SeqIO.parse(fp, "fasta")]) > 0 + def validate(self) -> None: + # Ignore last line since it could be truncated + records = self.contents().split("\n")[:-1] + assert len(records) > 0 + # BED files must have at least 3 columns - error out if the file incorrectly uses spaces instead of tabs + num_cols = -1 + for record in records: + assert len(record.split("\t")) >= 3 + # All rows should have the same number of columns + if num_cols == -1: + num_cols = len(record.split("\t")) + else: + assert num_cols == len(record.split("\t")) -def get_file_preview(client: S3Client, bucket: str, file_path: str) -> typing.TextIO: + +class JsonHandler(FileFormatHandler): """ - Get first 1MB of a file and save it in a temporary file + Validate JSON files """ - data = client.get_object(Bucket=bucket, Key=file_path, Range="bytes=0-1000000")["Body"].read() - fp = tempfile.NamedTemporaryFile("w+b") - fp.write(data) - fp.flush() - try: - data.decode("utf-8") - return open(fp.name, "r") - except UnicodeDecodeError: - return gzip.open(fp.name, "rt") + def validate(self) -> None: + json.loads(self.contents()) # throws an exception for invalid JSON +class ZipHandler(FileFormatHandler): + """ + Validate ZIP files + """ + + def validate(self) -> None: + assert self.key.endswith(".zip") # throws an exception if the file is not a zip file -def get_validator(format: str, compression_type: str) -> type[FileFormatHandler]: +def get_validator(format: str) -> type[FileFormatHandler]: """ Returns the validator for a given file format """ - if format == "fastq": - return FastqHandler - elif format == "fasta": + if format in ["fa", "fasta"]: return FastaHandler + elif format == "fastq": + return FastqHandler + elif format == "bed": + return BedHandler + elif format == "json": + return JsonHandler + elif format == "zip": + return ZipHandler else: raise Exception(f"Unknown file format '{format}'") diff --git a/test_app/tests/test_file_mutations.py b/test_app/tests/test_file_mutations.py index b277fd2..56729c3 100644 --- a/test_app/tests/test_file_mutations.py +++ b/test_app/tests/test_file_mutations.py @@ -35,7 +35,7 @@ async def test_file_validation( files = session.execute(sa.select(File)).scalars().all() file = list(filter(lambda file: file.entity_field_name == "r1_file", files))[0] - valid_fastq_file = "test_infra/fixtures/test1.fastq" + valid_fastq_file = "tests/fixtures/test1.fastq" file_size = os.stat(valid_fastq_file).st_size moto_client.put_object(Bucket=file.namespace, Key=file.path, Body=open(valid_fastq_file, "rb")) @@ -189,7 +189,7 @@ async def test_create_file( # Upload a fastq file to a mock bucket so we can create a file object from it file_namespace = "local-bucket" file_path = "test1.fastq" - file_path_local = "test_infra/fixtures/test1.fastq" + file_path_local = "tests/fixtures/test1.fasta" file_size = os.stat(file_path_local).st_size with open(file_path_local, "rb") as fp: moto_client.put_object(Bucket=file_namespace, Key=file_path, Body=fp) @@ -211,7 +211,7 @@ async def test_create_file( path size }} - }} + """ output = await gql_client.query(mutation, member_projects=[123], service_identity="workflows") assert output["data"]["createFile"]["size"] == file_size @@ -270,7 +270,7 @@ async def test_delete_from_s3( sequencing_read.r1_file.namespace = bucket session.commit() - valid_fastq_file = "test_infra/fixtures/test1.fastq" + valid_fastq_file = "tests/fixtures/test1.fastq" moto_client.put_object(Bucket=file.namespace, Key=file.path, Body=open(valid_fastq_file, "rb")) # Delete SequencingRead and cascade to File objects From a46fc14b64a99449f8810f5027c0633aa9211640 Mon Sep 17 00:00:00 2001 From: Omar Valenzuela Date: Fri, 7 Jun 2024 14:57:19 -0700 Subject: [PATCH 08/16] setup changes, update file paths, tests: 22 fail | 59 pass --- platformics/api/core/deps.py | 5 ---- platformics/api/core/gql_loaders.py | 17 ++----------- platformics/api/setup.py | 24 ++++++++++++------- platformics/api/types/entities.py | 6 ++--- .../templates/database/models/__init__.py.j2 | 2 -- test_app/main.py | 8 +++---- test_app/tests/test_file_concatenation.py | 4 ++-- test_app/tests/test_file_uploads.py | 4 ++-- 8 files changed, 28 insertions(+), 42 deletions(-) diff --git a/platformics/api/core/deps.py b/platformics/api/core/deps.py index 3950265..39a0eef 100644 --- a/platformics/api/core/deps.py +++ b/platformics/api/core/deps.py @@ -15,11 +15,6 @@ from platformics.api.core.error_handler import PlatformicsException -def get_db_module(request: Request) -> typing.Any: - """Get the DB module from our app state""" - return request.app.state.db_module - - def get_settings(request: Request) -> APISettings: """Get the settings object from the app state""" return request.app.state.settings diff --git a/platformics/api/core/gql_loaders.py b/platformics/api/core/gql_loaders.py index f61ca98..2cf820f 100644 --- a/platformics/api/core/gql_loaders.py +++ b/platformics/api/core/gql_loaders.py @@ -83,12 +83,7 @@ async def load_fn(keys: list[Any]) -> typing.Sequence[Any]: for _, remote in relationship.local_remote_pairs: filters.append(remote.in_(keys)) query = get_db_query( - related_model, - CerbosAction.VIEW, - self.cerbos_client, - self.principal, - where, - order_by, # type: ignore + related_model, CerbosAction.VIEW, self.cerbos_client, self.principal, where, order_by # type: ignore ) for item in filters: query = query.where(item) @@ -157,15 +152,7 @@ async def load_fn(keys: list[Any]) -> typing.Sequence[Any]: raise PlatformicsException("No aggregate functions selected") query, group_by = get_aggregate_db_query( - related_model, - CerbosAction.VIEW, - self.cerbos_client, - self.principal, - where, - aggregate_selections, - groupby_selections, - None, - remote, # type: ignore + related_model, CerbosAction.VIEW, self.cerbos_client, self.principal, where, aggregate_selections, groupby_selections, None, remote # type: ignore ) for item in filters: query = query.where(item) diff --git a/platformics/api/setup.py b/platformics/api/setup.py index 077641a..a088b47 100644 --- a/platformics/api/setup.py +++ b/platformics/api/setup.py @@ -4,14 +4,23 @@ import typing -import strawberry +import uvicorn from cerbos.sdk.client import CerbosClient from cerbos.sdk.model import Principal from fastapi import Depends, FastAPI -from platformics.api.core.deps import get_auth_principal, get_cerbos_client, get_engine, get_db_module +from platformics.api.core.deps import ( + get_auth_principal, + get_cerbos_client, + get_engine, + get_s3_client, +) +from platformics.api.core.error_handler import HandleErrors from platformics.api.core.gql_loaders import EntityLoader from platformics.database.connect import AsyncDB from platformics.settings import APISettings +from platformics.database.models.file import File + +import strawberry from strawberry.fastapi import GraphQLRouter from strawberry.schema.config import StrawberryConfig from strawberry.schema.name_converter import HasGraphQLName, NameConverter @@ -23,7 +32,6 @@ def get_context( engine: AsyncDB = Depends(get_engine), - db_module: AsyncDB = Depends(get_db_module), cerbos_client: CerbosClient = Depends(get_cerbos_client), principal: Principal = Depends(get_auth_principal), ) -> dict[str, typing.Any]: @@ -32,8 +40,6 @@ def get_context( """ return { "sqlalchemy_loader": EntityLoader(engine=engine, cerbos_client=cerbos_client, principal=principal), - # This is entirely to support automatically resolving Relay Nodes in the EntityInterface - "db_module": db_module, } @@ -48,19 +54,19 @@ def get_graphql_name(self, obj: HasGraphQLName) -> str: return super().get_graphql_name(obj) -def get_app(settings: APISettings, schema: strawberry.Schema, db_module: typing.Any) -> FastAPI: +def get_app(settings: APISettings, schema: strawberry.Schema) -> FastAPI: """ Make sure tests can get their own instances of the app. """ - settings = APISettings.model_validate({}) # Workaround for https://github.com/pydantic/pydantic/issues/3753 + File.set_settings(settings) + File.set_s3_client(get_s3_client(settings)) title = settings.SERVICE_NAME - graphql_app: GraphQLRouter = GraphQLRouter(schema, context_getter=get_context, graphiql=True) + graphql_app: GraphQLRouter = GraphQLRouter(schema, context_getter=get_context) _app = FastAPI(title=title, debug=settings.DEBUG) _app.include_router(graphql_app, prefix="/graphql") # Add a global settings object to the app that we can use as a dependency _app.state.settings = settings - _app.state.db_module = db_module return _app diff --git a/platformics/api/types/entities.py b/platformics/api/types/entities.py index 6485c38..aca6db1 100644 --- a/platformics/api/types/entities.py +++ b/platformics/api/types/entities.py @@ -1,5 +1,6 @@ from typing import Iterable +import database.models as db import strawberry from platformics.api import relay @@ -20,12 +21,11 @@ class EntityInterface(relay.Node): # In the Strawberry docs, this field is called `code`, but we're using `id` instead. # Otherwise, Strawberry SQLAlchemyMapper errors with: "SequencingRead object has no # attribute code" (unless you create a column `code` in the table) - id: relay.NodeID[str] + id: relay.NodeID[str] # type: ignore @classmethod async def resolve_nodes(cls, *, info: Info, node_ids: Iterable[str], required: bool = False) -> list: dataloader = info.context["sqlalchemy_loader"] - db_module = info.context["db_module"] gql_type: str = cls.__strawberry_definition__.name # type: ignore - sql_model = getattr(db_module, gql_type) + sql_model = getattr(db, gql_type) return await dataloader.resolve_nodes(sql_model, node_ids) diff --git a/platformics/codegen/templates/database/models/__init__.py.j2 b/platformics/codegen/templates/database/models/__init__.py.j2 index 423497d..cfebf59 100644 --- a/platformics/codegen/templates/database/models/__init__.py.j2 +++ b/platformics/codegen/templates/database/models/__init__.py.j2 @@ -11,9 +11,7 @@ from sqlalchemy.orm import configure_mappers from platformics.database.models import Base, meta, Entity, File, FileStatus # noqa: F401 {%- for class in classes %} - {%- if class.snake_name != "Entity" %} from database.models.{{ class.snake_name }} import {{ class.name }} # noqa: F401 - {%- endif %} {%- endfor %} from platformics.database.models.file import File, FileStatus # noqa: F401 diff --git a/test_app/main.py b/test_app/main.py index 961df11..3bf47b2 100644 --- a/test_app/main.py +++ b/test_app/main.py @@ -4,20 +4,20 @@ import strawberry import uvicorn -from platformics.api.setup import get_app, get_strawberry_config +from platformics.api.setup import get_app, get_strawberry_config, CustomNameConverter +from platformics.api.core.error_handler import HandleErrors from platformics.settings import APISettings -from database import models from api.mutations import Mutation from api.queries import Query settings = APISettings.model_validate({}) # Workaround for https://github.com/pydantic/pydantic/issues/3753 strawberry_config = get_strawberry_config() -schema = strawberry.Schema(query=Query, mutation=Mutation, config=strawberry_config) +schema = strawberry.Schema(query=Query, mutation=Mutation, config=get_strawberry_config(), extensions=[HandleErrors()]) # Create and run app -app = get_app(settings, schema, models) +app = get_app(settings, schema) if __name__ == "__main__": config = uvicorn.Config("main:app", host="0.0.0.0", port=9008, log_level="info") diff --git a/test_app/tests/test_file_concatenation.py b/test_app/tests/test_file_concatenation.py index d39c116..7a75c2b 100644 --- a/test_app/tests/test_file_concatenation.py +++ b/test_app/tests/test_file_concatenation.py @@ -27,8 +27,8 @@ async def test_concatenation( user_id = 12345 project_id = 111 member_projects = [project_id] - fasta_file_1 = f"test_infra/fixtures/{file_name_1}" - fasta_file_2 = f"test_infra/fixtures/{file_name_2}" + fasta_file_1 = f"tests/fixtures/{file_name_1}" + fasta_file_2 = f"tests/fixtures/{file_name_2}" # Create mock data with sync_db.session() as session: diff --git a/test_app/tests/test_file_uploads.py b/test_app/tests/test_file_uploads.py index 7b06cc2..38a79b0 100644 --- a/test_app/tests/test_file_uploads.py +++ b/test_app/tests/test_file_uploads.py @@ -60,7 +60,7 @@ async def test_upload_process( credentials = output["data"]["uploadFile"]["credentials"] # Upload the file - fastq_file = "test_infra/fixtures/test1.fastq" + fastq_file = "tests/fixtures/test1.fastq" fastq_file_size = os.stat(fastq_file).st_size moto_client.put_object(Bucket=credentials["namespace"], Key=credentials["path"], Body=open(fastq_file, "rb")) @@ -93,7 +93,7 @@ async def test_upload_process_multiple_files_per_entity( user_id = 12345 project_id = 111 member_projects = [project_id] - fastq_file = "test_infra/fixtures/test1.fastq" + fastq_file = "tests/fixtures/test1.fastq" # Create mock data with sync_db.session() as session: From aeffbba804eebf7fd7dbe4fd971bd4948d4a0645 Mon Sep 17 00:00:00 2001 From: Omar Valenzuela Date: Mon, 10 Jun 2024 10:57:01 -0700 Subject: [PATCH 09/16] set deleted_at -- 59 pass, 22 failed --- platformics/api/core/helpers.py | 18 +++++++----------- platformics/api/setup.py | 1 - test_app/main.py | 1 - test_app/schema/schema.yaml | 5 +++++ 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/platformics/api/core/helpers.py b/platformics/api/core/helpers.py index a94280b..05eec6f 100644 --- a/platformics/api/core/helpers.py +++ b/platformics/api/core/helpers.py @@ -126,7 +126,7 @@ def convert_where_clauses_to_sql( action, cerbos_query, related_cls, - join_info.get("where"), # type: ignore + join_info.get("where"), # type: ignore join_info.get("order_by"), join_info.get("group_by"), depth, @@ -170,15 +170,11 @@ def convert_where_clauses_to_sql( # For the variants of regexp_match, we pass in a dict with the comparator, should_negate, and flag elif isinstance(sa_comparator, dict): if sa_comparator["should_negate"]: - query = query.filter( - ~(getattr(getattr(sa_model, col), sa_comparator["comparator"])(value, sa_comparator["flag"])) - ) + query = query.filter(~(getattr(getattr(sa_model, col), sa_comparator["comparator"])(value, sa_comparator["flag"]))) else: - query = query.filter( - getattr(getattr(sa_model, col), sa_comparator["comparator"])(value, sa_comparator["flag"]) - ) + query = query.filter(getattr(getattr(sa_model, col), sa_comparator["comparator"])(value, sa_comparator["flag"])) else: - query = query.filter(getattr(getattr(sa_model, col), sa_comparator)(value)) # type: ignore + query = query.filter(getattr(getattr(sa_model, col), sa_comparator)(value)) # type: ignore return query, local_order_by, local_group_by @@ -207,9 +203,9 @@ def get_db_query( cerbos_client, action, query, - model_cls, # type: ignore + model_cls, # type: ignore where, - order_by, # type: ignore + order_by, # type: ignore [], 0, ) @@ -297,7 +293,7 @@ def get_aggregate_db_query( cerbos_client, action, query, - model_cls, # type: ignore + model_cls, # type: ignore where, [], group_by, diff --git a/platformics/api/setup.py b/platformics/api/setup.py index a088b47..be179d3 100644 --- a/platformics/api/setup.py +++ b/platformics/api/setup.py @@ -4,7 +4,6 @@ import typing -import uvicorn from cerbos.sdk.client import CerbosClient from cerbos.sdk.model import Principal from fastapi import Depends, FastAPI diff --git a/test_app/main.py b/test_app/main.py index 3bf47b2..0ed11a3 100644 --- a/test_app/main.py +++ b/test_app/main.py @@ -12,7 +12,6 @@ from api.queries import Query settings = APISettings.model_validate({}) # Workaround for https://github.com/pydantic/pydantic/issues/3753 -strawberry_config = get_strawberry_config() schema = strawberry.Schema(query=Query, mutation=Mutation, config=get_strawberry_config(), extensions=[HandleErrors()]) diff --git a/test_app/schema/schema.yaml b/test_app/schema/schema.yaml index 8b3b497..bc020be 100644 --- a/test_app/schema/schema.yaml +++ b/test_app/schema/schema.yaml @@ -79,6 +79,11 @@ classes: updated_at: range: date readonly: true + deleted_at: + range: date + annotations: + mutable: true + system_writable_only: True annotations: plural: Entities From 54d390a91e8fd6e555d63a29964bcfb28f3b431e Mon Sep 17 00:00:00 2001 From: Jessica Gadling Date: Thu, 13 Jun 2024 13:21:56 -0400 Subject: [PATCH 10/16] Put the db models patches back in place. --- platformics/api/core/deps.py | 5 +++++ platformics/api/setup.py | 10 ++++++++-- platformics/api/types/entities.py | 6 +++--- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/platformics/api/core/deps.py b/platformics/api/core/deps.py index 3685a01..b284ccc 100644 --- a/platformics/api/core/deps.py +++ b/platformics/api/core/deps.py @@ -16,6 +16,11 @@ from platformics.settings import APISettings +def get_db_module(request: Request) -> typing.Any: + """Get the DB module from our app state""" + return request.app.state.db_module + + def get_settings(request: Request) -> APISettings: """Get the settings object from the app state""" return request.app.state.settings diff --git a/platformics/api/setup.py b/platformics/api/setup.py index b5b811e..6c76589 100644 --- a/platformics/api/setup.py +++ b/platformics/api/setup.py @@ -12,9 +12,10 @@ from strawberry.schema.config import StrawberryConfig from strawberry.schema.name_converter import HasGraphQLName, NameConverter -from platformics.api.core.deps import get_auth_principal, get_cerbos_client, get_engine +from platformics.api.core.deps import get_auth_principal, get_cerbos_client, get_db_module, get_engine, get_s3_client from platformics.api.core.gql_loaders import EntityLoader from platformics.database.connect import AsyncDB +from platformics.database.models.file import File from platformics.settings import APISettings # ------------------------------------------------------------------------------ @@ -24,6 +25,7 @@ def get_context( engine: AsyncDB = Depends(get_engine), + db_module: AsyncDB = Depends(get_db_module), cerbos_client: CerbosClient = Depends(get_cerbos_client), principal: Principal = Depends(get_auth_principal), ) -> dict[str, typing.Any]: @@ -32,6 +34,8 @@ def get_context( """ return { "sqlalchemy_loader": EntityLoader(engine=engine, cerbos_client=cerbos_client, principal=principal), + # This is entirely to support automatically resolving Relay Nodes in the EntityInterface + "db_module": db_module, } @@ -50,10 +54,12 @@ def get_app(settings: APISettings, schema: strawberry.Schema, db_module: typing. """ Make sure tests can get their own instances of the app. """ + File.set_settings(settings) + File.set_s3_client(get_s3_client(settings)) settings = APISettings.model_validate({}) # Workaround for https://github.com/pydantic/pydantic/issues/3753 title = settings.SERVICE_NAME - graphql_app: GraphQLRouter = GraphQLRouter(schema, context_getter=get_context, graphiql=True) + graphql_app: GraphQLRouter = GraphQLRouter(schema, context_getter=get_context) _app = FastAPI(title=title, debug=settings.DEBUG) _app.include_router(graphql_app, prefix="/graphql") # Add a global settings object to the app that we can use as a dependency diff --git a/platformics/api/types/entities.py b/platformics/api/types/entities.py index 3721d04..358c5ba 100644 --- a/platformics/api/types/entities.py +++ b/platformics/api/types/entities.py @@ -1,6 +1,5 @@ from typing import Iterable -import database.models as db import strawberry from strawberry.types import Info @@ -21,11 +20,12 @@ class EntityInterface(relay.Node): # In the Strawberry docs, this field is called `code`, but we're using `id` instead. # Otherwise, Strawberry SQLAlchemyMapper errors with: "SequencingRead object has no # attribute code" (unless you create a column `code` in the table) - id: relay.NodeID[str] # type: ignore + id: relay.NodeID[str] @classmethod async def resolve_nodes(cls, *, info: Info, node_ids: Iterable[str], required: bool = False) -> list: dataloader = info.context["sqlalchemy_loader"] + db_module = info.context["db_module"] gql_type: str = cls.__strawberry_definition__.name # type: ignore - sql_model = getattr(db, gql_type) + sql_model = getattr(db_module, gql_type) return await dataloader.resolve_nodes(sql_model, node_ids) From cec0907ecd02786a807a6d1a6bdcdedc08af9e50 Mon Sep 17 00:00:00 2001 From: Jessica Gadling Date: Thu, 13 Jun 2024 13:27:17 -0400 Subject: [PATCH 11/16] Put this back. --- platformics/codegen/templates/database/models/__init__.py.j2 | 2 ++ 1 file changed, 2 insertions(+) diff --git a/platformics/codegen/templates/database/models/__init__.py.j2 b/platformics/codegen/templates/database/models/__init__.py.j2 index e2aeb73..f70c376 100644 --- a/platformics/codegen/templates/database/models/__init__.py.j2 +++ b/platformics/codegen/templates/database/models/__init__.py.j2 @@ -11,7 +11,9 @@ from sqlalchemy.orm import configure_mappers from platformics.database.models import Base, meta, Entity, File, FileStatus # noqa: F401 {%- for class in classes %} + {%- if class.snake_name != "Entity" %} from database.models.{{ class.snake_name }} import {{ class.name }} # noqa: F401 + {%- endif %} {%- endfor %} from platformics.database.models.file import File, FileStatus # noqa: F401 From 933740ef02416df8212739ae53cced218ef47a49 Mon Sep 17 00:00:00 2001 From: Jessica Gadling Date: Thu, 13 Jun 2024 15:13:15 -0400 Subject: [PATCH 12/16] Speed up docker image builds. --- .gitignore | 3 +++ Dockerfile | 2 ++ Makefile | 5 +++++ 3 files changed, 10 insertions(+) diff --git a/.gitignore b/.gitignore index a2d4326..5f643b7 100644 --- a/.gitignore +++ b/.gitignore @@ -257,3 +257,6 @@ test_app/test_infra/* # temp files /tmp/* + +# build artifacts +requirements.txt diff --git a/Dockerfile b/Dockerfile index 075330b..d9711ac 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,6 +4,8 @@ RUN apt update && \ apt install -y make gcc && \ apt clean +COPY requirements.txt . +RUN pip install -r requirements.txt COPY dist/platformics-0.1.0-py3-none-any.whl /tmp/platformics-0.1.0-py3-none-any.whl RUN cd /tmp/ && pip install platformics-0.1.0-py3-none-any.whl && rm -rf /tmp/*.whl diff --git a/Makefile b/Makefile index 36ac17a..2ddc373 100644 --- a/Makefile +++ b/Makefile @@ -71,7 +71,12 @@ gha-setup: ## Set up the environment in CI build: rm -rf dist/*.whl poetry build + # Export poetry dependency list as a requirements.txt, which makes Docker builds + # faster by not having to reinstall all dependencies every time we build a new wheel. + poetry export --without-hashes --format=requirements.txt > requirements.txt $(docker_compose) build + $(MAKE_TEST_APP) build + rm requirements.txt .PHONY: dev ## Launch a container suitable for developing the platformics library dev: From f617111f232c60252031aabf4331922106d46b38 Mon Sep 17 00:00:00 2001 From: Jessica Gadling Date: Thu, 13 Jun 2024 15:42:00 -0400 Subject: [PATCH 13/16] Fix aggregate query test. --- platformics/api/core/query_builder.py | 8 ++++---- platformics/api/core/query_input_types.py | 2 +- test_app/tests/test_aggregate_queries.py | 5 ++++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/platformics/api/core/query_builder.py b/platformics/api/core/query_builder.py index 7d61259..58aa1d4 100644 --- a/platformics/api/core/query_builder.py +++ b/platformics/api/core/query_builder.py @@ -18,7 +18,7 @@ import platformics.database.models as db from platformics.api.core.errors import PlatformicsError -from platformics.api.core.query_input_types import OrderBy, aggregator_map, operator_map +from platformics.api.core.query_input_types import aggregator_map, operator_map, orderBy from platformics.database.models.base import Base from platformics.security.authorization import CerbosAction, get_resource_query @@ -26,7 +26,7 @@ T = typing.TypeVar("T") -def apply_order_by(field: str, direction: OrderBy, query: Select) -> Select: +def apply_order_by(field: str, direction: orderBy, query: Select) -> Select: match direction.value: case "asc": query = query.order_by(getattr(query.selected_columns, field).asc()) @@ -44,9 +44,9 @@ def apply_order_by(field: str, direction: OrderBy, query: Select) -> Select: class IndexedOrderByClause(TypedDict): - field: dict[str, OrderBy] | dict[str, dict[str, Any]] + field: dict[str, orderBy] | dict[str, dict[str, Any]] index: int - sort: OrderBy + sort: orderBy def convert_where_clauses_to_sql( diff --git a/platformics/api/core/query_input_types.py b/platformics/api/core/query_input_types.py index 5b2ff98..5911427 100644 --- a/platformics/api/core/query_input_types.py +++ b/platformics/api/core/query_input_types.py @@ -42,7 +42,7 @@ @strawberry.enum -class OrderBy(enum.Enum): +class orderBy(enum.Enum): # noqa # defaults to nulls last asc = "asc" asc_nulls_first = "asc_nulls_first" diff --git a/test_app/tests/test_aggregate_queries.py b/test_app/tests/test_aggregate_queries.py index 64050a7..0ad5f9e 100644 --- a/test_app/tests/test_aggregate_queries.py +++ b/test_app/tests/test_aggregate_queries.py @@ -405,7 +405,10 @@ async def test_soft_deleted_data_not_in_aggregate_query( }} }} """ - output = await gql_client.query(soft_delete_query, user_id=user_id, member_projects=[project_id]) + # Only service identities are allowed to soft delete entities + output = await gql_client.query( + soft_delete_query, user_id=user_id, member_projects=[project_id], service_identity="workflows" + ) assert output["data"]["updateSample"][0]["id"] == str(sample_to_delete.id) # The soft-deleted sample should not be included in the aggregate query anymore From 9be18effa944573df2cb85a8767d17818ca3b3b8 Mon Sep 17 00:00:00 2001 From: Omar Valenzuela Date: Mon, 17 Jun 2024 11:05:40 -0700 Subject: [PATCH 14/16] specify service_identity for deletion specific tests --- test_app/tests/test_file_mutations.py | 4 +++- test_app/tests/test_where_clause.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/test_app/tests/test_file_mutations.py b/test_app/tests/test_file_mutations.py index 0c32486..16d944b 100644 --- a/test_app/tests/test_file_mutations.py +++ b/test_app/tests/test_file_mutations.py @@ -286,7 +286,9 @@ async def test_delete_from_s3( assert "Contents" in moto_client.list_objects(Bucket=file.namespace, Prefix=file.path) # Issue deletion - result = await gql_client.query(query, user_id=user1_id, member_projects=[project1_id]) + result = await gql_client.query( + query, user_id=user1_id, member_projects=[project1_id], service_identity="workflows" + ) assert result["data"]["deleteSequencingRead"][0]["id"] == str(file.entity_id) # Make sure file either does or does not exist diff --git a/test_app/tests/test_where_clause.py b/test_app/tests/test_where_clause.py index a7ee270..3963df6 100644 --- a/test_app/tests/test_where_clause.py +++ b/test_app/tests/test_where_clause.py @@ -271,7 +271,7 @@ async def test_soft_deleted_objects(sync_db: SyncDB, gql_client: GQLTestClient) }} }} """ - output = await gql_client.query(soft_delete_mutation, member_projects=[project_id]) + output = await gql_client.query(soft_delete_mutation, member_projects=[project_id], service_identity="workflows") assert len(output["data"]["updateSequencingRead"]) == 3 # Check that the soft-deleted sequencing reads are not returned @@ -313,7 +313,9 @@ async def test_soft_deleted_objects(sync_db: SyncDB, gql_client: GQLTestClient) }} """ - output = await gql_client.query(hard_delete_mutation, user_id=user_id, member_projects=[project_id]) + output = await gql_client.query( + hard_delete_mutation, user_id=user_id, member_projects=[project_id], service_identity="workflows" + ) assert len(output["data"]["deleteSequencingRead"]) == 3 # Check that the hard-deleted sequencing reads are not returned From abb500e0cbebc990e809944927beb87d0bcce1f8 Mon Sep 17 00:00:00 2001 From: Omar Valenzuela Date: Mon, 17 Jun 2024 13:14:37 -0700 Subject: [PATCH 15/16] enable cascade deletion, 60 pass / 21 fail --- test_app/schema/schema.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test_app/schema/schema.yaml b/test_app/schema/schema.yaml index b583295..f03cb47 100644 --- a/test_app/schema/schema.yaml +++ b/test_app/schema/schema.yaml @@ -180,8 +180,14 @@ classes: required: true r1_file: range: File + readonly: true + annotations: + cascade_delete: true r2_file: range: File + readonly: true + annotations: + cascade_delete: true technology: range: SequencingTechnology required: true @@ -209,6 +215,8 @@ classes: attributes: file: range: File + annotations: + cascade_delete: true sequencing_reads: range: SequencingRead inverse: SequencingRead.primer_file From a6b20a94fac96c8d8cd4afdc4895ea2c8d15639c Mon Sep 17 00:00:00 2001 From: Omar Valenzuela Date: Tue, 18 Jun 2024 13:40:47 -0700 Subject: [PATCH 16/16] bump python version in CICD --- .github/workflows/build-and-push.yml | 2 +- .github/workflows/build-and-test.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-and-push.yml b/.github/workflows/build-and-push.yml index ef6817c..a884dfc 100644 --- a/.github/workflows/build-and-push.yml +++ b/.github/workflows/build-and-push.yml @@ -16,7 +16,7 @@ jobs: - name: set up Python uses: actions/setup-python@v2 with: - python-version: '3.11' + python-version: '3.12' - name: install poetry run: | python -m pip install --no-cache-dir poetry==1.8 supervisor diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 5a26123..0c0a6ba 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -20,7 +20,7 @@ jobs: - name: set up Python uses: actions/setup-python@v2 with: - python-version: '3.11' + python-version: '3.12' - name: install poetry run: | python -m pip install --no-cache-dir poetry==1.8 supervisor