diff --git a/.github/workflows/build-and-push.yml b/.github/workflows/build-and-push.yml index 862c5b8..a884dfc 100644 --- a/.github/workflows/build-and-push.yml +++ b/.github/workflows/build-and-push.yml @@ -16,11 +16,11 @@ jobs: - name: set up Python uses: actions/setup-python@v2 with: - python-version: '3.11' + python-version: '3.12' - name: install poetry run: | python -m pip install --no-cache-dir poetry==1.8 supervisor - + - name: set up docker run: | make gha-setup diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index a0a62f4..0c0a6ba 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -20,11 +20,11 @@ jobs: - name: set up Python uses: actions/setup-python@v2 with: - python-version: '3.11' + python-version: '3.12' - name: install poetry run: | python -m pip install --no-cache-dir poetry==1.8 supervisor - + - name: set up docker run: make gha-setup diff --git a/.gitignore b/.gitignore index a2d4326..5f643b7 100644 --- a/.gitignore +++ b/.gitignore @@ -257,3 +257,6 @@ test_app/test_infra/* # temp files /tmp/* + +# build artifacts +requirements.txt diff --git a/Dockerfile b/Dockerfile index 075330b..d9711ac 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,6 +4,8 @@ RUN apt update && \ apt install -y make gcc && \ apt clean +COPY requirements.txt . +RUN pip install -r requirements.txt COPY dist/platformics-0.1.0-py3-none-any.whl /tmp/platformics-0.1.0-py3-none-any.whl RUN cd /tmp/ && pip install platformics-0.1.0-py3-none-any.whl && rm -rf /tmp/*.whl diff --git a/Makefile b/Makefile index 34bbfee..2ddc373 100644 --- a/Makefile +++ b/Makefile @@ -20,10 +20,6 @@ help: ## display help for this makefile @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' @echo "### SHARED FUNCTIONS END ###" -.PHONY: codegen -codegen: - $(MAKE_TEST_APP) codegen - .PHONY: codegen codegen: ## Run codegen to convert the LinkML schema to a GQL API $(docker_compose_run) $(CONTAINER) python3 -m platformics.cli.main api generate --schemafile ./schema/schema.yaml --output-prefix . @@ -75,7 +71,12 @@ gha-setup: ## Set up the environment in CI build: rm -rf dist/*.whl poetry build + # Export poetry dependency list as a requirements.txt, which makes Docker builds + # faster by not having to reinstall all dependencies every time we build a new wheel. + poetry export --without-hashes --format=requirements.txt > requirements.txt $(docker_compose) build + $(MAKE_TEST_APP) build + rm requirements.txt .PHONY: dev ## Launch a container suitable for developing the platformics library dev: diff --git a/platformics/api/core/gql_loaders.py b/platformics/api/core/gql_loaders.py index ea3512f..98cfd7a 100644 --- a/platformics/api/core/gql_loaders.py +++ b/platformics/api/core/gql_loaders.py @@ -9,7 +9,7 @@ import platformics.database.models as db from platformics.api.core.errors import PlatformicsError -from platformics.api.core.helpers import get_aggregate_db_query, get_db_query, get_db_rows +from platformics.api.core.query_builder import get_aggregate_db_query, get_db_query, get_db_rows from platformics.database.connect import AsyncDB from platformics.security.authorization import CerbosAction diff --git a/platformics/api/core/helpers.py b/platformics/api/core/query_builder.py similarity index 98% rename from platformics/api/core/helpers.py rename to platformics/api/core/query_builder.py index f8f00f8..58aa1d4 100644 --- a/platformics/api/core/helpers.py +++ b/platformics/api/core/query_builder.py @@ -18,7 +18,7 @@ import platformics.database.models as db from platformics.api.core.errors import PlatformicsError -from platformics.api.core.gql_to_sql import OrderBy, aggregator_map, operator_map +from platformics.api.core.query_input_types import aggregator_map, operator_map, orderBy from platformics.database.models.base import Base from platformics.security.authorization import CerbosAction, get_resource_query @@ -26,7 +26,7 @@ T = typing.TypeVar("T") -def apply_order_by(field: str, direction: OrderBy, query: Select) -> Select: +def apply_order_by(field: str, direction: orderBy, query: Select) -> Select: match direction.value: case "asc": query = query.order_by(getattr(query.selected_columns, field).asc()) @@ -44,9 +44,9 @@ def apply_order_by(field: str, direction: OrderBy, query: Select) -> Select: class IndexedOrderByClause(TypedDict): - field: dict[str, OrderBy] | dict[str, dict[str, Any]] + field: dict[str, orderBy] | dict[str, dict[str, Any]] index: int - sort: OrderBy + sort: orderBy def convert_where_clauses_to_sql( diff --git a/platformics/api/core/gql_to_sql.py b/platformics/api/core/query_input_types.py similarity index 99% rename from platformics/api/core/gql_to_sql.py rename to platformics/api/core/query_input_types.py index 5b2ff98..5911427 100644 --- a/platformics/api/core/gql_to_sql.py +++ b/platformics/api/core/query_input_types.py @@ -42,7 +42,7 @@ @strawberry.enum -class OrderBy(enum.Enum): +class orderBy(enum.Enum): # noqa # defaults to nulls last asc = "asc" asc_nulls_first = "asc_nulls_first" diff --git a/platformics/api/files.py b/platformics/api/files.py index 865d9aa..8929ea1 100644 --- a/platformics/api/files.py +++ b/platformics/api/files.py @@ -33,8 +33,8 @@ require_auth_principal, require_system_user, ) -from platformics.api.core.gql_to_sql import EnumComparators, IntComparators, StrComparators, UUIDComparators -from platformics.api.core.helpers import get_db_rows +from platformics.api.core.query_builder import get_db_rows +from platformics.api.core.query_input_types import EnumComparators, IntComparators, StrComparators, UUIDComparators from platformics.api.core.strawberry_extensions import DependencyExtension from platformics.api.types.entities import Entity from platformics.security.authorization import CerbosAction, get_resource_query @@ -262,7 +262,8 @@ async def validate_file( # Validate data try: - validator.validate(client=s3_client, bucket=file.namespace, file_path=file.path) + validator(s3_client, file.namespace, file.path).validate() + file_size = s3_client.head_object(Bucket=file.namespace, Key=file.path)["ContentLength"] except: # noqa file.status = db.FileStatus.FAILED diff --git a/platformics/api/setup.py b/platformics/api/setup.py index 387e738..6c76589 100644 --- a/platformics/api/setup.py +++ b/platformics/api/setup.py @@ -12,9 +12,10 @@ from strawberry.schema.config import StrawberryConfig from strawberry.schema.name_converter import HasGraphQLName, NameConverter -from platformics.api.core.deps import get_auth_principal, get_cerbos_client, get_db_module, get_engine +from platformics.api.core.deps import get_auth_principal, get_cerbos_client, get_db_module, get_engine, get_s3_client from platformics.api.core.gql_loaders import EntityLoader from platformics.database.connect import AsyncDB +from platformics.database.models.file import File from platformics.settings import APISettings # ------------------------------------------------------------------------------ @@ -53,10 +54,12 @@ def get_app(settings: APISettings, schema: strawberry.Schema, db_module: typing. """ Make sure tests can get their own instances of the app. """ + File.set_settings(settings) + File.set_s3_client(get_s3_client(settings)) settings = APISettings.model_validate({}) # Workaround for https://github.com/pydantic/pydantic/issues/3753 title = settings.SERVICE_NAME - graphql_app: GraphQLRouter = GraphQLRouter(schema, context_getter=get_context, graphiql=True) + graphql_app: GraphQLRouter = GraphQLRouter(schema, context_getter=get_context) _app = FastAPI(title=title, debug=settings.DEBUG) _app.include_router(graphql_app, prefix="/graphql") # Add a global settings object to the app that we can use as a dependency diff --git a/platformics/codegen/templates/api/types/class_name.py.j2 b/platformics/codegen/templates/api/types/class_name.py.j2 index 7664c57..75c492c 100644 --- a/platformics/codegen/templates/api/types/class_name.py.j2 +++ b/platformics/codegen/templates/api/types/class_name.py.j2 @@ -17,7 +17,7 @@ import platformics.database.models as base_db import database.models as db import strawberry import datetime -from platformics.api.core.helpers import get_db_rows, get_aggregate_db_rows +from platformics.api.core.query_builder import get_db_rows, get_aggregate_db_rows from api.validators.{{cls.snake_name}} import {{cls.name}}CreateInputValidator, {{cls.name}}UpdateInputValidator {%- if render_files %} from platformics.api.files import File, FileWhereClause @@ -34,7 +34,7 @@ from cerbos.sdk.model import Principal, Resource from fastapi import Depends from platformics.api.core.errors import PlatformicsError from platformics.api.core.deps import get_cerbos_client, get_db_session, require_auth_principal, is_system_user -from platformics.api.core.gql_to_sql import aggregator_map, orderBy, EnumComparators, DatetimeComparators, IntComparators, FloatComparators, StrComparators, UUIDComparators, BoolComparators +from platformics.api.core.query_input_types import aggregator_map, orderBy, EnumComparators, DatetimeComparators, IntComparators, FloatComparators, StrComparators, UUIDComparators, BoolComparators from platformics.api.core.strawberry_extensions import DependencyExtension from platformics.security.authorization import CerbosAction, get_resource_query from sqlalchemy import inspect diff --git a/platformics/support/format_handlers.py b/platformics/support/format_handlers.py index afe2efc..83ef53a 100644 --- a/platformics/support/format_handlers.py +++ b/platformics/support/format_handlers.py @@ -3,8 +3,8 @@ """ import gzip -import tempfile -import typing +import io +import json from abc import abstractmethod from typing import Protocol @@ -17,62 +17,113 @@ class FileFormatHandler(Protocol): Interface for a file format handler """ - @classmethod + s3client: S3Client + bucket: str + key: str + + def __init__(self, s3client: S3Client, bucket: str, key: str): + self.s3client = s3client + self.bucket = bucket + self.key = key + + def contents(self) -> str: + """ + Get the contents of the file + """ + body = self.s3client.get_object(Bucket=self.bucket, Key=self.key, Range="bytes=0-1000000")["Body"] + if self.key.endswith(".gz"): + with gzip.GzipFile(fileobj=body) as fp: + return fp.read().decode("utf-8") + return body.read().decode("utf-8") + @abstractmethod - def validate(cls, client: S3Client, bucket: str, file_path: str) -> None: + def validate(self) -> None: raise NotImplementedError +class FastaHandler(FileFormatHandler): + """ + Validate FASTA files. Note that even truncated FASTA files are supported: + ">" is a valid FASTA file, and so is ">abc" (without a sequence). + """ + + def validate(self) -> None: + sequences = 0 + for _ in SeqIO.parse(io.StringIO(self.contents()), "fasta"): + sequences += 1 + assert sequences > 0 + + class FastqHandler(FileFormatHandler): """ - Validate FASTQ files (contain sequencing reads) + Validate FASTQ files. Can't use biopython directly because large file would be truncated. + This removes truncated FASTQ records by assuming 1 read = 4 lines. """ - @classmethod - def validate(cls, client: S3Client, bucket: str, file_path: str) -> None: - for fp in get_file_preview(client, bucket, file_path): - assert len(list(SeqIO.parse(fp, "fastq"))) > 0 + def validate(self) -> None: + # Load file and only keep non-truncated FASTQ records (4 lines per record) + fastq = self.contents().split("\n") + fastq = fastq[: len(fastq) - (len(fastq) % 4)] + # Validate it with SeqIO + reads = 0 + for _ in SeqIO.parse(io.StringIO("\n".join(fastq)), "fastq"): + reads += 1 + assert reads > 0 -class FastaHandler(FileFormatHandler): + +class BedHandler(FileFormatHandler): """ - Validate FASTA files (contain sequences) + Validate BED files using basic checks. """ - @classmethod - def validate(cls, client: S3Client, bucket: str, file_path: str) -> None: - for fp in get_file_preview(client, bucket, file_path): - assert len(list(SeqIO.parse(fp, "fasta"))) > 0 + def validate(self) -> None: + # Ignore last line since it could be truncated + records = self.contents().split("\n")[:-1] + assert len(records) > 0 + # BED files must have at least 3 columns - error out if the file incorrectly uses spaces instead of tabs + num_cols = -1 + for record in records: + assert len(record.split("\t")) >= 3 + # All rows should have the same number of columns + if num_cols == -1: + num_cols = len(record.split("\t")) + else: + assert num_cols == len(record.split("\t")) -def get_file_preview( - client: S3Client, - bucket: str, - file_path: str, -) -> typing.Generator[typing.TextIO, typing.Any, typing.Any]: + +class JsonHandler(FileFormatHandler): + """ + Validate JSON files + """ + + def validate(self) -> None: + json.loads(self.contents()) # throws an exception for invalid JSON + + +class ZipHandler(FileFormatHandler): """ - Get first 1MB of a file and save it in a temporary file + Validate ZIP files """ - data = client.get_object(Bucket=bucket, Key=file_path, Range="bytes=0-1000000")["Body"].read() - fp = tempfile.NamedTemporaryFile("w+b") - fp.write(data) - fp.flush() - try: - data.decode("utf-8") - with open(fp.name, "r") as fh: - yield fh - except UnicodeDecodeError: - return gzip.open(fp.name, "rt") + def validate(self) -> None: + assert self.key.endswith(".zip") # throws an exception if the file is not a zip file def get_validator(format: str) -> type[FileFormatHandler]: """ Returns the validator for a given file format """ - if format == "fastq": - return FastqHandler - elif format == "fasta": + if format in ["fa", "fasta"]: return FastaHandler + elif format == "fastq": + return FastqHandler + elif format == "bed": + return BedHandler + elif format == "json": + return JsonHandler + elif format == "zip": + return ZipHandler else: raise Exception(f"Unknown file format '{format}'") diff --git a/test_app/Makefile b/test_app/Makefile index 63b7021..aed147c 100644 --- a/test_app/Makefile +++ b/test_app/Makefile @@ -69,6 +69,10 @@ codegen: ## Run codegen to convert the LinkML schema to a GQL API test: init ## Run tests $(docker_compose) exec $(APP_CONTAINER) pytest -vvv +.PHONY: test-file +test-file: init ## Run tests for a specific file, ex: make test-file FILE=tests/test_file.py + $(docker_compose) exec $(APP_CONTAINER) pytest -vvv $(FILE) + .PHONY: restart restart: ## Restart the GQL service $(docker_compose_run) $(APP_CONTAINER) supervisorctl restart graphql_api diff --git a/test_app/main.py b/test_app/main.py index 961df11..db4ccd7 100644 --- a/test_app/main.py +++ b/test_app/main.py @@ -5,6 +5,7 @@ import strawberry import uvicorn from platformics.api.setup import get_app, get_strawberry_config +from platformics.api.core.error_handler import HandleErrors from platformics.settings import APISettings from database import models @@ -12,8 +13,7 @@ from api.queries import Query settings = APISettings.model_validate({}) # Workaround for https://github.com/pydantic/pydantic/issues/3753 -strawberry_config = get_strawberry_config() -schema = strawberry.Schema(query=Query, mutation=Mutation, config=strawberry_config) +schema = strawberry.Schema(query=Query, mutation=Mutation, config=get_strawberry_config(), extensions=[HandleErrors()]) # Create and run app diff --git a/test_app/schema/README.md b/test_app/schema/README.md index cb605f8..0248609 100644 --- a/test_app/schema/README.md +++ b/test_app/schema/README.md @@ -333,4 +333,4 @@ classes: range: string description: range: string -``` \ No newline at end of file +``` diff --git a/test_app/schema/schema.yaml b/test_app/schema/schema.yaml index f82266e..f03cb47 100644 --- a/test_app/schema/schema.yaml +++ b/test_app/schema/schema.yaml @@ -79,6 +79,11 @@ classes: updated_at: range: date readonly: true + deleted_at: + range: date + annotations: + mutable: true + system_writable_only: True annotations: plural: Entities @@ -175,8 +180,14 @@ classes: required: true r1_file: range: File + readonly: true + annotations: + cascade_delete: true r2_file: range: File + readonly: true + annotations: + cascade_delete: true technology: range: SequencingTechnology required: true @@ -189,10 +200,11 @@ classes: primer_file: range: GenomicRange inverse: GenomicRange.sequencing_reads - contigs: + contig: range: Contig - inverse: Contig.sequencing_read - multivalued: true + inverse: Contig.sequencing_reads + annotations: + mutable: false annotations: plural: SequencingReads @@ -203,6 +215,8 @@ classes: attributes: file: range: File + annotations: + cascade_delete: true sequencing_reads: range: SequencingRead inverse: SequencingRead.primer_file @@ -215,11 +229,18 @@ classes: mixins: - EntityMixin attributes: - sequencing_read: + sequencing_reads: range: SequencingRead - inverse: Sample.contigs + inverse: SequencingRead.contig + multivalued: true sequence: required: true + upstream_database: + range: UpstreamDatabase + inverse: UpstreamDatabase.contig + required: true + annotations: + mutable: false annotations: plural: Contigs @@ -234,3 +255,61 @@ classes: inverse: entity.id annotations: hidden: true + + UpstreamDatabase: + is_a: Entity + mixins: + - EntityMixin + attributes: + name: + range: string + required: true + annotations: + indexed: true + contig: + range: Contig + inverse: Contig.upstream_database + multivalued: true + # This is where NCBI indexes would live + annotations: + plural: UpstreamDatabases + + ConstraintCheckedType: + is_a: Entity + mixins: + - EntityMixin + attributes: + length_3_to_8: + range: string + annotations: + minimum_length: 3 + maximum_length: 8 + regex_format_check: + range: string + pattern: '\d{3}-\d{2}-\d{4}' + min_value_0: + range: integer + minimum_value: 0 + enum_field: + range: NucleicAcid + bool_field: + range: boolean + max_value_9: + range: integer + maximum_value: 9 + min_value_0_max_value_9: + range: integer + minimum_value: 0 + maximum_value: 9 + float_1dot1_to_2dot2: + range: float + minimum_value: 1.1 + maximum_value: 2.2 + no_string_checks: + range: string + no_int_checks: + range: integer + no_float_checks: + range: float + annotations: + plural: ConstraintCheckedTypes diff --git a/test_app/tests/test_aggregate_queries.py b/test_app/tests/test_aggregate_queries.py index dfeb722..0ad5f9e 100644 --- a/test_app/tests/test_aggregate_queries.py +++ b/test_app/tests/test_aggregate_queries.py @@ -3,10 +3,12 @@ """ import pytest -from conftest import GQLTestClient, SessionStorage from platformics.database.connect import SyncDB +from conftest import GQLTestClient, SessionStorage from test_infra.factories.sample import SampleFactory from test_infra.factories.sequencing_read import SequencingReadFactory +from test_infra.factories.contig import ContigFactory +from test_infra.factories.upstream_database import UpstreamDatabaseFactory @pytest.mark.asyncio @@ -60,12 +62,12 @@ async def test_basic_aggregate_query( } """ output = await gql_client.query(query, user_id=user_id, member_projects=[project_id, secondary_project_id]) - avg_collectionId = output["data"]["samplesAggregate"]["aggregate"]["avg"]["collectionId"] - count = output["data"]["samplesAggregate"]["aggregate"]["count"] - max_collectionLocation = output["data"]["samplesAggregate"]["aggregate"]["max"]["collectionLocation"] - min_collectionLocation = output["data"]["samplesAggregate"]["aggregate"]["min"]["collectionLocation"] - stddev_collectionId = output["data"]["samplesAggregate"]["aggregate"]["stddev"]["collectionId"] - sum_ownerUserId = output["data"]["samplesAggregate"]["aggregate"]["sum"]["ownerUserId"] + avg_collectionId = output["data"]["samplesAggregate"]["aggregate"][0]["avg"]["collectionId"] + count = output["data"]["samplesAggregate"]["aggregate"][0]["count"] + max_collectionLocation = output["data"]["samplesAggregate"]["aggregate"][0]["max"]["collectionLocation"] + min_collectionLocation = output["data"]["samplesAggregate"]["aggregate"][0]["min"]["collectionLocation"] + stddev_collectionId = output["data"]["samplesAggregate"]["aggregate"][0]["stddev"]["collectionId"] + sum_ownerUserId = output["data"]["samplesAggregate"]["aggregate"][0]["sum"]["ownerUserId"] assert avg_collectionId == 189 assert count == 5 @@ -106,8 +108,8 @@ async def test_nested_aggregate_query( } """ results = await gql_client.query(query, user_id=111, member_projects=[888]) - assert results["data"]["samples"][0]["sequencingReadsAggregate"]["aggregate"]["count"] == 2 - assert results["data"]["samples"][1]["sequencingReadsAggregate"]["aggregate"]["count"] == 3 + assert results["data"]["samples"][0]["sequencingReadsAggregate"]["aggregate"][0]["count"] == 2 + assert results["data"]["samples"][1]["sequencingReadsAggregate"]["aggregate"][0]["count"] == 3 @pytest.mark.asyncio @@ -148,3 +150,280 @@ async def test_count_distinct_query( """ results = await gql_client.query(query, user_id=111, member_projects=[888]) assert results["data"]["samplesAggregate"]["aggregate"][0]["count"] == 2 + + +@pytest.mark.asyncio +async def test_groupby_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can perform a groupby query + """ + with sync_db.session() as session: + SessionStorage.set_session(session) + SampleFactory.create_batch(2, owner_user_id=111, collection_id=888, collection_location="San Francisco, CA") + SampleFactory.create_batch(3, owner_user_id=111, collection_id=888, collection_location="Mountain View, CA") + + query = """ + query MyQuery { + samplesAggregate { + aggregate { + groupBy { + collectionLocation + } + count + } + } + } + """ + results = await gql_client.query(query, user_id=111, member_projects=[888]) + aggregate = results["data"]["samplesAggregate"]["aggregate"] + for group in aggregate: + if group["groupBy"]["collectionLocation"] == "San Francisco, CA": + assert group["count"] == 2 + elif group["groupBy"]["collectionLocation"] == "Mountain View, CA": + assert group["count"] == 3 + + +@pytest.mark.asyncio +async def test_groupby_query_with_nested_fields( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can perform a groupby query with nested fields + """ + with sync_db.session() as session: + SessionStorage.set_session(session) + sample_1 = SampleFactory(owner_user_id=111, collection_id=888, collection_location="San Francisco, CA") + sample_2 = SampleFactory(owner_user_id=111, collection_id=888, collection_location="Mountain View, CA") + SequencingReadFactory.create_batch( + 2, sample=sample_1, owner_user_id=sample_1.owner_user_id, collection_id=sample_1.collection_id + ) + SequencingReadFactory.create_batch( + 3, sample=sample_2, owner_user_id=sample_2.owner_user_id, collection_id=sample_2.collection_id + ) + + query = """ + query MyQuery { + sequencingReadsAggregate { + aggregate { + groupBy { + sample { + collectionLocation + } + } + count + } + } + } + """ + results = await gql_client.query(query, user_id=111, member_projects=[888]) + aggregate = results["data"]["sequencingReadsAggregate"]["aggregate"] + for group in aggregate: + if group["groupBy"]["sample"]["collectionLocation"] == "San Francisco, CA": + assert group["count"] == 2 + elif group["groupBy"]["sample"]["collectionLocation"] == "Mountain View, CA": + assert group["count"] == 3 + + +@pytest.mark.asyncio +async def test_groupby_query_with_multiple_fields( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can perform a groupby query with fields nested at multiple levels + """ + with sync_db.session() as session: + SessionStorage.set_session(session) + sample_1 = SampleFactory(owner_user_id=111, collection_id=888, collection_location="San Francisco, CA") + sample_2 = SampleFactory(owner_user_id=111, collection_id=888, collection_location="Mountain View, CA") + SequencingReadFactory.create_batch( + 1, + sample=sample_1, + owner_user_id=sample_1.owner_user_id, + collection_id=sample_1.collection_id, + technology="Illumina", + ) + SequencingReadFactory.create_batch( + 2, + sample=sample_1, + owner_user_id=sample_1.owner_user_id, + collection_id=sample_1.collection_id, + technology="Nanopore", + ) + SequencingReadFactory.create_batch( + 3, + sample=sample_2, + owner_user_id=sample_2.owner_user_id, + collection_id=sample_2.collection_id, + technology="Illumina", + ) + SequencingReadFactory.create_batch( + 4, + sample=sample_2, + owner_user_id=sample_2.owner_user_id, + collection_id=sample_2.collection_id, + technology="Nanopore", + ) + + query = """ + query MyQuery { + sequencingReadsAggregate { + aggregate { + groupBy { + sample { + collectionLocation + } + technology + } + count + } + } + } + """ + results = await gql_client.query(query, user_id=111, member_projects=[888]) + aggregate = results["data"]["sequencingReadsAggregate"]["aggregate"] + for group in aggregate: + if ( + group["groupBy"]["sample"]["collectionLocation"] == "San Francisco, CA" + and group["groupBy"]["technology"] == "Illumina" + ): + assert group["count"] == 1 + elif ( + group["groupBy"]["sample"]["collectionLocation"] == "San Francisco, CA" + and group["groupBy"]["technology"] == "Nanopore" + ): + assert group["count"] == 2 + elif ( + group["groupBy"]["sample"]["collectionLocation"] == "Mountain View, CA" + and group["groupBy"]["technology"] == "Illumina" + ): + assert group["count"] == 3 + elif ( + group["groupBy"]["sample"]["collectionLocation"] == "Mountain View, CA" + and group["groupBy"]["technology"] == "Nanopore" + ): + assert group["count"] == 4 + + +@pytest.mark.asyncio +async def test_deeply_nested_groupby_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can perform a deeply nested groupby query + """ + + user_id = 12345 + project_id = 123 + + with sync_db.session() as session: + SessionStorage.set_session(session) + SessionStorage.set_session(session) + + upstream_db_1 = UpstreamDatabaseFactory(owner_user_id=user_id, collection_id=project_id, name="NCBI") + upstream_db_2 = UpstreamDatabaseFactory(owner_user_id=user_id, collection_id=project_id, name="GTDB") + contig_1 = ContigFactory(owner_user_id=user_id, collection_id=project_id, upstream_database=upstream_db_1) + contig_2 = ContigFactory(owner_user_id=user_id, collection_id=project_id, upstream_database=upstream_db_2) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_1) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_1) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_2) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_2) + + query = """ + query MyQuery { + sequencingReadsAggregate { + aggregate { + count + groupBy { + contig { + upstreamDatabase { + name + } + } + } + } + } + } + """ + results = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + aggregate = results["data"]["sequencingReadsAggregate"]["aggregate"] + for group in aggregate: + if group["groupBy"]["contig"]["upstreamDatabase"]["name"] == "NCBI": + assert group["count"] == 2 + elif group["groupBy"]["contig"]["upstreamDatabase"]["name"] == "GTDB": + assert group["count"] == 2 + + +@pytest.mark.asyncio +async def test_soft_deleted_data_not_in_aggregate_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that soft-deleted data is not included in aggregate queries + """ + user_id = 12345 + project_id = 123 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + SampleFactory.create_batch( + 4, collection_location="San Francisco, CA", owner_user_id=user_id, collection_id=project_id + ) + sample_to_delete = SampleFactory( + collection_location="Mountain View, CA", owner_user_id=user_id, collection_id=project_id + ) + + # Verify that there are 5 samples in the database + aggregate_query = """ + query MyQuery { + samplesAggregate { + aggregate { + count + } + } + } + """ + output = await gql_client.query(aggregate_query, user_id=user_id, member_projects=[project_id]) + assert output["data"]["samplesAggregate"]["aggregate"][0]["count"] == 5 + + # Soft-delete a sample by updating its deleted_at field + soft_delete_query = f""" + mutation MyMutation {{ + updateSample ( + where: {{ id: {{ _eq: "{sample_to_delete.id}" }} }}, + input: {{ deletedAt: "2021-01-01T00:00:00Z" }} + ) + {{ + id + }} + }} + """ + # Only service identities are allowed to soft delete entities + output = await gql_client.query( + soft_delete_query, user_id=user_id, member_projects=[project_id], service_identity="workflows" + ) + assert output["data"]["updateSample"][0]["id"] == str(sample_to_delete.id) + + # The soft-deleted sample should not be included in the aggregate query anymore + output = await gql_client.query(aggregate_query, user_id=user_id, member_projects=[project_id]) + assert output["data"]["samplesAggregate"]["aggregate"][0]["count"] == 4 + + # The soft-deleted sample should be included in the aggregate query if we specifically ask for it + aggregate_soft_deleted_query = """ + query MyQuery { + samplesAggregate(where: { deletedAt: { _is_null: false } }) { + aggregate { + count + } + } + } + """ + output = await gql_client.query(aggregate_soft_deleted_query, user_id=user_id, member_projects=[project_id]) + assert output["data"]["samplesAggregate"]["aggregate"][0]["count"] == 1 diff --git a/test_app/tests/test_authorization.py b/test_app/tests/test_authorization.py index 3f794e9..5bd486d 100644 --- a/test_app/tests/test_authorization.py +++ b/test_app/tests/test_authorization.py @@ -2,10 +2,14 @@ Authorization spot-checks """ +import uuid + import pytest -from platformics.database.connect import SyncDB +from database.models import Sample from conftest import GQLTestClient, SessionStorage from test_infra.factories.sample import SampleFactory +from test_infra.factories.sequencing_read import SequencingReadFactory +from platformics.database.connect import SyncDB @pytest.mark.asyncio @@ -45,3 +49,235 @@ async def test_collection_authorization( output = await gql_client.query(query, user_id=user_id, member_projects=project_ids) assert len(output["data"]["samples"]) == num_results assert {sample["collectionLocation"] for sample in output["data"]["samples"]} == set(cities) + + +@pytest.mark.asyncio +async def test_system_fields_only_creatable_by_system( + gql_client: GQLTestClient, +) -> None: + """ + Make sure only system users can set system fields + """ + user_id = 12345 + project_ids = [333] + producing_run_id = str(uuid.uuid4()) + query = f""" + mutation MyMutation {{ + createGenomicRange(input: {{collectionId: {project_ids[0]}, producingRunId: "{producing_run_id}" }}) {{ + collectionId + producingRunId + }} + }} + """ + + # Our mutation should have been saved because we are a system user. + output = await gql_client.query(query, user_id=user_id, member_projects=project_ids, service_identity="workflows") + assert output["data"]["createGenomicRange"]["collectionId"] == 333 + assert output["data"]["createGenomicRange"]["producingRunId"] == producing_run_id + + # Our mutation should have ignored producing run because we're not a system user. + output = await gql_client.query(query, user_id=user_id, member_projects=project_ids) + assert output["data"]["createGenomicRange"]["collectionId"] == 333 + assert output["data"]["createGenomicRange"]["producingRunId"] is None + + +@pytest.mark.asyncio +async def test_system_fields_only_mutable_by_system( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Make sure only system users can mutate system fields + """ + user_id = 12345 + project_ids = [333] + + with sync_db.session() as session: + SessionStorage.set_session(session) + sample = SampleFactory.create(collection_location="City1", owner_user_id=999, collection_id=333) + + # Fetch all samples + def get_query(input_value: str) -> str: + return f""" + mutation MyMutation {{ + updateSample( + where: {{id: {{_eq: "{sample.id}" }} }}, + input: {{systemMutableField: "{input_value}"}}) {{ + id + systemMutableField + }} + }} + """ + + output = await gql_client.query( + get_query("hello"), user_id=user_id, member_projects=project_ids, service_identity="workflows" + ) + # Our mutation should have been saved because we are a system user. + assert output["data"]["updateSample"][0]["systemMutableField"] == "hello" + + output = await gql_client.query(get_query("goodbye"), user_id=user_id, member_projects=project_ids) + # This field should have been ignored because we're not a system user + assert output["data"]["updateSample"][0]["systemMutableField"] == "hello" + + +@pytest.mark.asyncio +async def test_system_types_only_mutable_by_system( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Make sure only system users can mutate system fields + """ + user_id = 12345 + project_ids = [333] + + # Fetch all samples + def get_update_query(id: str, input_value: str) -> str: + return f""" + mutation MyMutation {{ + updateSystemWritableOnlyType( + where: {{id: {{_eq: "{id}" }} }}, + input: {{name: "{input_value}"}}) {{ + id + name + }} + }} + """ + + # Fetch all samples + create_query = """ + mutation MyMutation { + createSystemWritableOnlyType( + input: { + collectionId: 333, + name: "row name here" + } + ) { id, name } + } + """ + + # Our mutation should have been saved because we are a system user. + output = await gql_client.query( + create_query, user_id=user_id, member_projects=project_ids, service_identity="workflows" + ) + assert output["data"]["createSystemWritableOnlyType"]["name"] == "row name here" + item_id = output["data"]["createSystemWritableOnlyType"]["id"] + + # Our mutation should have failed with an authorization error because we are not a system user + output = await gql_client.query(create_query, user_id=user_id, member_projects=project_ids) + assert "Unauthorized" in output["errors"][0]["message"] + assert "not creatable" in output["errors"][0]["message"] + + # This field should have been ignored because we're not a system user + output = await gql_client.query(get_update_query(item_id, "new_name"), user_id=user_id, member_projects=project_ids) + assert "Unauthorized" in output["errors"][0]["message"] + assert "not mutable" in output["errors"][0]["message"] + + # This field should have been ignored because we're not a system user + output = await gql_client.query( + get_update_query(item_id, "new_name"), + user_id=user_id, + member_projects=project_ids, + service_identity="workflows", + ) + assert output["data"]["updateSystemWritableOnlyType"][0]["name"] == "new_name" + + +@pytest.mark.asyncio +async def test_update_wont_associate_inaccessible_relationships( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Make sure users can only see samples in collections they have access to. + """ + owner_user_id = 333 + user_id = 12345 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + test_sample0 = SampleFactory.create(collection_location="City1", owner_user_id=999, collection_id=444) + test_sample1 = SampleFactory.create(collection_location="City2", owner_user_id=999, collection_id=444) + test_sample2 = SampleFactory.create(collection_location="City3", owner_user_id=999, collection_id=444) + test_sample3 = SampleFactory.create(collection_location="City4", owner_user_id=999, collection_id=444) + test_sequencing_read = SequencingReadFactory.create( + sample=test_sample0, owner_user_id=owner_user_id, collection_id=111 + ) + + def gen_query(test_sample: Sample) -> str: + # Fetch all samples + query = f""" + mutation MyMutation {{ + updateSequencingRead( + where: {{id: {{_eq: "{test_sequencing_read.id}"}} }}, + input: {{ + sampleId: "{test_sample.id}" + }} + ) {{ + id + sample {{ + id + }} + }} + }} + """ + return query + + # We are a member of 444 so this should work. + output = await gql_client.query(gen_query(test_sample1), user_id=user_id, member_projects=[111, 444]) + assert output["data"]["updateSequencingRead"][0]["sample"]["id"] == str(test_sample1.id) + + # We are NOT a member of 444 so this should break. + output = await gql_client.query(gen_query(test_sample2), user_id=user_id, member_projects=[111, 555]) + assert "Unauthorized" in output["errors"][0]["message"] + + # Project 444 is public so this should work + output = await gql_client.query( + gen_query(test_sample3), user_id=user_id, member_projects=[111, 555], viewer_projects=[444] + ) + assert output["data"]["updateSequencingRead"][0]["sample"]["id"] == str(test_sample3.id) + + +@pytest.mark.asyncio +async def test_create_wont_associate_inaccessible_relationships( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Make sure users can only see samples in collections they have access to. + """ + owner_user_id = 333 + user_id = 12345 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + test_sample = SampleFactory.create(collection_location="City2", owner_user_id=owner_user_id, collection_id=444) + + # Fetch all samples + query = f""" + mutation MyMutation {{ + createSequencingRead( + input: {{ + collectionId: 111, + sampleId: "{test_sample.id}", + protocol: artic_v4, + technology: Illumina, + nucleicAcid: RNA, + clearlabsExport: false + }} + ) {{ id }} + }} + """ + # We are a member of 444 so this should work. + output = await gql_client.query(query, user_id=user_id, member_projects=[111, 444]) + assert output["data"]["createSequencingRead"]["id"] + + # We are NOT a member of 444 so this should break. + output = await gql_client.query(query, user_id=user_id, member_projects=[111, 555]) + assert "Unauthorized" in output["errors"][0]["message"] + + # Project 444 is public so this should work + output = await gql_client.query(query, user_id=user_id, member_projects=[111, 555], viewer_projects=[444]) + assert output["data"]["createSequencingRead"]["id"] diff --git a/test_app/tests/test_basic_queries.py b/test_app/tests/test_basic_queries.py index ac9d741..0ab1f6a 100644 --- a/test_app/tests/test_basic_queries.py +++ b/test_app/tests/test_basic_queries.py @@ -2,11 +2,14 @@ Test basic queries and mutations """ +import datetime import pytest from platformics.database.connect import SyncDB from conftest import GQLTestClient, SessionStorage from test_infra.factories.sample import SampleFactory +date_now = datetime.datetime.now() + @pytest.mark.asyncio async def test_graphql_query( @@ -24,13 +27,25 @@ async def test_graphql_query( with sync_db.session() as session: SessionStorage.set_session(session) SampleFactory.create_batch( - 2, collection_location="San Francisco, CA", owner_user_id=user_id, collection_id=project_id + 2, + collection_location="San Francisco, CA", + collection_date=date_now, + owner_user_id=user_id, + collection_id=project_id, ) SampleFactory.create_batch( - 6, collection_location="Mountain View, CA", owner_user_id=user_id, collection_id=project_id + 6, + collection_location="Mountain View, CA", + collection_date=date_now, + owner_user_id=user_id, + collection_id=project_id, ) SampleFactory.create_batch( - 4, collection_location="Phoenix, AZ", owner_user_id=secondary_user_id, collection_id=9999 + 4, + collection_location="Phoenix, AZ", + collection_date=date_now, + owner_user_id=secondary_user_id, + collection_id=9999, ) # Fetch all samples @@ -69,6 +84,7 @@ async def test_graphql_mutations( sampleType: "Type 1" waterControl: false collectionLocation: "San Francisco, CA" + collectionDate: "2024-01-01" collectionId: 123 }) { id, diff --git a/test_app/tests/test_cascade_deletion.py b/test_app/tests/test_cascade_deletion.py new file mode 100644 index 0000000..20bb38d --- /dev/null +++ b/test_app/tests/test_cascade_deletion.py @@ -0,0 +1,83 @@ +""" +Test cascade deletion +""" + +import pytest +from platformics.database.connect import SyncDB +from conftest import SessionStorage, GQLTestClient, FileFactory +from test_infra.factories.sequencing_read import SequencingReadFactory + + +@pytest.mark.asyncio +async def test_cascade_delete( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can make cascade deletions + """ + user_id = 12345 + project_id = 123 + + # Create mock data: 2 SequencingReads, each with a different Sample, and each with R1/R2 + with sync_db.session() as session: + SessionStorage.set_session(session) + sequencing_reads = SequencingReadFactory.create_batch( + 2, technology="Illumina", owner_user_id=user_id, collection_id=project_id + ) + FileFactory.update_file_ids() + + # Delete the first Sample + query = f""" + mutation MyMutation {{ + deleteSample (where: {{ id: {{ _eq: "{sequencing_reads[0].sample_id}" }} }}) {{ + id + }} + }} + """ + result = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + assert result["data"]["deleteSample"][0]["id"] == str(sequencing_reads[0].sample_id) + + # The first SequencingRead should be deleted + query = f""" + query MyQuery {{ + sequencingReadsAggregate(where: {{ id: {{ _eq: "{sequencing_reads[0].entity_id}"}} }}) {{ + aggregate {{ count }} + }} + }} + """ + result = await gql_client.query(query, member_projects=[project_id]) + assert result["data"]["sequencingReadsAggregate"]["aggregate"][0]["count"] == 0 + + # The second SequencingRead should still exist + query = f""" + query MyQuery {{ + sequencingReadsAggregate(where: {{ id: {{ _eq: "{sequencing_reads[1].entity_id}"}} }}) {{ + aggregate {{ count }} + }} + }} + """ + result = await gql_client.query(query, member_projects=[project_id]) + assert result["data"]["sequencingReadsAggregate"]["aggregate"][0]["count"] == 1 + + # Files from the first SequencingRead should be deleted + query = f""" + query MyQuery {{ + files(where: {{ entityId: {{ _eq: "{sequencing_reads[0].entity_id}" }} }}) {{ + id + }} + }} + """ + result = await gql_client.query(query, member_projects=[project_id]) + assert len(result["data"]["files"]) == 0 + + # Files from the second SequencingRead should NOT be deleted + query = f""" + query MyQuery {{ + files(where: {{ entityId: {{ _eq: "{sequencing_reads[1].entity_id}" }} }}) {{ + id + }} + }} + """ + result = await gql_client.query(query, member_projects=[project_id]) + assert len(result["data"]["files"]) == 2 diff --git a/test_app/tests/test_error_handling.py b/test_app/tests/test_error_handling.py new file mode 100644 index 0000000..fd92f35 --- /dev/null +++ b/test_app/tests/test_error_handling.py @@ -0,0 +1,54 @@ +""" +Test basic error handling +""" + +import pytest +from conftest import GQLTestClient + + +@pytest.mark.asyncio +async def test_unauthorized_error( + gql_client: GQLTestClient, +) -> None: + """ + Validate that expected errors don't get masked by our error handler. + """ + query = """ + mutation createOneSample { + createSample(input: { + name: "Test Sample" + sampleType: "Type 1" + waterControl: false + collectionLocation: "San Francisco, CA" + collectionDate: "2024-01-01" + collectionId: 123 + }) { + id, + collectionLocation + } + } + """ + output = await gql_client.query(query, member_projects=[456]) + + # Make sure we haven't masked expected errors. + assert output["data"] is None + assert "Unauthorized: Cannot create entity in this collection" in output["errors"][0]["message"] + + +@pytest.mark.asyncio +async def test_python_error( + gql_client: GQLTestClient, +) -> None: + """ + Validate that unexpected errors do get masked by our error handler. + """ + query = """ + query causeException { + uncaughtException + } + """ + output = await gql_client.query(query, member_projects=[456]) + + # Make sure we have masked unexpected errors. + assert output["data"] is None + assert "Unexpected error" in output["errors"][0]["message"] diff --git a/test_app/tests/test_field_constraints.py b/test_app/tests/test_field_constraints.py new file mode 100644 index 0000000..8091437 --- /dev/null +++ b/test_app/tests/test_field_constraints.py @@ -0,0 +1,133 @@ +""" +Authorization spot-checks +""" + +import uuid +import json +from typing import Any + +import pytest +from conftest import GQLTestClient, SessionStorage +from test_infra.factories.constraint_checked_type import ConstraintCheckedTypeFactory +from platformics.database.connect import SyncDB + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "field_name,invalid_values,valid_values", + [ + ("length3To8", ["a", "", " a ", " abcdefghi "], ["abcde", " abc ", "abcdefgh "]), + ("regexFormatCheck", ["hi", "aaa-bbb-ccc", "100-100-10000"], ["222-33-4444", " aaa333-44-5555ccc "]), + ("minValue0", [-7, -14], [0, 3, 999999]), + ("maxValue9", [11, 44444], [0, -17, 9]), + ("minValue0MaxValue9", [-39, 10, 9999], [0, 6, 9]), + ("float1dot1To2dot2", [0, 3.3], [1.1, 1.2, 2.2]), + ("noStringChecks", [], ["", "aasdfasdfasd", " lorem ipsum !@#$%^&*() "]), + ("noIntChecks", [], [65, 0, 400, 7]), + ("noFloatChecks", [], [-65.50, 0.00001, 400.6, 7]), + ("boolField", [], [True, False]), + ("enumField", [], ["RNA", "DNA"]), + ], +) +async def test_create_validation( + field_name: str, + invalid_values: list[Any], + valid_values: list[Any], + gql_client: GQLTestClient, +) -> None: + """ + Make sure input field validation works. + """ + user_id = 12345 + project_ids = [333] + + def get_query(field_name: str, value: Any) -> str: + if "enum" not in field_name: + value = json.dumps(value) + query = f""" + mutation MyMutation {{ + createConstraintCheckedType(input: {{collectionId: {project_ids[0]}, {field_name}: {value} }}) {{ + collectionId + {field_name} + }} + }} + """ + return query + + # These should succeed + for value in valid_values: + query = get_query(field_name, value) + output = await gql_client.query(query, user_id=user_id, member_projects=project_ids) + if isinstance(value, str): + assert output["data"]["createConstraintCheckedType"][field_name] == value.strip() + else: + assert output["data"]["createConstraintCheckedType"][field_name] == value + + # These should fail + for value in invalid_values: + query = get_query(field_name, value) + output = await gql_client.query(query, user_id=user_id, member_projects=project_ids) + assert "Validation Error" in output["errors"][0]["message"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "field_name,invalid_values,valid_values", + [ + ("length3To8", ["a", "", " a ", " abcdefghi "], ["abcde", " abc ", "abcdefgh "]), + ("regexFormatCheck", ["hi", "aaa-bbb-ccc", "100-100-10000"], ["222-33-4444", " aaa333-44-5555ccc "]), + ("minValue0", [-7, -14], [0, 3, 999999]), + ("maxValue9", [11, 44444], [0, -17, 9]), + ("minValue0MaxValue9", [-39, 10, 9999], [0, 6, 9]), + ("float1dot1To2dot2", [0, 3.3], [1.1, 1.2, 2.2]), + ("noStringChecks", [], ["", "aasdfasdfasd", " lorem ipsum !@#$%^&*() "]), + ("noIntChecks", [], [65, 0, 400, 7]), + ("noFloatChecks", [], [-65.50, 0.00001, 400.6, 7]), + ("boolField", [], [True, False]), + ("enumField", [], ["RNA", "DNA"]), + ], +) +async def test_update_validation( + field_name: str, + invalid_values: list[Any], + valid_values: list[Any], + gql_client: GQLTestClient, + sync_db: SyncDB, +) -> None: + """ + Make sure input field validation works. + """ + with sync_db.session() as session: + SessionStorage.set_session(session) + instance = ConstraintCheckedTypeFactory.create(owner_user_id=999, collection_id=333) + + user_id = 12345 + project_ids = [333] + + def get_query(id: uuid.UUID, field_name: str, value: Any) -> str: + if "enum" not in field_name: + value = json.dumps(value) + query = f""" + mutation MyMutation {{ + updateConstraintCheckedType(where: {{id: {{_eq: "{id}" }} }}, input: {{ {field_name}: {value} }}) {{ + collectionId + {field_name} + }} + }} + """ + return query + + # These should succeed + for value in valid_values: + query = get_query(instance.id, field_name, value) + output = await gql_client.query(query, user_id=user_id, member_projects=project_ids) + if isinstance(value, str): + assert output["data"]["updateConstraintCheckedType"][0][field_name] == value.strip() + else: + assert output["data"]["updateConstraintCheckedType"][0][field_name] == value + + # These should fail + for value in invalid_values: + query = get_query(instance.id, field_name, value) + output = await gql_client.query(query, user_id=user_id, member_projects=project_ids) + assert "Validation Error" in output["errors"][0]["message"] diff --git a/test_app/tests/test_field_visibility.py b/test_app/tests/test_field_visibility.py new file mode 100644 index 0000000..5779ab3 --- /dev/null +++ b/test_app/tests/test_field_visibility.py @@ -0,0 +1,184 @@ +""" +Test basic queries and mutations +""" + +import datetime + +import pytest +from conftest import GQLTestClient + +date_now = datetime.datetime.now() + + +@pytest.mark.asyncio +async def test_hidden_fields( + gql_client: GQLTestClient, +) -> None: + """ + Test that we can hide fields from the GQL interface + """ + user_id = 12345 + project_id = 123 + + # Introspect the GenomicRange type + query = """ + query MyQuery { + __type(name: "GenomicRange") { + name + fields { + name + type { + name + kind + } + } + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + field_names = [field["name"] for field in output["data"]["__type"]["fields"]] + + # entityId is a hidden field, make sure it's not part if our type def. + assert "entityId" not in field_names + + # ownerUserId is a regular field inherited from Entity, make sure we can see it. + assert "ownerUserId" in field_names + + # file is a regular field on the GR table, make sure we can see it. + assert "file" in field_names + assert "fileId" in field_names + + +@pytest.mark.asyncio +async def test_hidden_mutations( + gql_client: GQLTestClient, +) -> None: + """ + Test that we don't generate mutations unless they make sense + """ + user_id = 12345 + project_id = 123 + + # Introspect the Mutations fields + query = """ + query MyQuery { + __schema { + mutationType { + fields { + name + } + } + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + mutations = [field["name"] for field in output["data"]["__schema"]["mutationType"]["fields"]] + + # ImmutableType has `mutable: false` set on the entire table so it shouldn't have a mutation + assert "updateImmutableType" not in mutations + + # ImmutableType still allows creation. + assert "createImmutableType" in mutations + + +# Make sure we only allow certain fields to be set at entity creation time. +@pytest.mark.asyncio +async def test_update_fields( + gql_client: GQLTestClient, +) -> None: + """ + Test that we don't show immutable fields in update mutations. + """ + user_id = 12345 + project_id = 123 + + # Introspect the Mutations fields + query = """ + fragment FullType on __Type { + kind + name + inputFields { + ...InputValue + } + } + fragment InputValue on __InputValue { + name + type { + ...TypeRef + } + defaultValue + } + fragment TypeRef on __Type { + kind + name + ofType { + kind + name + } + } + query IntrospectionQuery { + __schema { + types { + ...FullType + } + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + create_type = [item for item in output["data"]["__schema"]["types"] if item["name"] == "SequencingReadUpdateInput"][ + 0 + ] + fields = [field["name"] for field in create_type["inputFields"]] + # We have a limited subset of mutable fields on SequencingRead + assert set(fields) == set(["nucleicAcid", "clearlabsExport", "technology", "sampleId", "deletedAt"]) + + +# Make sure we only allow certain fields to be set at entity creation time. +@pytest.mark.asyncio +async def test_creation_fields( + gql_client: GQLTestClient, +) -> None: + """ + Test that we don't generate mutations unless they actually do something + """ + user_id = 12345 + project_id = 123 + + # Introspect the Mutations fields + query = """ + fragment FullType on __Type { + kind + name + inputFields { + ...InputValue + } + } + fragment InputValue on __InputValue { + name + type { + ...TypeRef + } + defaultValue + } + fragment TypeRef on __Type { + kind + name + ofType { + kind + name + } + } + query IntrospectionQuery { + __schema { + types { + ...FullType + } + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + create_type = [item for item in output["data"]["__schema"]["types"] if item["name"] == "GenomicRangeCreateInput"][0] + fields = [field["name"] for field in create_type["inputFields"]] + # Producing run id and collection id are always settable on a new entity. + # producingRunId is only settable by a system user, and collectionId is settable by users. + assert set(fields) == set(["producingRunId", "collectionId", "deletedAt"]) diff --git a/test_app/tests/test_file_concatenation.py b/test_app/tests/test_file_concatenation.py index d1ea3e0..7a75c2b 100644 --- a/test_app/tests/test_file_concatenation.py +++ b/test_app/tests/test_file_concatenation.py @@ -4,9 +4,9 @@ import pytest import requests -from conftest import GQLTestClient, SessionStorage from mypy_boto3_s3.client import S3Client from platformics.database.connect import SyncDB +from conftest import SessionStorage, GQLTestClient from test_infra.factories.sequencing_read import SequencingReadFactory diff --git a/test_app/tests/test_file_mutations.py b/test_app/tests/test_file_mutations.py index 68336f6..16d944b 100644 --- a/test_app/tests/test_file_mutations.py +++ b/test_app/tests/test_file_mutations.py @@ -3,15 +3,15 @@ """ import os - import pytest +import typing import sqlalchemy as sa -from conftest import FileFactory, GQLTestClient, SessionStorage -from database.models import SequencingRead from mypy_boto3_s3.client import S3Client from platformics.database.connect import SyncDB -from platformics.database.models import File, FileStatus +from database.models import File, FileStatus +from conftest import SessionStorage, FileFactory, GQLTestClient from test_infra.factories.sequencing_read import SequencingReadFactory +from database.models import SequencingRead @pytest.mark.asyncio @@ -106,16 +106,8 @@ async def test_invalid_fastq( @pytest.mark.parametrize( "member_projects,project_id,entity_field", [ - ( - [456], - 123, - "r1_file", - ), # Can't create file for entity you don't have access to - ( - [123], - 123, - "does_not_exist", - ), # Can't create file for entity that isn't connected to a valid file type + ([456], 123, "r1_file"), # Can't create file for entity you don't have access to + ([123], 123, "does_not_exist"), # Can't create file for entity that isn't connected to a valid file type ([123], 123, "r1_file"), # Can create file for entity you have access to ], ) @@ -170,7 +162,7 @@ async def test_upload_file( assert output["errors"] is not None return - # Moto produces a hard-coded tokens + # Moto produces hard-coded tokens assert output["data"]["uploadFile"]["credentials"]["accessKeyId"].endswith("EXAMPLE") assert output["data"]["uploadFile"]["credentials"]["secretAccessKey"].endswith("EXAMPLEKEY") @@ -197,7 +189,7 @@ async def test_create_file( # Upload a fastq file to a mock bucket so we can create a file object from it file_namespace = "local-bucket" file_path = "test1.fastq" - file_path_local = "tests/fixtures/test1.fastq" + file_path_local = "tests/fixtures/test1.fasta" file_size = os.stat(file_path_local).st_size with open(file_path_local, "rb") as fp: moto_client.put_object(Bucket=file_namespace, Key=file_path, Body=fp) @@ -211,7 +203,7 @@ async def test_create_file( file: {{ name: "{file_path}", fileFormat: "fastq", - protocol: "s3", + protocol: s3, namespace: "{file_namespace}", path: "{file_path}" }} @@ -219,7 +211,99 @@ async def test_create_file( path size }} - }} + """ - output = await gql_client.query(mutation, member_projects=[123]) + output = await gql_client.query(mutation, member_projects=[123], service_identity="workflows") assert output["data"]["createFile"]["size"] == file_size + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "file_path,multiple_files_for_one_path,should_delete", + [ + ("nextgen/test1.fastq", False, True), + ("bla/test1.fastq", False, False), + ("nextgen/test1.fastq", True, False), + ("bla/test1.fastq", True, False), + ], +) +async def test_delete_from_s3( + file_path: str, + should_delete: bool, + multiple_files_for_one_path: bool, + sync_db: SyncDB, + gql_client: GQLTestClient, + moto_client: S3Client, + monkeypatch: typing.Any, +) -> None: + """ + Test that we delete a file from S3 under the right circumstances + """ + user1_id = 12345 + project1_id = 123 + user2_id = 67890 + project2_id = 456 + bucket = "local-bucket" + + # Patch the S3 client to make sure tests are operating on the same mock bucket + monkeypatch.setattr(File, "get_s3_client", lambda: moto_client) + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + SequencingReadFactory.create(owner_user_id=user1_id, collection_id=project1_id) + FileFactory.update_file_ids() + session.commit() + files = session.execute(sa.select(File)).scalars().all() + file = list(filter(lambda file: file.entity_field_name == "r1_file", files))[0] + file.path = file_path + file.namespace = bucket # set the bucket to make sure the mock file is in the right place + session.commit() + + # Also test the case where multiple files point to the same path + if multiple_files_for_one_path: + sequencing_read = SequencingReadFactory.create(owner_user_id=user2_id, collection_id=project2_id) + FileFactory.update_file_ids() + session.commit() + session.refresh(sequencing_read) + sequencing_read.r1_file.path = file_path + sequencing_read.r1_file.namespace = bucket + session.commit() + + valid_fastq_file = "tests/fixtures/test1.fastq" + moto_client.put_object(Bucket=file.namespace, Key=file.path, Body=open(valid_fastq_file, "rb")) + + # Delete SequencingRead and cascade to File objects + query = f""" + mutation MyMutation {{ + deleteSequencingRead(where: {{ id: {{ _eq: "{file.entity_id}" }} }}) {{ + id + }} + }} + """ + + # File should exist on S3 before the deletion + assert "Contents" in moto_client.list_objects(Bucket=file.namespace, Prefix=file.path) + + # Issue deletion + result = await gql_client.query( + query, user_id=user1_id, member_projects=[project1_id], service_identity="workflows" + ) + assert result["data"]["deleteSequencingRead"][0]["id"] == str(file.entity_id) + + # Make sure file either does or does not exist + if should_delete: + assert "Contents" not in moto_client.list_objects(Bucket=file.namespace, Prefix=file.path) + else: + assert "Contents" in moto_client.list_objects(Bucket=file.namespace, Prefix=file.path) + + # Make sure File object doesn't exist either + query = f""" + query MyQuery {{ + files(where: {{ id: {{ _eq: "{file.id}" }} }}) {{ + id + }} + }} + """ + result = await gql_client.query(query, user_id=user1_id, member_projects=[project1_id]) + assert result["data"]["files"] == [] diff --git a/test_app/tests/test_file_queries.py b/test_app/tests/test_file_queries.py index 0cb555c..1488b51 100644 --- a/test_app/tests/test_file_queries.py +++ b/test_app/tests/test_file_queries.py @@ -3,9 +3,9 @@ """ import pytest -from platformics.database.connect import SyncDB from conftest import FileFactory, GQLTestClient, SessionStorage from test_infra.factories.sequencing_read import SequencingReadFactory +from platformics.database.connect import SyncDB @pytest.mark.asyncio @@ -46,12 +46,11 @@ async def test_file_query( } """ output = await gql_client.query(query, member_projects=[project1_id]) - # Each SequencingRead results in 5 files: + # Each SequencingRead results in 3 files: # r1_file, r2_file # primer_file -> GenomicRange file - # GenomicRange produces ReferenceGenome -> file and file_index - # so we expect 8 * 5 = 40 files. - assert len(output["data"]["files"]) == 40 + # so we expect 8 * 3 = 24 files. + assert len(output["data"]["files"]) == 24 for file in output["data"]["files"]: assert file["path"] is not None assert file["entity"]["collectionId"] == project1_id diff --git a/test_app/tests/test_limit_offset_queries.py b/test_app/tests/test_limit_offset_queries.py new file mode 100644 index 0000000..acd5497 --- /dev/null +++ b/test_app/tests/test_limit_offset_queries.py @@ -0,0 +1,108 @@ +""" +Test limit/offset on top-level queries +""" + +import datetime +import pytest +from platformics.database.connect import SyncDB +from conftest import GQLTestClient, SessionStorage +from test_infra.factories.sample import SampleFactory + +date_now = datetime.datetime.now() + + +@pytest.mark.asyncio +async def test_limit_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can limit the number of samples returned + """ + user_id = 12345 + project_id = 123 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + SampleFactory.create_batch(10, owner_user_id=user_id, collection_id=project_id) + + # Fetch all samples + query = """ + query limitQuery { + samples(limitOffset: {limit: 3}) { + id, + collectionLocation + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + assert len(output["data"]["samples"]) == 3 + + +@pytest.mark.asyncio +async def test_offset_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can offset the number of samples returned + """ + user_id = 12345 + project_id = 123 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + for i in range(10): + SampleFactory.create(name=f"Sample {i}", owner_user_id=user_id, collection_id=project_id) + + # Fetch all samples + all_samples_query = """ + query allSamples { + samples(orderBy: {name: asc}) { + name + } + } + """ + + output = await gql_client.query(all_samples_query, user_id=user_id, member_projects=[project_id]) + all_sample_names = [sample["name"] for sample in output["data"]["samples"]] + + # Fetch samples with limit: 3, offset: 3 + query = """ + query offsetQuery { + samples(limitOffset: {limit: 3, offset: 3}, orderBy: {name: asc}) { + name + } + } + """ + + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + offset_sample_names = [sample["name"] for sample in output["data"]["samples"]] + assert offset_sample_names == all_sample_names[3:6] + + # If we offset by 10, we should get an empty list + query = """ + query offsetQuery { + samples(limitOffset: {limit: 1, offset: 10}, orderBy: {name: asc}) { + name + } + } + """ + + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + assert len(output["data"]["samples"]) == 0 + + # If a user includes an offset without a limit, we should get an error + query = """ + query offsetQuery { + samples(limitOffset: {offset: 1}, orderBy: {name: asc}) { + name + } + } + """ + + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + assert output["data"] is None + assert "Cannot use offset without limit" in output["errors"][0]["message"] diff --git a/test_app/tests/test_nested_queries.py b/test_app/tests/test_nested_queries.py index d1adb7b..3a0d8f8 100644 --- a/test_app/tests/test_nested_queries.py +++ b/test_app/tests/test_nested_queries.py @@ -3,14 +3,13 @@ """ import base64 -from collections import defaultdict - import pytest -from conftest import GQLTestClient, SessionStorage -from platformics.api.types.entities import Entity +from collections import defaultdict from platformics.database.connect import SyncDB +from conftest import GQLTestClient, SessionStorage from test_infra.factories.sample import SampleFactory from test_infra.factories.sequencing_read import SequencingReadFactory +from platformics.api.types.entities import Entity def get_id(entity: Entity) -> str: diff --git a/test_app/tests/test_sorting_queries.py b/test_app/tests/test_sorting_queries.py new file mode 100644 index 0000000..08ed951 --- /dev/null +++ b/test_app/tests/test_sorting_queries.py @@ -0,0 +1,244 @@ +""" +Test queries with an ORDER BY clause +""" + +import pytest +from platformics.database.connect import SyncDB +from conftest import GQLTestClient, SessionStorage +from test_infra.factories.sample import SampleFactory +from test_infra.factories.sequencing_read import SequencingReadFactory +from test_infra.factories.contig import ContigFactory +from test_infra.factories.upstream_database import UpstreamDatabaseFactory + + +@pytest.mark.asyncio +async def test_basic_order_by_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can add an ORDER BY clause to a query + """ + user_id = 12345 + project_id = 123 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + SampleFactory.create_batch( + 2, collection_location="San Francisco, CA", owner_user_id=user_id, collection_id=project_id + ) + SampleFactory.create_batch( + 1, collection_location="Mountain View, CA", owner_user_id=user_id, collection_id=project_id + ) + SampleFactory.create_batch( + 2, collection_location="Los Angeles, CA", owner_user_id=user_id, collection_id=project_id + ) + + # Fetch all samples, in descending order of collection location + query = """ + query MyQuery { + samples(orderBy: {collectionLocation: desc}) { + collectionLocation + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + locations = [sample["collectionLocation"] for sample in output["data"]["samples"]] + assert locations == [ + "San Francisco, CA", + "San Francisco, CA", + "Mountain View, CA", + "Los Angeles, CA", + "Los Angeles, CA", + ] + + +@pytest.mark.asyncio +async def test_order_multiple_fields_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can sort by multiple fields, and that the order of the fields are preserved + """ + user_id = 12345 + project_id = 123 + secondary_project_id = 234 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + SampleFactory(owner_user_id=user_id, collection_id=project_id, collection_location="San Francisco, CA") + SampleFactory( + owner_user_id=user_id, collection_id=secondary_project_id, collection_location="San Francisco, CA" + ) + SampleFactory(owner_user_id=user_id, collection_id=project_id, collection_location="Mountain View, CA") + SampleFactory( + owner_user_id=user_id, collection_id=secondary_project_id, collection_location="Mountain View, CA" + ) + + # Fetch all samples, in descending order of collection id and then ascending order of collection location + query = """ + query MyQuery { + samples(orderBy: [{collectionId: desc}, {collectionLocation: asc}]) { + collectionLocation, + collectionId + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id, secondary_project_id]) + locations = [sample["collectionLocation"] for sample in output["data"]["samples"]] + collection_ids = [sample["collectionId"] for sample in output["data"]["samples"]] + assert locations == ["Mountain View, CA", "San Francisco, CA", "Mountain View, CA", "San Francisco, CA"] + assert collection_ids == [234, 234, 123, 123] + + # Fetch all samples, in ascending order of collection location and then descending order of collection id + query = """ + query MyQuery { + samples(orderBy: [{collectionLocation: asc}, {collectionId: desc}]) { + collectionLocation, + collectionId + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id, secondary_project_id]) + locations = [sample["collectionLocation"] for sample in output["data"]["samples"]] + collection_ids = [sample["collectionId"] for sample in output["data"]["samples"]] + assert locations == ["Mountain View, CA", "Mountain View, CA", "San Francisco, CA", "San Francisco, CA"] + assert collection_ids == [234, 123, 234, 123] + + +@pytest.mark.asyncio +async def test_sort_nested_objects_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can sort nested objects + """ + user_id = 12345 + project_id = 123 + + with sync_db.session() as session: + SessionStorage.set_session(session) + sample_1 = SampleFactory( + owner_user_id=user_id, collection_id=project_id, collection_location="San Francisco, CA" + ) + sample_2 = SampleFactory( + owner_user_id=user_id, collection_id=project_id, collection_location="Mountain View, CA" + ) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_1, nucleic_acid="DNA") + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_1, nucleic_acid="RNA") + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_2, nucleic_acid="DNA") + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_2, nucleic_acid="RNA") + + # Fetch all samples, in descending order of collection location, and then in ascending order of + # the related sequencing read's nucleic acid + query = """ + query MyQuery { + samples(orderBy: {collectionLocation: desc}) { + collectionLocation + sequencingReads(orderBy: {nucleicAcid: asc}) { + edges { + node { + nucleicAcid + } + } + } + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + locations = [sample["collectionLocation"] for sample in output["data"]["samples"]] + assert locations == ["San Francisco, CA", "Mountain View, CA"] + nucleic_acids = [] + for sample in output["data"]["samples"]: + for sr in sample["sequencingReads"]["edges"]: + nucleic_acids.append(sr["node"]["nucleicAcid"]) + assert nucleic_acids == ["DNA", "RNA", "DNA", "RNA"] + + +@pytest.mark.asyncio +async def test_order_by_related_field_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can sort by fields of a related object + """ + user_id = 12345 + project_id = 123 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + sample_1 = SampleFactory( + owner_user_id=user_id, collection_id=project_id, collection_location="Mountain View, CA" + ) + sample_2 = SampleFactory( + owner_user_id=user_id, collection_id=project_id, collection_location="San Francisco, CA" + ) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_1, nucleic_acid="DNA") + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_1, nucleic_acid="RNA") + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_2, nucleic_acid="DNA") + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, sample=sample_2, nucleic_acid="RNA") + + # Fetch all sequencing reads, in descending order of the related sample's collection location + query = """ + query MyQuery { + sequencingReads(orderBy: {sample: {collectionLocation: desc}}) { + nucleicAcid + sample { + collectionLocation + } + } + } + """ + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + nucleic_acids = [sr["nucleicAcid"] for sr in output["data"]["sequencingReads"]] + collection_locations = [sr["sample"]["collectionLocation"] for sr in output["data"]["sequencingReads"]] + assert nucleic_acids == ["DNA", "RNA", "DNA", "RNA"] + assert collection_locations == ["San Francisco, CA", "San Francisco, CA", "Mountain View, CA", "Mountain View, CA"] + + +@pytest.mark.asyncio +async def test_deeply_nested_query( + sync_db: SyncDB, + gql_client: GQLTestClient, +) -> None: + """ + Test that we can sort by fields of a very deeply nested object + """ + user_id = 12345 + project_id = 123 + + # Create mock data + with sync_db.session() as session: + SessionStorage.set_session(session) + upstream_db_1 = UpstreamDatabaseFactory(owner_user_id=user_id, collection_id=project_id, name="NCBI") + upstream_db_2 = UpstreamDatabaseFactory(owner_user_id=user_id, collection_id=project_id, name="GTDB") + contig_1 = ContigFactory(owner_user_id=user_id, collection_id=project_id, upstream_database=upstream_db_1) + contig_2 = ContigFactory(owner_user_id=user_id, collection_id=project_id, upstream_database=upstream_db_2) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_1) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_1) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_2) + SequencingReadFactory(owner_user_id=user_id, collection_id=project_id, contig=contig_2) + + # Fetch all contigs, in descending order of the related sequencing read's taxon's upstream database's name + query = """ + query MyQuery { + sequencingReads(orderBy: {contig: {upstreamDatabase: {name: desc}}}) { + id + contig { + upstreamDatabase { + name + } + } + } + } + """ + + output = await gql_client.query(query, user_id=user_id, member_projects=[project_id]) + upstream_database_names = [d["contig"]["upstreamDatabase"]["name"] for d in output["data"]["sequencingReads"]] + assert upstream_database_names == ["NCBI", "NCBI", "GTDB", "GTDB"] diff --git a/test_app/tests/test_test_setup.py b/test_app/tests/test_test_setup.py index 89c9f7e..509c637 100644 --- a/test_app/tests/test_test_setup.py +++ b/test_app/tests/test_test_setup.py @@ -12,7 +12,7 @@ async def test_graphql_query(gql_client: GQLTestClient, api_test_schema: FastAPI """ Make sure we're using the right schema and http client """ - assert api_test_schema.title == "Codegen Tests" + assert api_test_schema.title == "Platformics" assert gql_client.http_client.base_url.host == "test-codegen" # Reference genomes is not an entity in the mock schema but is one in the real schema diff --git a/test_app/tests/test_where_clause.py b/test_app/tests/test_where_clause.py index 7910d4f..3963df6 100644 --- a/test_app/tests/test_where_clause.py +++ b/test_app/tests/test_where_clause.py @@ -4,7 +4,8 @@ import pytest from platformics.database.connect import SyncDB -from conftest import GQLTestClient, SessionStorage +from conftest import GQLTestClient, SessionStorage, FileFactory +from test_infra.factories.sample import SampleFactory from test_infra.factories.sequencing_read import SequencingReadFactory from support.enums import SequencingTechnology @@ -155,3 +156,168 @@ async def test_where_clause_mutations(sync_db: SyncDB, gql_client: GQLTestClient assert sequencing_read["technology"] == new_technology else: assert sequencing_read["technology"] == prev_technology + + +@pytest.mark.asyncio +async def test_where_clause_regex_match(sync_db: SyncDB, gql_client: GQLTestClient) -> None: + """ + Verify that the regex operators work as expected. + """ + # Regex for any string with "MATCH" in it + regex = ".*MATCH.*" + with sync_db.session() as session: + SessionStorage.set_session(session) + # Sample names that match the regex (case-sensitive) + case_sensitive_matches = ["A MATCH", "A MATCHING SAMPLE"] + # Sample names that match the regex if case is ignored, but do not match if case is considered + case_insensitive_matches = ["a match if ignore case", "a matching sample if ignore case"] + # Sample names that don't match the regex at all + no_matches = ["asdf1234", "HCTAM"] + # Create the samples + all_sample_names = case_sensitive_matches + case_insensitive_matches + no_matches + for name in all_sample_names: + SampleFactory.create(name=name, owner_user_id=user_id, collection_id=project_id) + + match_case_query = f""" + query GetSamplesMatchingCase {{ + samples ( where: {{ + name: {{ + _regex: "{regex}" + }} + }}) {{ + name + }} + }} + """ + + match_case_query_output = await gql_client.query(match_case_query, member_projects=[project_id]) + assert len(match_case_query_output["data"]["samples"]) == 2 + output_sample_names = [sample["name"] for sample in match_case_query_output["data"]["samples"]] + assert sorted(output_sample_names) == sorted(case_sensitive_matches) + + ignore_case_query = f""" + query GetSamplesMatchingIgnoreCase {{ + samples ( where: {{ + name: {{ + _iregex: "{regex}" + }} + }}) {{ + name + }} + }} + """ + + ignore_case_query_output = await gql_client.query(ignore_case_query, member_projects=[project_id]) + assert len(ignore_case_query_output["data"]["samples"]) == 4 + output_sample_names = [sample["name"] for sample in ignore_case_query_output["data"]["samples"]] + assert sorted(output_sample_names) == sorted(case_sensitive_matches + case_insensitive_matches) + + no_match_query = f""" + query GetSamplesNoMatch {{ + samples ( where: {{ + name: {{ + _nregex: "{regex}" + }} + }}) {{ + name + }} + }} + """ + + no_match_query_output = await gql_client.query(no_match_query, member_projects=[project_id]) + assert len(no_match_query_output["data"]["samples"]) == 4 + output_sample_names = [sample["name"] for sample in no_match_query_output["data"]["samples"]] + assert sorted(output_sample_names) == sorted(no_matches + case_insensitive_matches) + + no_match_ignore_case_query = f""" + query GetSamplesNoMatchIgnoreCase {{ + samples ( where: {{ + name: {{ + _niregex: "{regex}" + }} + }}) {{ + name + }} + }} + """ + + no_match_ignore_case_query_output = await gql_client.query(no_match_ignore_case_query, member_projects=[project_id]) + assert len(no_match_ignore_case_query_output["data"]["samples"]) == 2 + output_sample_names = [sample["name"] for sample in no_match_ignore_case_query_output["data"]["samples"]] + assert sorted(output_sample_names) == sorted(no_matches) + + +@pytest.mark.asyncio +async def test_soft_deleted_objects(sync_db: SyncDB, gql_client: GQLTestClient) -> None: + """ + Verify that the where clause works as expected with soft-deleted objects. + By default, soft-deleted objects should not be returned. + """ + sequencing_reads = generate_sequencing_reads(sync_db) + FileFactory.update_file_ids() + # Soft delete the first 3 sequencing reads by updating the deleted_at field + deleted_ids = [str(sequencing_reads[0].id), str(sequencing_reads[1].id), str(sequencing_reads[2].id)] + soft_delete_mutation = f""" + mutation SoftDeleteSequencingReads {{ + updateSequencingRead ( + where: {{ + id: {{ _in: [ "{deleted_ids[0]}", "{deleted_ids[1]}", "{deleted_ids[2]}" ] }}, + }}, + input: {{ + deletedAt: "2021-01-01T00:00:00Z", + }} + ) {{ + id + }} + }} + """ + output = await gql_client.query(soft_delete_mutation, member_projects=[project_id], service_identity="workflows") + assert len(output["data"]["updateSequencingRead"]) == 3 + + # Check that the soft-deleted sequencing reads are not returned + regular_query = """ + query GetSequencingReads { + sequencingReads { + id + } + } + """ + output = await gql_client.query(regular_query, member_projects=[project_id]) + assert len(output["data"]["sequencingReads"]) == 2 + for sequencing_read in output["data"]["sequencingReads"]: + assert sequencing_read["id"] not in deleted_ids + + # Check that the soft-deleted sequencing reads are returned when explicitly requested + soft_deleted_query = """ + query GetSequencingReads { + sequencingReads ( where: { deletedAt: { _is_null: false } }) { + id + } + } + """ + output = await gql_client.query(soft_deleted_query, member_projects=[project_id]) + assert len(output["data"]["sequencingReads"]) == 3 + for sequencing_read in output["data"]["sequencingReads"]: + assert sequencing_read["id"] in deleted_ids + + # Check that we can hard-delete the soft-deleted objects + hard_delete_mutation = f""" + mutation DeleteSequencingReads {{ + deleteSequencingRead ( + where: {{ + id: {{ _in: [ "{deleted_ids[0]}", "{deleted_ids[1]}", "{deleted_ids[2]}" ] }}, + }} + ) {{ + id + }} + }} + """ + + output = await gql_client.query( + hard_delete_mutation, user_id=user_id, member_projects=[project_id], service_identity="workflows" + ) + assert len(output["data"]["deleteSequencingRead"]) == 3 + + # Check that the hard-deleted sequencing reads are not returned + output = await gql_client.query(soft_deleted_query, member_projects=[project_id]) + assert len(output["data"]["sequencingReads"]) == 0