Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Glow Effect for 3d Scatter Plots #32

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mplcyberpunk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import matplotlib as mpl
import pkg_resources

from .core import add_glow_effects, make_lines_glow, add_underglow, make_scatter_glow, add_gradient_fill, add_bar_gradient
from .core import add_glow_effects, make_lines_glow, add_underglow, make_scatter_glow, add_gradient_fill, add_bar_gradient, make_3d_scatter_collection_glow

__version__ = pkg_resources.require("mplcyberpunk")[0].version
__author__ = 'Dominik Haitz <[email protected]>'
Expand Down
42 changes: 42 additions & 0 deletions mplcyberpunk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Polygon
import mpl_toolkits
from mpl_toolkits.mplot3d import Axes3D


def add_glow_effects(ax: Optional[plt.Axes] = None, gradient_fill: bool = False) -> None:
Expand Down Expand Up @@ -262,3 +264,43 @@ def add_bar_gradient(
)

bar.remove()

def make_3d_scatter_collection_glow(
ax: plt.axes,
collection: mpl_toolkits.mplot3d.art3d.Path3DCollection,
n_glow_lines: int = 10,
diff_dotwidth: float = 1.2,
alpha: float = 0.3,
) -> None:
"""
Add glow effect to a specific collection in the 3d scatter plot.
Copies the idea from make_scatter_glow(), but targets a single collection
in 3d space.

Done on only a single collection because applying to all collections,
using something like shown below, would only plot the glow scatters
for the first collection, and none of the other ones.

I suppose this may be nice if a user wants to glow only a specific label
on the scatter.

```py
for collection in ax.collections:
<code below here>
```

"""

try:
# get the x, y, and z cords of the points
x, y, z = collection._offsets3d
# get the colors of this collection of points
dot_color = collection.get_facecolors()
# get the size of dots from this collection
dot_size = collection.get_sizes()

alpha = alpha/n_glow_lines
for _ in range(1, n_glow_lines):
ax.scatter(x, y, z, s=dot_size*(diff_dotwidth**_), c=dot_color, alpha=alpha)
except:
pass
59 changes: 59 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import mplcyberpunk
import random
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

def test_plotting_working():
plt.style.use("cyberpunk")
Expand Down Expand Up @@ -183,3 +184,61 @@ def test_gradient_bars():
mplcyberpunk.add_bar_gradient(bars=bars, ax=ax)

fig.savefig('test_gradient_bars.png')

def test_3d_scatter_glow():
plt.style.use('cyberpunk')
plt.rcParams['figure.facecolor'] = '#000000'
plt.rcParams['axes.facecolor'] = '#000000'
plt.rcParams['axes3d.xaxis.panecolor'] = '#101010'
plt.rcParams['axes3d.yaxis.panecolor'] = '#101010'
plt.rcParams['axes3d.zaxis.panecolor'] = '#101010'

fig = plt.figure(figsize=(10, 10))

ax = fig.add_subplot(111, projection="3d")

colors = ["C0", "C1", "C2"]

labels = ["label1", "label2", "label3"]

data = {
"label1": {
"x": np.random.rand(30, 1),
"y": np.random.rand(30, 1),
"z": np.random.rand(30, 1),
},
"label2": {
"x": np.random.rand(10, 1),
"y": np.random.rand(10, 1),
"z": np.random.rand(10, 1),
},
"label3": {
"x": np.random.rand(50, 1),
"y": np.random.rand(50, 1),
"z": np.random.rand(50, 1),
}
}

for i, label in enumerate(labels):
single_scatter_x = data[label]["x"]
single_scatter_y = data[label]["y"]
single_scatter_z = data[label]["z"]

ax.scatter(
single_scatter_x,
single_scatter_y,
single_scatter_z,
s=180,
color=colors[i],
label=label,
)

collections = [c for c in ax.collections]
for collection in collections:
mplcyberpunk.make_3d_scatter_collection_glow(ax, collection, alpha=0.3, n_glow_lines=10, diff_dotwidth=1.2)

ax.set_title("3d Scatter Glow Test")

fig.savefig("test_3d_scatter_glow.png")

plt.show()