Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update: add plotting functions #13

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ classifiers = [
]
dependencies = [
# add requirements here
"numpy"
"numpy",
"plotly==5.18",
"matplotlib>=3.4.1",
"ipython<=8.0",
]

[project.optional-dependencies]
Expand Down
214 changes: 214 additions & 0 deletions src/py/mat3ra/utils/jupyterlite/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from IPython.display import display


def scatter_plot_2d(
x_values: List[float],
y_values: List[float],
hover_texts: List[str],
settings: Dict[str, Any],
trace_names: Optional[List[str]] = None,
) -> go.Figure:
"""
Create a generic 2D scatter plot.

Args:
x_values: List of x-coordinates
y_values: List of y-coordinates
hover_texts: List of hover texts for each point
settings: Plot settings including scales, height, and titles
trace_names: Optional list of names for each trace
"""
data = []
for i in range(len(x_values)):
trace = go.Scatter(
x=[x_values[i]],
y=[y_values[i]],
text=[hover_texts[i]],
mode="markers",
hoverinfo="text",
name=trace_names[i] if trace_names else f"Point {i}",
)
data.append(trace)

layout = go.Layout(
xaxis=dict(title=settings.get("x_title", "X"), type=settings.get("x_scale", "linear")),
yaxis=dict(title=settings.get("y_title", "Y"), type=settings.get("y_scale", "linear")),
hovermode="closest",
height=settings.get("height", 600),
title=settings.get("title", ""),
legend_title_text=settings.get("legend_title", ""),
)

return go.Figure(data=data, layout=layout)


def create_realtime_plot(
title: str = "Real-time Progress", x_label: str = "Step", y_label: str = "Value"
) -> go.FigureWidget:
"""
Create a real-time updating plot.
"""
fig = make_subplots(rows=1, cols=1, specs=[[{"type": "scatter"}]])
scatter = go.Scatter(x=[], y=[], mode="lines+markers", name="Progress")
fig.add_trace(scatter)
fig.update_layout(title_text=title, xaxis_title=x_label, yaxis_title=y_label)
widget = go.FigureWidget(fig)
display(widget) # Automatically display the widget
return widget


def create_update_callback(
dynamic_object: Any,
value_getter: Union[Callable, Any],
figure: go.FigureWidget,
steps: List[int],
values: List[float],
step_attr: str = "nsteps",
print_format: str = "Step: {}, Value: {:.4f}",
) -> Callable:
"""
Create a general update callback for real-time plotting.

Args:
dynamic_object: Object containing step information
value_getter: Either a callable function or an object with a getter method
figure: Plotly figure widget to update
steps: List to store step values
values: List to store measured values
step_attr: Attribute name for step count in dynamic_object
print_format: Format string for progress printing
"""

def update():
step = getattr(dynamic_object, step_attr)
# Handle both callable and object with getter method
value = value_getter() if callable(value_getter) else value_getter.get_total_energy()

steps.append(step)
values.append(value)

print(print_format.format(step, value))
with figure.batch_update():
figure.data[0].x = steps
figure.data[0].y = values

return update


def plot_distribution_function(
bin_centers: np.ndarray,
distribution: np.ndarray,
xlabel: str = "Distance",
ylabel: str = "g(r)",
title: str = "Distribution Function",
figsize: Tuple[int, int] = (8, 5),
) -> None:
"""
Plot a generic distribution function.

Args:
bin_centers: The bin centers.
distribution: The distribution values.
xlabel: The x-axis label.
ylabel: The y-axis label.
title: The title of the plot.
figsize: The size of the figure.
"""
plt.figure(figsize=figsize)
plt.plot(bin_centers, distribution, label=title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.legend()
plt.grid()
plt.show()


def plot_3d_surface(
x_matrix: np.ndarray,
y_matrix: np.ndarray,
z_matrix: np.ndarray,
optimal_point: Optional[Tuple[float, float]] = None,
title: str = "Surface Plot",
labels: Optional[Dict[str, str]] = None,
) -> None:
"""
Create a 3D surface plot with optional optimal point.

Args:
x_matrix: The x-axis matrix.
y_matrix: The y-axis matrix.
z_matrix: The z-axis matrix.
optimal_point: The optimal point to highlight.
title: The title of the plot.
labels: The labels for the axes.
"""
if labels is None:
labels = {"x": "X", "y": "Y", "z": "Z"}

fig = go.Figure(data=[go.Surface(x=x_matrix, y=y_matrix, z=z_matrix, colorscale="Viridis")])

if optimal_point is not None:
x_opt, y_opt = optimal_point
z_opt = np.min(z_matrix)
fig.add_trace(
go.Scatter3d(
x=[x_opt], y=[y_opt], z=[z_opt], mode="markers", marker=dict(size=8, color="red"), name="Optimal Point"
)
)

fig.update_layout(
title=title,
scene=dict(xaxis_title=labels["x"], yaxis_title=labels["y"], zaxis_title=labels["z"]),
width=800,
height=800,
)
fig.show()


def plot_2d_heatmap(
x_values: np.ndarray,
y_values: np.ndarray,
z_matrix: np.ndarray,
optimal_point: Optional[Tuple[float, float]] = None,
title: str = "Heatmap",
labels: Optional[Dict[str, str]] = None,
) -> None:
"""
Create a 2D heatmap with optional optimal point.

Args:
x_values: The x-axis values.
y_values: The y-axis values.
z_matrix: The z-axis matrix.
optimal_point: The optimal point to highlight.
title: The title of the plot.
labels: The labels for the axes.
"""
if labels is None:
labels = {"x": "X", "y": "Y", "z": "Z"}

fig = go.Figure(
data=go.Heatmap(x=x_values, y=y_values, z=z_matrix, colorscale="Viridis", colorbar=dict(title=labels["z"]))
)

if optimal_point is not None:
x_opt, y_opt = optimal_point
fig.add_trace(
go.Scatter(
x=[x_opt],
y=[y_opt],
mode="markers",
marker=dict(size=12, color="red", symbol="x"),
name="Optimal Point",
)
)

fig.update_layout(title=title, xaxis_title=labels["x"], yaxis_title=labels["y"], width=800, height=600)
fig.show()
Loading