Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Do not drop 'public' schema in setup_system_catalog #48

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 42 additions & 28 deletions prestogres/pgsql/prestogres.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the document about system column names: http://www.postgresql.org/docs/9.3/static/ddl-system-columns.html
SYSTEM_COLUMN_NAMES = set(["oid", "tableoid", "xmin", "cmin", "xmax", "cmax", "ctid"])


# convert Presto query result field types to PostgreSQL types
def _pg_result_type(presto_type):
if presto_type == "varchar": # for old Presto
Expand All @@ -29,6 +30,7 @@ def _pg_result_type(presto_type):
# assuming Presto and PostgreSQL use the same SQL standard name
return presto_type


# convert Presto table column types to PostgreSQL types
def _pg_table_type(presto_type):
if presto_type == "varchar": # for old Presto
Expand All @@ -43,6 +45,7 @@ def _pg_table_type(presto_type):
# assuming Presto and PostgreSQL use the same SQL standard name
return presto_type


# queries can include same column name twice but tables can't.
def _rename_duplicated_column_names(column_names, where):
renamed = []
Expand All @@ -53,15 +56,16 @@ def _rename_duplicated_column_names(column_names, where):
name += "_"
if name != original_name:
if name in SYSTEM_COLUMN_NAMES:
plpy.warning("Column %s is renamed to %s because the name in %s conflicts with PostgreSQL system column names" % \
(plpy.quote_ident(original_name), plpy.quote_ident(name), where))
plpy.warning("Column %s is renamed to %s because the name in %s conflicts with PostgreSQL system column names" %
(plpy.quote_ident(original_name), plpy.quote_ident(name), where))
else:
plpy.warning("Column %s is renamed to %s because the name appears twice in %s" % \
(plpy.quote_ident(original_name), plpy.quote_ident(name), where))
plpy.warning("Column %s is renamed to %s because the name appears twice in %s" %
(plpy.quote_ident(original_name), plpy.quote_ident(name), where))
used_names.add(name)
renamed.append(name)
return renamed


# build CREATE TEMPORARY TABLE statement
def _build_create_temp_table_sql(table_name, column_names, column_types):
create_sql = ["create temporary table %s (\n " % plpy.quote_ident(table_name)]
Expand All @@ -80,6 +84,7 @@ def _build_create_temp_table_sql(table_name, column_names, column_types):
create_sql.append("\n)")
return ''.join(create_sql)


# build CREATE TABLE statement
def _build_create_table(schema_name, table_name, column_names, column_types, not_nulls):
alter_sql = ["create table %s.%s (\n " % (plpy.quote_ident(schema_name), plpy.quote_ident(table_name))]
Expand All @@ -99,15 +104,18 @@ def _build_create_table(schema_name, table_name, column_names, column_types, not
alter_sql.append("\n)")
return ''.join(alter_sql)


def _get_session_time_zone():
rows = plpy.execute("show timezone")
return rows[0].values()[0]


def _get_session_search_path_array():
rows = plpy.execute("select ('{' || current_setting('search_path') || '}')::text[]")
return rows[0].values()[0]

NULL_PATTERN = dict({'\0':None})
NULL_PATTERN = dict({'\0': None})


def remove_null(bs):
if isinstance(bs, str):
Expand All @@ -117,6 +125,7 @@ def remove_null(bs):
else:
return bs


class QueryAutoClose(object):
def __init__(self, query):
self.query = query
Expand All @@ -126,6 +135,7 @@ def __init__(self, query):
def __del__(self):
self.query.close()


class QueryAutoCloseIterator(object):
def __init__(self, gen, query_auto_close):
self.gen = gen
Expand All @@ -140,6 +150,7 @@ def next(self):
row[i] = remove_null(v)
return row


class QueryAutoCloseIteratorWithJsonConvert(QueryAutoCloseIterator):
def __init__(self, gen, query_auto_close, json_columns):
QueryAutoCloseIterator.__init__(self, gen, query_auto_close)
Expand All @@ -153,12 +164,14 @@ def next(self):
row[i] = json.dumps(row[i])
return row


class SessionData(object):
def __init__(self):
self.query_auto_close = None

session = SessionData()


def start_presto_query(presto_server, presto_user, presto_catalog, presto_schema, function_name, query):
try:
# preserve search_path if explicitly set
Expand Down Expand Up @@ -200,8 +213,8 @@ def start_presto_query(presto_server, presto_user, presto_catalog, presto_schema
(plpy.quote_ident(function_name), plpy.quote_ident(type_name))

# run statements
plpy.execute("drop table if exists pg_temp.%s cascade" % \
(plpy.quote_ident(type_name)))
plpy.execute("drop table if exists pg_temp.%s cascade" %
(plpy.quote_ident(type_name)))
plpy.execute(create_type_sql)
plpy.execute(create_function_sql)

Expand All @@ -220,10 +233,11 @@ def start_presto_query(presto_server, presto_user, presto_catalog, presto_schema
e.__class__.__module__ = "__main__"
raise


def fetch_presto_query_results():
try:
# TODO should throw an exception?
#if session.query_auto_close is None:
# if session.query_auto_close is None:

query_auto_close = session.query_auto_close
session.query_auto_close = None # close of the iterator closes query
Expand All @@ -245,6 +259,7 @@ def fetch_presto_query_results():

Column = namedtuple("Column", ("name", "type", "nullable"))


def setup_system_catalog(presto_server, presto_user, presto_catalog, presto_schema, access_role):
search_path = _get_session_search_path_array()
if search_path == ['$user', 'public']:
Expand Down Expand Up @@ -274,43 +289,43 @@ def setup_system_catalog(presto_server, presto_user, presto_catalog, presto_sche
continue

if len(schema_name) > PG_NAMEDATALEN - 1:
plpy.warning("Schema %s is skipped because its name is longer than %d characters" % \
(plpy.quote_ident(schema_name), PG_NAMEDATALEN - 1))
plpy.warning("Schema %s is skipped because its name is longer than %d characters" %
(plpy.quote_ident(schema_name), PG_NAMEDATALEN - 1))
continue

tables = schemas.setdefault(schema_name, {})

if len(table_name) > PG_NAMEDATALEN - 1:
plpy.warning("Table %s.%s is skipped because its name is longer than %d characters" % \
(plpy.quote_ident(schema_name), plpy.quote_ident(table_name), PG_NAMEDATALEN - 1))
plpy.warning("Table %s.%s is skipped because its name is longer than %d characters" %
(plpy.quote_ident(schema_name), plpy.quote_ident(table_name), PG_NAMEDATALEN - 1))
continue

columns = tables.setdefault(table_name, [])

if len(column_name) > PG_NAMEDATALEN - 1:
plpy.warning("Column %s.%s.%s is skipped because its name is longer than %d characters" % \
(plpy.quote_ident(schema_name), plpy.quote_ident(table_name), \
plpy.quote_ident(column_name), PG_NAMEDATALEN - 1))
plpy.warning("Column %s.%s.%s is skipped because its name is longer than %d characters" %
(plpy.quote_ident(schema_name), plpy.quote_ident(table_name),
plpy.quote_ident(column_name), PG_NAMEDATALEN - 1))
continue

columns.append(Column(column_name, column_type, is_nullable))

# drop all schemas excepting prestogres_catalog, information_schema and pg_%
sql = "select n.nspname as schema_name from pg_catalog.pg_namespace n" \
" where n.nspname not in ('prestogres_catalog', 'information_schema')" \
" and n.nspname not like 'pg_%'"
sql = ("select n.nspname as schema_name from pg_catalog.pg_namespace n "
"where n.nspname not in ('prestogres_catalog', 'information_schema') "
"and n.nspname not like 'pg_%' and n.nspname != 'public'")
for row in plpy.cursor(sql):
plpy.execute("drop schema %s cascade" % plpy.quote_ident(row["schema_name"]))

# create schema and tables
for schema_name, tables in sorted(schemas.items(), key=lambda (k,v): k):
for schema_name, tables in sorted(schemas.items(), key=lambda (k, v): k):
try:
plpy.execute("create schema %s" % (plpy.quote_ident(schema_name)))
except:
# ignore error?
pass

for table_name, columns in sorted(tables.items(), key=lambda (k,v): k):
for table_name, columns in sorted(tables.items(), key=lambda (k, v): k):
column_names = []
column_types = []
not_nulls = []
Expand All @@ -325,27 +340,26 @@ def setup_system_catalog(presto_server, presto_user, presto_catalog, presto_sche

# change columns
column_names = _rename_duplicated_column_names(column_names,
"%s.%s table" % (plpy.quote_ident(schema_name), plpy.quote_ident(table_name)))
"%s.%s table" % (plpy.quote_ident(schema_name), plpy.quote_ident(table_name)))
create_sql = _build_create_table(schema_name, table_name, column_names, column_types, not_nulls)
plpy.execute(create_sql)

# grant access on the schema to the restricted user so that
# pg_table_is_visible(reloid) used by \d of psql command returns true
plpy.execute("grant usage on schema %s to %s" % \
(plpy.quote_ident(schema_name), plpy.quote_ident(access_role)))
plpy.execute("grant usage on schema %s to %s" %
(plpy.quote_ident(schema_name), plpy.quote_ident(access_role)))
# this SELECT privilege is unnecessary because queries against those tables
# won't run on PostgreSQL. causing an exception is good if Prestogres has
# a bug sending a presto query to PostgreSQL without rewriting.
# TODO however, it's granted for now because some BI tools might check
# has_table_privilege. the best solution is to grant privilege but
# actually selecting from those tables causes an exception.
plpy.execute("grant select on all tables in schema %s to %s" % \
(plpy.quote_ident(schema_name), plpy.quote_ident(access_role)))
plpy.execute("grant select on all tables in schema %s to %s" %
(plpy.quote_ident(schema_name), plpy.quote_ident(access_role)))

# fake current_database() to return Presto's catalog name to be compatible with some
# applications that use db.schema.table syntax to identify a table
if plpy.execute("select pg_catalog.current_database()")[0].values()[0] != presto_catalog:
plpy.execute("delete from pg_catalog.pg_proc where proname='current_database'")
plpy.execute("create function pg_catalog.current_database() returns name as $$begin return %s::name; end$$ language plpgsql stable strict" % \
plpy.quote_literal(presto_catalog))

plpy.execute("create function pg_catalog.current_database() returns name as $$begin return %s::name; end$$ language plpgsql stable strict" %
plpy.quote_literal(presto_catalog))