Skip to content

Commit

Permalink
#51 contours_request now plots exactly the requested samples
Browse files Browse the repository at this point in the history
  • Loading branch information
funkchaser committed Sep 24, 2024
1 parent 56b97bd commit bd87ef5
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

* PlotContoursRequest plots now all the last requested designs
* Fixed small issues on documentation

### Removed
Expand Down
5 changes: 1 addition & 4 deletions src/aixd_ara/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,9 @@ def plot_contours_request():
data = json.loads(data)
session_id = data["session_id"]
sc = SessionController.create(session_id)

output_type = data["output_type"]
requested_values = data["request"]
n_samples = data["n_samples"]

result = sc.plot_contours_request(request=requested_values, n_samples=n_samples, output_type=output_type)
result = sc.plot_contours_request(output_type=output_type)
response = json.dumps(result, cls=DataEncoder)
return response

Expand Down
6 changes: 1 addition & 5 deletions src/aixd_ara/components/ara_PlotContoursRequest/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@

cid = component_id(session_id, ghenv.Component, "create_dataset_object")

n_samples = 3

if plot:
variable_types = get_dataobject_types(session_id())["dataobject_types"]
request_dict = reformat_request(request, variable_types)
print request_dict
st[cid] = plot_contours_request(session_id(), request_dict, n_samples, "interactive") # will launch the plotly fig in browser
st[cid] = plot_contours_request(session_id(), "interactive") # will launch the plotly fig in browser

if cid in st.keys():
print st[cid]
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,12 @@
"nickname": "PlotContReq",
"category": "ARA",
"subcategory": "5 Plotter",
"description": "Plots the requested and predicted values against the distribution contours for each pair of the corresponding variables.",
"description": "Plots the predicted values of the requested designs against the distribution contours for each pair of the corresponding variables.",
"exposure": 2,
"ghpython": {
"isAdvancedMode": false,
"iconDisplay": 0,
"inputParameters": [
{
"name": "request",
"description": "List of requested values, each formatted as a string with the following format: 'variable_name:value'.",
"typeHintID": "str",
"scriptParamAccess": 1
},
{
"name": "plot",
"description": "Set to True to (re-)create the plot.",
Expand Down
14 changes: 9 additions & 5 deletions src/aixd_ara/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self):
self.datamodule = None
self.samples_per_file = None
self.model_is_trained = False
self.requested_designs = None, None, None

def reset(self):
self.project_root = None
Expand All @@ -48,6 +49,7 @@ def reset(self):
self.datamodule = None
self.samples_per_file = None
self.model_is_trained = False
self.requested_designs = None, None, None

@property
def dataset_path(self):
Expand Down Expand Up @@ -345,7 +347,7 @@ def plot_contours(self, dataobjects, output_type):
fig = plotter.contours2d(block=block, attributes=dataobjects)
return _fig_output(fig, output_type)

def plot_contours_request(self, request, n_samples, output_type):
def plot_contours_request(self, output_type):
"""
request:
dictionary where keys are the names of dataobjects (usually performance attributes),
Expand All @@ -355,10 +357,12 @@ def plot_contours_request(self, request, n_samples, output_type):
raise ValueError("Dataset is not loaded.")
if not self.model:
raise ValueError("Model is not loaded.")
if self.requested_designs == (None, None, None):
raise ValueError("No designs have been requested yet.")

plotter = Plotter(datamodule=self.datamodule, output=None)
gen = Generator(model=self.model, datamodule=self.datamodule, over_sample=10)
_, detailed_results = gen.generate(request=request, n_samples=n_samples, format_out="dict_list")
detailed_results = self.requested_designs[2]
n_samples = len(self.requested_designs[1])

fig = plotter.generation_scatter([detailed_results], n_samples=n_samples)
return _fig_output(fig, output_type)
Expand Down Expand Up @@ -534,8 +538,8 @@ def request_designs(self, request, n_samples=1):
raise ValueError("Model is not loaded.")

gen = Generator(model=self.model, datamodule=self.datamodule, over_sample=100)
new_designs = gen.generate(request=request, n_samples=n_samples, format_out="dict_list")[0]

new_designs, detailed_results = gen.generate(request=request, n_samples=n_samples, format_out="dict_list")
self.requested_designs = request, new_designs, detailed_results
# split the result into separate dictionaries for design parameters and performance attributes
# assert len(new_designs) == n_samples
samples = []
Expand Down
4 changes: 2 additions & 2 deletions src/aixd_ara/gh_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def plot_contours(session_id, attributes, output_type):
return http_post_request(action="plot_contours", data=data)


def plot_contours_request(session_id, request, n_samples, output_type):
data = {"session_id": session_id, "request": request, "n_samples": n_samples, "output_type": output_type}
def plot_contours_request(session_id, output_type):
data = {"session_id": session_id, "output_type": output_type}
return http_post_request(action="plot_contours_request", data=data)


Expand Down

0 comments on commit bd87ef5

Please sign in to comment.