From a078d57ba66dfce15d0eb3c69b7dacfd3ff3e4ae Mon Sep 17 00:00:00 2001 From: Dhruv <83733638+dhruv-ahuja@users.noreply.github.com> Date: Sat, 24 Aug 2024 08:35:25 +0530 Subject: [PATCH] Fix Nested Filter-Sort and Model Querying Issues (#29) * refactor: make filter operators constant * fix: resolve filtering issues - ensure each filter is applied only once per query - fix failures on nested filtering * fix: allow nested sorting --- src/config/constants/app.py | 18 +++++++++++++ src/services/poe.py | 26 +++++++----------- src/utils/services.py | 54 ++++++++++++++++++++++--------------- 3 files changed, 61 insertions(+), 37 deletions(-) diff --git a/src/config/constants/app.py b/src/config/constants/app.py index e8e0f4b..1fe1d23 100644 --- a/src/config/constants/app.py +++ b/src/config/constants/app.py @@ -1,8 +1,10 @@ from typing import Literal from uuid import uuid4 import datetime as dt +import operator from beanie.odm.interfaces.find import FindType, DocumentProjectionType +from beanie.odm.operators.find.evaluation import RegEx as RegExOperator from beanie.odm.queries.find import FindMany @@ -31,3 +33,19 @@ FIND_MANY_QUERY = FindMany[FindType] | FindMany[DocumentProjectionType] FILTER_OPERATION = Literal["=", "!=", ">", ">=", "<", "<=", "like"] +FILTER_OPERATION_MAP = { + "=": operator.eq, + "!=": operator.ne, + ">": operator.gt, + "<": operator.lt, + ">=": operator.ge, + "<=": operator.le, + "like": RegExOperator, +} +NESTED_FILTER_OPERATION_MAP = { + "=": "$eq", + ">": "$gt", + ">=": "$gte", + "<": "$lt", + "<=": "$lte", +} diff --git a/src/services/poe.py b/src/services/poe.py index f4946c3..06e1879 100644 --- a/src/services/poe.py +++ b/src/services/poe.py @@ -42,30 +42,24 @@ async def get_items( ) -> tuple[list[ItemBase], int]: """Gets items by given category group, and the total items' count in the database.""" - query = Item.find() - chainer = QueryChainer(query, Item) - if filter_sort_input is None: - items_count = await query.find(fetch_links=True).count() - items = await chainer.paginate(pagination).query.find(fetch_links=True).project(ItemBase).to_list() + items_count = await Item.find().count() + items = await QueryChainer(Item.find(), Item).paginate(pagination).query.find().project(ItemBase).to_list() return items, items_count - base_query_chain = chainer.filter(filter_sort_input.filter_).sort(filter_sort_input.sort) - - # * clone the query for use with total record counts and pagination calculations - count_query = ( - base_query_chain.filter(filter_sort_input.filter_) + items_query = ( + QueryChainer(Item.find(), Item) + .filter(filter_sort_input.filter_) .sort(filter_sort_input.sort) - .clone() - .query.find(fetch_links=True) - .count() + .paginate(pagination) + .query.project(ItemBase) + .to_list() ) - - paginated_query = base_query_chain.paginate(pagination).query.find(fetch_links=True).project(ItemBase).to_list() + count_query = QueryChainer(Item.find(), Item).filter(filter_sort_input.filter_).query.count() try: - items = await paginated_query + items = await items_query items_count = await count_query except Exception as exc: logger.error(f"error getting items from database; filter_sort: {filter_sort_input}: {exc}") diff --git a/src/utils/services.py b/src/utils/services.py index 752b6df..de54e9a 100644 --- a/src/utils/services.py +++ b/src/utils/services.py @@ -1,15 +1,15 @@ import copy -import operator from typing import Self, Type, cast from beanie import Document from beanie.odm.operators.find.evaluation import RegEx as RegExOperator +from bson import Decimal128 from loguru import logger import orjson import pymongo from redis.asyncio import Redis, RedisError -from src.config.constants.app import FIND_MANY_QUERY +from src.config.constants.app import FILTER_OPERATION_MAP, FIND_MANY_QUERY, NESTED_FILTER_OPERATION_MAP from src.schemas.requests import FilterInputType, FilterSchema, PaginationInput, SortInputType, SortSchema from src.schemas.responses import E, T, BaseResponse @@ -71,7 +71,8 @@ def sort_on_query(query: FIND_MANY_QUERY, model: Type[Document], sort: SortInput field = entry.field operation = pymongo.ASCENDING if entry.operation == "asc" else pymongo.DESCENDING - model_field = getattr(model, field) + is_nested = "." in field + model_field = field if is_nested else getattr(model, field) expression = (model_field, operation) sort_expressions.append(expression) @@ -80,6 +81,23 @@ def sort_on_query(query: FIND_MANY_QUERY, model: Type[Document], sort: SortInput return query +def _build_nested_query(entry: FilterSchema, query: FIND_MANY_QUERY) -> FIND_MANY_QUERY: + """Builds queries for nested fields, using raw BSON query syntax to ensure nested fields are parsed properly.""" + + field = entry.field + operation = entry.operation + value = entry.value + + if operation != "like": + operation_function = NESTED_FILTER_OPERATION_MAP[operation] + filter_query = {field: {operation_function: Decimal128(value)}} + else: + filter_query = {field: {"$regex": value, "$options": "i"}} + + query = query.find(filter_query) + return query + + def filter_on_query(query: FIND_MANY_QUERY, model: Type[Document], filter_: FilterInputType) -> FIND_MANY_QUERY: """Parses, gathers and chains filter operations on the input query. Skips the process if filter input is empty.\n Maps the operation list to operator arguments that allow using the operator dynamically, to create expressions @@ -89,31 +107,25 @@ def filter_on_query(query: FIND_MANY_QUERY, model: Type[Document], filter_: Filt if not isinstance(filter_, list): return query - operation_map = { - "=": operator.eq, - "!=": operator.ne, - ">": operator.gt, - "<": operator.lt, - ">=": operator.ge, - "<=": operator.le, - "like": RegExOperator, - } - for entry in filter_: field = entry.field operation = entry.operation - operation_function = operation_map[operation] + operation_function = FILTER_OPERATION_MAP[operation] value = entry.value - model_field = getattr(model, field) - - if operation != "like": - query = query.find(operation_function(model_field, value)) + is_nested = "." in field + if is_nested: + query = _build_nested_query(entry, query) else: - operation_function = RegExOperator - options = "i" # case-insensitive search + model_field = getattr(model, field) + + if operation != "like": + query = query.find(operation_function(model_field, value)) + else: + operation_function = RegExOperator + options = "i" # case-insensitive search - query = query.find(operation_function(model_field, value, options=options)) + query = query.find(operation_function(model_field, value, options=options)) return query