Skip to content

Commit

Permalink
#15 add plot contours request (generation_scatter)
Browse files Browse the repository at this point in the history
  • Loading branch information
funkchaser committed May 2, 2024
1 parent 7085125 commit af8ace4
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 10 deletions.
Binary file added docs/_images/icons/aixd_PlotContoursRequest.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
32 changes: 28 additions & 4 deletions docs/documentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ Retrieves one sample from the dataset (at a given or random index) and instantia

**Outputs**

- **sample** -- Summary of the retrieved sample.
- **sample_summary** -- Summary of the retrieved sample.

DatasetSummary
--------------
Expand Down Expand Up @@ -203,13 +203,16 @@ Runs a generation campaing to create new designs using the trained model.

**Inputs**

- **requested_values** *(str)* -- List of requested values, each formatted as a string with the following format: 'variable_name:value'.
- **requested_values** *[List of (str)]* -- List of requested values, each formatted as a string with the following format: 'variable_name:value'.
- **n_designs** *(int)* -- Number of designs to generate.
- **run** *(none)* -- Set to True to start the generation process.
- **generate** *(bool)* -- Set to True to start the generation process.
- **clear** *(bool)* -- Forget the previously generated designs.
- **pick_previous** *(bool)* -- Iterate backward through the list of generated designs, instantiate the previous sample.
- **pick_next** *(bool)* -- Iterate forward through the list of generated designs, instantiate the next sample.

**Outputs**

- **predicions** -- List of generated designs.
- **sample_summary** -- Selected sample.

GenSampleEval
-------------
Expand Down Expand Up @@ -376,6 +379,27 @@ Plots the distribution contours for each pair of variables from the data in the

- **img** -- Bitmap image if output_type is 'static', otherwise None.

PlotContoursRequest
-------------------
.. image:: _images/icons/aixd_PlotContoursRequest.png
:align: left
:height: 24
:width: 24

Plots the requested and predicted values against the distribution contours for each pair of the corresponding variables.


**Inputs**

- **request** *[List of (str)]* -- List of requested values, each formatted as a string with the following format: 'variable_name:value'.
- **output_type** *(str)* -- Plot type: 'static' creates a bitmap image, 'interactive' launches an interactive plot in a browser.
- **plot** *(bool)* -- Set to True to (re-)create the plot.
- **scale** *(float)* -- Resize factor for the static plot.

**Outputs**

- **img** -- Bitmap image if output_type is 'static', otherwise None.

PlotCorrelations
----------------
.. image:: _images/icons/aixd_PlotCorrelations.png
Expand Down
16 changes: 16 additions & 0 deletions src/aixd_grasshopper/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,22 @@ def plot_contours():
return response


@app.route("/plot_contours_request", methods=["POST"])
def plot_contours_request():
data = request.data
data = json.loads(data)
session_id = data["session_id"]
sc = SessionController.create(session_id)

output_type = data["output_type"]
requested_values = data["request"]
n_samples = data["n_samples"]

result = sc.plot_contours_request(request=requested_values, n_samples=n_samples, output_type=output_type)
response = json.dumps(result, cls=DataEncoder)
return response


@app.route("/design_parameters", methods=["GET"])
def get_design_parameters():
data = request.args
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "PlotContours",
"nickname": "PlotCountours",
"nickname": "PlotContours",
"category": "AIXD",
"subcategory": "5 Plotter",
"description": "Plots the distribution contours for each pair of variables from the data in the dataset.",
Expand Down
26 changes: 26 additions & 0 deletions src/aixd_grasshopper/components/aixd_PlotContoursRequest/code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# flake8: noqa
from scriptcontext import sticky as st
from aixd_grasshopper.gh_ui import plot_contours_request
from aixd_grasshopper.gh_ui import get_dataobject_types
from aixd_grasshopper.gh_ui_helper import session_id
from aixd_grasshopper.gh_ui_helper import component_id
from aixd_grasshopper.gh_ui_helper import convert_str_to_bitmap
from aixd_grasshopper.gh_ui_helper import reformat_request


cid = component_id(session_id, ghenv.Component, "create_dataset_object")

n_samples = 3

if plot:
variable_types = get_dataobject_types(session_id())["dataobject_types"]
request_dict = reformat_request(request, variable_types)
print request_dict
st[cid] = plot_contours_request(session_id(), request_dict, n_samples, output_type) # if output_type interactive: will launch the plotly fig in browser

if cid in st.keys():
print st[cid]
#TODO: add error msg here
if output_type == "static" and 'imgstr' in st[cid].keys():
imgstr = st[cid]['imgstr']
img = convert_str_to_bitmap(imgstr, scale)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"name": "PlotContoursRequest",
"nickname": "PlotContReq",
"category": "AIXD",
"subcategory": "5 Plotter",
"description": "Plots the requested and predicted values against the distribution contours for each pair of the corresponding variables.",
"exposure": 2,
"ghpython": {
"isAdvancedMode": false,
"iconDisplay": 0,
"inputParameters": [
{
"name": "request",
"description": "List of requested values, each formatted as a string with the following format: 'variable_name:value'.",
"typeHintID": "str",
"scriptParamAccess": 1
},
{
"name": "output_type",
"description": "Plot type: 'static' creates a bitmap image, 'interactive' launches an interactive plot in a browser.",
"typeHintID": "str",
"scriptParamAccess": 0
},
{
"name": "plot",
"description": "Set to True to (re-)create the plot.",
"typeHintID": "bool",
"scriptParamAccess": 0
},
{
"name": "scale",
"description": "Resize factor for the static plot.",
"typeHintID": "float",
"scriptParamAccess": 0
}
],
"outputParameters": [
{
"name": "img",
"description": "Bitmap image if output_type is 'static', otherwise None."
}
]
}
}
24 changes: 20 additions & 4 deletions src/aixd_grasshopper/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,22 @@ def plot_contours(self, dataobjects, output_type):
fig = plotter.contours2d(block=block, attributes=dataobjects)
return _fig_output(fig, output_type)

def plot_contours_request(self, request, n_samples, output_type):
"""
request: dictionary where keys are the names of dataobjects (usually performance attributes), and values the requested target value(s).
"""
if not self.dataset:
raise ValueError("Dataset is not loaded.")
if not self.model:
raise ValueError("Model is not loaded.")

plotter = Plotter(datamodule=self.datamodule, output=None)
gen = Generator(model=self.model, datamodule=self.datamodule, over_sample=10)
_, detailed_results = gen.generation(request=request, n_samples=n_samples, format_out="dict_list")

fig = plotter.generation_scatter([detailed_results], n_samples=n_samples)
return _fig_output(fig, output_type)

def model_setup(self, model_type, inputML, outputML, latent_dim, layer_widths, batch_size):
# TODO: set defaults here if missing?
if not self.dataset:
Expand All @@ -336,7 +352,7 @@ def model_setup(self, model_type, inputML, outputML, latent_dim, layer_widths, b
self.datamodule = datamodule

save_dir = self.dataset_path

if model_type == "CAE":
model = CondAEModel.from_datamodule(
datamodule, layer_widths=layer_widths, latent_dim=latent_dim, save_dir=save_dir
Expand All @@ -347,7 +363,7 @@ def model_setup(self, model_type, inputML, outputML, latent_dim, layer_widths, b
)
else:
raise ValueError("Model type not recognized. Choose 'CAE' or 'CVAE'.")

self.model = model
self.model_is_trained = False

Expand Down Expand Up @@ -405,7 +421,7 @@ def model_load(self, model_type, checkpoint_path, checkpoint_name):
model = CondVAEModel.load_model_from_checkpoint(checkpoint_filepath)
else:
raise ValueError("Model type not recognized. Choose 'CAE' or 'CVAE'.")

self.model = model
self.model_is_trained = True
self.datamodule = self._datamodule_from_dataset()
Expand Down Expand Up @@ -482,7 +498,7 @@ def request_designs(self, request, n_samples=1):
if not self.dataset:
raise ValueError("Dataset is not loaded.")
if not self.model:
raise ValueError("NN model is not loaded.")
raise ValueError("Model is not loaded.")

gen = Generator(model=self.model, datamodule=self.datamodule, over_sample=100)
new_designs = gen.generation(request=request, n_samples=n_samples, format_out="dict_list")[0]
Expand Down
12 changes: 11 additions & 1 deletion src/aixd_grasshopper/gh_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def plot_contours(session_id, attributes, output_type):
return http_post_request(action="plot_contours", data=data)


def plot_contours_request(session_id, request, n_samples, output_type):
data = {"session_id": session_id, "request": request, "n_samples": n_samples, "output_type": output_type}
return http_post_request(action="plot_contours_request", data=data)


def plot_correlations(session_id, attributes, output_type):
data = {"session_id": session_id, "attributes": attributes, "output_type": output_type}
return http_post_request(action="plot_correlations", data=data)
Expand All @@ -76,7 +81,12 @@ def model_train(session_id, epochs, wb):


def model_load(session_id, model_type, checkpoint_name, checkpoint_path):
data = {"session_id": session_id, "model_type": model_type, "checkpoint_name": checkpoint_name, "checkpoint_path": checkpoint_path}
data = {
"session_id": session_id,
"model_type": model_type,
"checkpoint_name": checkpoint_name,
"checkpoint_path": checkpoint_path,
}
return http_post_request(action="model_load", data=data)


Expand Down

0 comments on commit af8ace4

Please sign in to comment.