Skip to content

Commit

Permalink
#66 ghcomponent to merge datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
funkchaser committed Feb 13, 2025
1 parent a25d6c5 commit a13ac1e
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/aixd_ara/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,22 @@ def ara_welcome(*args):
print()


@app.route("/merge_datasets", methods=["POST"])
def merge_datasets():
data = request.data
data = json.loads(data)
session_id = data["session_id"]
sc = SessionController.create(session_id)

result = sc.merge_datasets(
root_folder=data["root_folder"],
new_dataset_name=data["new_dataset_name"],
samples_per_file=data["samples_per_file"],
)
response = json.dumps(result, cls=DataEncoder)
return response


if __name__ == "__main__":
import sys

Expand Down
25 changes: 25 additions & 0 deletions src/aixd_ara/components/ara_DatasetsMerge/code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# flake8: noqa
from scriptcontext import sticky as st
from Grasshopper.Kernel.GH_RuntimeMessageLevel import Error, Warning

from aixd_ara.gh_ui import merge_datasets
from aixd_ara.gh_ui_helper import clear_sticky
from aixd_ara.gh_ui_helper import component_id
from aixd_ara.gh_ui_helper import session_id

if samples_per_file is None: samples_per_file = 1000
assert samples_per_file>0, "samples_per_file must be a positive integer, got {}".format(samples_per_file)

if root_folder:
cid = component_id(session_id(), ghenv.Component, "ProjectSetup")

if merge:
st[cid] = merge_datasets(session_id(), root_folder, new_dataset_name,samples_per_file)

if cid in st.keys():
status = st[cid]["status"]
msg = st[cid]["msg"]
# if status=="error":
# ghenv.Component.AddRuntimeMessage(Error, msg)
# elif status=="warning":
# ghenv.Component.AddRuntimeMessage(Warning, msg)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
45 changes: 45 additions & 0 deletions src/aixd_ara/components/ara_DatasetsMerge/metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"name": "DatasetsMerge",
"nickname": "DatasetsMerge",
"category": "ARA",
"subcategory": "2 Dataset",
"description": "Merges multiple datasets into a single dataset. Requires that the datasets have the same schema (variable names, types, dimensions).",
"exposure": 2,
"ghpython": {
"isAdvancedMode": false,
"iconDisplay": 0,
"inputParameters": [
{
"name": "root_folder",
"description": "Path to the folder containing the datasets to merge.",
"typeHintID": "str",
"scriptParamAccess": 0
},
{
"name": "new_dataset_name",
"description": "Name of the merged dataset. (Optional, default: 'merged_dataset'.)",
"typeHintID": "str",
"scriptParamAccess": 0
},
{
"name": "samples_per_file",
"description": "Number of samples to be saved in each file of the new dataset. (Optional, default: 1000.)",
"typeHintID": "int",
"scriptParamAccess": 0
},
{
"name": "merge",
"description": "Triggers the merge process.",
"typeHintID": "bool",
"scriptParamAccess": 0
}
],

"outputParameters": [
{
"name": "msg",
"description": "Message logs."
}
]
}
}
108 changes: 108 additions & 0 deletions src/aixd_ara/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import base64
import os
import random
import shutil
import pandas as pd

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -608,6 +610,112 @@ def blocknames_from_dataobjects(self, dataobjects):

return list(set(blocknames))

@staticmethod
def merge_datasets(root_folder: str, new_dataset_name: str, samples_per_file: int):
"""
This function merges all datasets in the root_folder into a new dataset with the name new_dataset_name.
Prerequisites:
- All datasets are stored in subfolders of the root_folder
- Each dataset has the same design_parameters and performance_attributes (variable names, types, dimensions)
Parameters:
-----------
root_folder: str
Path to the root folder containing the datasets.
new_dataset_name: str
Name of the new dataset. The new dataset will be created in a subfolder with this name. Default is "merged_dataset".
If such a folder already exists, it will be overwritten.
Returns:
--------
status : str
Status of the process: "warning", "error" or None
msg : str
A message containing information about the process.
"""
# TODO: check what happens with the domains, may need to get a union explicitly.

status = None
msg = ""

if not os.path.exists(root_folder):
txt = f"Folder {root_folder} does not exist.\n"
status = "error"
return status, txt

if not new_dataset_name:
new_dataset_name = "merged_dataset"

if os.path.exists(os.path.join(root_folder, new_dataset_name)):
txt = f"[WARNING] Folder {new_dataset_name} already exists. It will be overwritten.\n\n"
status = "warning"
print(txt)
msg += txt
shutil.rmtree(os.path.join(root_folder, new_dataset_name))

dataset_names = [
name
for name in os.listdir(root_folder)
if os.path.isdir(os.path.join(root_folder, name)) and name != new_dataset_name
]
txt = f"Found following subfolders: {dataset_names}.\n\n"
print(txt)
msg += txt

def _load_df(root_folder, dataset_name):
# Load old sharded data from pickled dataframes

# DPs
directory = os.path.join(root_folder, dataset_name, "design_parameters")
df_dp_all = []

for filename in os.listdir(directory):
if filename.endswith(".pkl"):
filepath = os.path.join(directory, filename)
df = pd.read_pickle(filepath)
df_dp_all.append(df)

df_dp_all = pd.concat(df_dp_all, axis=0)

# PAs
directory = os.path.join(root_folder, dataset_name, "performance_attributes")
df_pa_all = []

for filename in os.listdir(directory):
if filename.endswith(".pkl"):
filepath = os.path.join(directory, filename)
df = pd.read_pickle(filepath)
df_pa_all.append(df)

df_pa_all = pd.concat(df_pa_all, axis=0)
df_all = pd.merge(df_dp_all, df_pa_all, how="inner", on=["uid"])
df_all = df_all.drop(columns=["uid"])
return df_all

# load all datasets into a single dataframe
dfs = []
for dataset_name in dataset_names:
df = _load_df(root_folder, dataset_name)
dfs.append(df)
df_all = pd.concat(dfs)

# create new Dataset, take one of the old datasets as a template
dataset_temp = Dataset.from_dataset_folder(os.path.join(root_folder, dataset_names[0]))

dataset_new = Dataset(
name=new_dataset_name,
root_path=root_folder,
file_format="json",
design_par=dataset_temp.design_par,
perf_attributes=dataset_temp.perf_attributes,
overwrite=True,
)

# load data
dataset_new.import_data_from_df(df_all, samples_perfile=samples_per_file, flag_fromscratch=True)
msg += f"New dataset with {len(df_all)} samples has been created in {os.path.join(root_folder, new_dataset_name)}.\n"
return {"status": status, "msg": msg}


# --------------------------------------------------------------
# helper methods
Expand Down
10 changes: 10 additions & 0 deletions src/aixd_ara/gh_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,13 @@ def model_input_output_dimensions(session_id):
def request_designs(session_id, request, n_designs):
data = {"session_id": session_id, "requested_values": request, "n_designs": n_designs}
return http_post_request(action="request_designs", data=data)


def merge_datasets(session_id, root_folder, new_dataset_name, samples_per_file):
data = {
"session_id": session_id,
"root_folder": root_folder,
"new_dataset_name": new_dataset_name,
"samples_per_file": samples_per_file,
}
return http_post_request(action="merge_datasets", data=data)

0 comments on commit a13ac1e

Please sign in to comment.