-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathwave3d.py
202 lines (140 loc) · 4.87 KB
/
wave3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""
This is a 3D wave simulation test for education, it is by no means
optimized, and I don't fully guarantee its correctness.
Repo:
https://github.com/bean-mhm/wave-simulation-py
"""
import time
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import plotly.graph_objects as go
# Visualization method
# 'slice': Advance the simulation live and render a 2D slice
# 'volume': Visualize the 3D volume as a whole in constant time
visual = 'slice'
# Z index for slices
slice_z_index = 10
# How many times to increment before volume visualization
volume_iters = 10
# Grid resolution
nx, ny, nz = 20, 20, 20
# Minimum distance in the grid
step = 0.5
# Propagation speed in 3D
speed = 10
# Initial values
u = np.zeros(shape=(nx, ny, nz), dtype=float, order='C')
# Boost a specific point's value
# You can modify this to have different initial values.
u[10, 10, 10] = 1.0
# Previous values
u_last = np.copy(u)
# Velocities
vel = np.zeros(shape=(nx, ny, nz), dtype=float, order='C')
# Maximum timestep
max_dt = step / (speed * np.sqrt(3.0))
print(f'{max_dt = }')
# Timestep
dt = 0.8 * max_dt
# Stiffness
# Must be greater than or equal to 1 to function properly.
# The formula is made up and likely not physically correct.
stiffness = 2.0
# Get bound value from an array, returns the default value for
# out-of-bound indices.
def get_bound(arr, index: int, default=0.0):
if (index < 0 or index >= len(arr)):
return default
return arr[index]
# Get bound value in a multidimensional array
def get_bound_md(arr: np.ndarray, indices: tuple, default=0.0):
for i in range(len(indices)):
if indices[i] < 0 or indices[i] >= arr.shape[i]:
return default
return arr[indices]
# Advance the simulation by dt
def increment():
global u
global u_last
global vel
# Make a backup of the current values
temp = np.copy(u)
# Precalculate constants
step2 = step**2
speed2 = speed**2
stiffen_mul = stiffness ** (-dt)
# Go through the points
for z in range(nz):
for y in range(ny):
for x in range(nx):
# Get the current value
curr = u[x, y, z]
# Calculate the second gradients with respect to x, y, and z
grad_x = ((get_bound_md(u, (x + 1, y, z)) - curr) -
(curr - get_bound_md(u, (x - 1, y, z)))) / step2
grad_y = ((get_bound_md(u, (x, y + 1, z)) - curr) -
(curr - get_bound_md(u, (x, y - 1, z)))) / step2
grad_z = ((get_bound_md(u, (x, y, z + 1)) - curr) -
(curr - get_bound_md(u, (x, y, z - 1)))) / step2
# Calculate how much we need to adjust the velocity
acc = speed2 * (grad_x + grad_y + grad_z)
# Get the current velocity
currVel = (curr - u_last[x, y, z]) / dt
# Adjust the velocity
currVel += acc * dt
# "Stiffen"
currVel *= stiffen_mul
# Store the velocity
vel[x, y, z] = currVel
# Go through the points and adjust the values based on the
# velocities that we calculated before
for z in range(nz):
for y in range(ny):
for x in range(nx):
u[x, y, z] += vel[x, y, z] * dt
# Use the backup we made before
u_last = temp
# Get a 2D slice of the 3D grid
def get_slice():
global u
slice = np.zeros(shape=(nx, ny), dtype=float, order='C')
for y in range(ny):
for x in range(nx):
slice[x, y] = u[x, y, slice_z_index]
return slice
# Visualize
if visual == 'volume':
for i in range(volume_iters):
increment()
X, Y, Z = np.mgrid[:nx, :ny, :nz]
fig = go.Figure(data=go.Volume(
x=X.flatten(),
y=Y.flatten(),
z=Z.flatten(),
value=u.flatten(),
colorscale='twilight',
isomin=-0.1,
isomax=0.1,
opacity=0.1, # needs to be small to see through all surfaces
surface_count=50, # needs to be a large number for good volume rendering
))
fig.show()
elif visual == 'slice':
plt.ion()
fig = plt.figure()
img: matplotlib.image.AxesImage = plt.imshow(get_slice().T, cmap='twilight', interpolation='bicubic')
img.set_clim(vmin=-.1, vmax=.1)
stepsPerSecond = 30.0
secondsPerStep = 1.0 / stepsPerSecond
last = time.time()
while 1:
if time.time() - last >= secondsPerStep:
last = time.time()
increment()
img.set_data(get_slice().T)
#img.autoscale()
fig.canvas.draw()
fig.canvas.flush_events()
else:
print('Invalid visualization method')