Skip to content

Commit

Permalink
Merge pull request #224 from lorenzorubi-db/table_acls_in_chunks
Browse files Browse the repository at this point in the history
fix for driver OOM in export of table ACLs
  • Loading branch information
gregwood-db authored Jan 11, 2023
2 parents 48354a0 + 7f24c2b commit 3de69f9
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 30 deletions.
68 changes: 38 additions & 30 deletions data/notebooks/Export_Table_ACLs.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,7 @@ def create_grants_df(database_name: str,object_type: str, object_key: str):
return grants_df


def create_table_ACLSs_df_for_databases(database_names: List[str]):

# TODO check Catalog heuristic:
# if all databases are exported, we include the Catalog grants as well
#. if only a few databases are exported: we exclude the Catalog
if database_names is None or database_names == '':
database_names = get_database_names()
include_catalog = True
else:
include_catalog = False

def create_table_ACLSs_df_for_databases(database_names: List[str], include_catalog: bool):
num_databases_processed = len(database_names)
num_tables_or_views_processed = 0

Expand Down Expand Up @@ -201,35 +191,53 @@ def create_table_ACLSs_df_for_databases(database_names: List[str]):
# COMMAND ----------

# DBTITLE 1,Run Export
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]


databases_raw = dbutils.widgets.get("Databases")
output_path = dbutils.widgets.get("OutputPath")

if databases_raw.rstrip() == '':
databases = None
# TODO check Catalog heuristic:
# if all databases are exported, we include the Catalog grants as well
databases = get_database_names()
include_catalog = True
print(f"Exporting all databases")
else:
#. if only a few databases are exported: we exclude the Catalog
databases = [x.rstrip().lstrip() for x in databases_raw.split(",")]
include_catalog = False
print(f"Exporting the following databases: {databases}")

counter = 1
for databases_chunks in chunks(databases, 1):
table_ACLs_df, num_databases_processed, num_tables_or_views_processed = create_table_ACLSs_df_for_databases(
databases_chunks, include_catalog
)

print(
f"{datetime.datetime.now()} total number processed chunk {counter}: databases: {num_databases_processed}, tables or views: {num_tables_or_views_processed}")
print(f"{datetime.datetime.now()} writing table ACLs to {output_path}")

# with table ACLS active, I direct write to DBFS is not allowed, so we store
# the dateframe as a table for single zipped JSON file sorted, for consitent file diffs
(
table_ACLs_df
# .coalesce(1)
.selectExpr("Database", "Principal", "ActionTypes", "ObjectType", "ObjectKey", "ExportTimestamp")
# .sort("Database","Principal","ObjectType","ObjectKey")
.write
.format("JSON")
.option("compression", "gzip")
.mode("append" if counter > 1 else "overwrite")
.save(output_path)
)

table_ACLs_df,num_databases_processed, num_tables_or_views_processed = create_table_ACLSs_df_for_databases(databases)

print(f"{datetime.datetime.now()} total number processed: databases: {num_databases_processed}, tables or views: {num_tables_or_views_processed}")
print(f"{datetime.datetime.now()} writing table ACLs to {output_path}")

# with table ACLS active, I direct write to DBFS is not allowed, so we store
# the dateframe as a table for single zipped JSON file sorted, for consitent file diffs
(
table_ACLs_df
.coalesce(1)
.selectExpr("Database","Principal","ActionTypes","ObjectType","ObjectKey","ExportTimestamp")
.sort("Database","Principal","ObjectType","ObjectKey")
.write
.format("JSON")
.option("compression","gzip")
.mode("overwrite")
.save(output_path)
)
counter += 1
include_catalog = False


# COMMAND ----------
Expand Down
6 changes: 6 additions & 0 deletions data/notebooks/Import_Table_ACLs.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ def execute_sql_statements(sqls):
l = [ str(o) for o in error_causing_sqls ]
print("\n".join(l))

# COMMAND ----------

# DBTITLE 1,Nicer error output
if len(error_causing_sqls) != 0:
l = [ {'sql': str(o.get('sql')), 'error': str(o.get('error'))} for o in error_causing_sqls ]
display(spark.createDataFrame(l))

# COMMAND ----------

Expand Down

0 comments on commit 3de69f9

Please sign in to comment.