Skip to content

Commit

Permalink
gpu bug
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Jan 14, 2025
1 parent 8f88fa8 commit 0fb9234
Showing 1 changed file with 78 additions and 78 deletions.
156 changes: 78 additions & 78 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,93 +962,93 @@ def _solve_for_indicies(
device = self._device

tilt_deg = self._tilt_deg[datacube_number]
tilt = -xp.deg2rad(tilt_deg)
tilt = -np.deg2rad(tilt_deg)

# solve for real space coordinates
line_z = xp.linspace(0, 1, num_points) * (s[2] - 1)
line_y = line_z * xp.tan(tilt)
line_z = np.linspace(0, 1, num_points) * (s[2] - 1)
line_y = line_z * np.tan(tilt)
line_y -= np.mean(line_y)
offset = xp.arange(s[1], dtype="int")
offset = np.arange(s[1], dtype="int")

yF = xp.floor(line_y).astype("int")
zF = xp.floor(line_z).astype("int")
yF = np.floor(line_y).astype("int")
zF = np.floor(line_z).astype("int")
dy = line_y - yF
dz = line_z - zF

ind0 = np.hstack(
(
xp.tile(yF, (s[1], 1)) + offset[:, None],
xp.tile(yF + 1, (s[1], 1)) + offset[:, None],
xp.tile(yF, (s[1], 1)) + offset[:, None],
xp.tile(yF + 1, (s[1], 1)) + offset[:, None],
np.tile(yF, (s[1], 1)) + offset[:, None],
np.tile(yF + 1, (s[1], 1)) + offset[:, None],
np.tile(yF, (s[1], 1)) + offset[:, None],
np.tile(yF + 1, (s[1], 1)) + offset[:, None],
)
)

ind1 = np.hstack(
(
xp.tile(zF, (s[1], 1)),
xp.tile(zF, (s[1], 1)),
xp.tile(zF + 1, (s[1], 1)),
xp.tile(zF + 1, (s[1], 1)),
np.tile(zF, (s[1], 1)),
np.tile(zF, (s[1], 1)),
np.tile(zF + 1, (s[1], 1)),
np.tile(zF + 1, (s[1], 1)),
)
)

weights_real = np.hstack(
(
xp.tile(((1 - dy) * (1 - dz)), (s[1], 1)),
xp.tile(((dy) * (1 - dz)), (s[1], 1)),
xp.tile(((1 - dy) * (dz)), (s[1], 1)),
xp.tile(((dy) * (dz)), (s[1], 1)),
np.tile(((1 - dy) * (1 - dz)), (s[1], 1)),
np.tile(((dy) * (1 - dz)), (s[1], 1)),
np.tile(((1 - dy) * (dz)), (s[1], 1)),
np.tile(((dy) * (dz)), (s[1], 1)),
)
)

ind_real = xp.ravel_multi_index((ind0, ind1), (s[1], s[2]), mode="clip")
ind_real = np.ravel_multi_index((ind0, ind1), (s[1], s[2]), mode="clip")

# solve for diffraction space coordinates
length = s[-1] * np.cos(tilt)
line_y_diff = xp.arange(-(s[-1] - 1) / 2, s[-1] / 2) * length / s[-1]
line_z_diff = line_y_diff * xp.tan(tilt) + (s[-1] - 1) / 2
line_y_diff = np.arange(-(s[-1] - 1) / 2, s[-1] / 2) * length / s[-1]
line_z_diff = line_y_diff * np.tan(tilt) + (s[-1] - 1) / 2
line_y_diff += (s[-1] - 1) / 2

yF_diff = xp.floor(line_y_diff).astype("int")
zF_diff = xp.floor(line_z_diff).astype("int")
yF_diff = np.floor(line_y_diff).astype("int")
zF_diff = np.floor(line_z_diff).astype("int")
dy_diff = line_y_diff - yF_diff
dz_diff = line_z_diff - zF_diff

qx = xp.arange(s[-1])
qy = xp.arange(s[-1])
qxx, qyy = xp.meshgrid(qx, qy, indexing="ij")
qx = np.arange(s[-1])
qy = np.arange(s[-1])
qxx, qyy = np.meshgrid(qx, qy, indexing="ij")

ind0_diff = np.hstack(
(
xp.tile(yF_diff, s[-1]),
xp.tile(yF_diff + 1, s[-1]),
xp.tile(yF_diff, s[-1]),
xp.tile(yF_diff + 1, s[-1]),
np.tile(yF_diff, s[-1]),
np.tile(yF_diff + 1, s[-1]),
np.tile(yF_diff, s[-1]),
np.tile(yF_diff + 1, s[-1]),
)
)

ind1_diff = np.hstack(
(
xp.tile(zF_diff, s[-1]),
xp.tile(zF_diff, s[-1]),
xp.tile(zF_diff + 1, s[-1]),
xp.tile(zF_diff + 1, s[-1]),
np.tile(zF_diff, s[-1]),
np.tile(zF_diff, s[-1]),
np.tile(zF_diff + 1, s[-1]),
np.tile(zF_diff + 1, s[-1]),
)
)

weights_diff = np.hstack(
(
xp.tile(((1 - dy_diff) * (1 - dz_diff)), s[-1]),
xp.tile(((dy_diff) * (1 - dz_diff)), s[-1]),
xp.tile(((1 - dy_diff) * (dz_diff)), s[-1]),
xp.tile(((dy_diff) * (dz_diff)), s[-1]),
np.tile(((1 - dy_diff) * (1 - dz_diff)), s[-1]),
np.tile(((dy_diff) * (1 - dz_diff)), s[-1]),
np.tile(((1 - dy_diff) * (dz_diff)), s[-1]),
np.tile(((dy_diff) * (dz_diff)), s[-1]),
)
)

ind_diff = xp.ravel_multi_index(
ind_diff = np.ravel_multi_index(
(
xp.tile(qxx.ravel(), 4),
np.tile(qxx.ravel(), 4),
ind0_diff.ravel(),
ind1_diff.ravel(),
),
Expand All @@ -1057,45 +1057,45 @@ def _solve_for_indicies(
)

# solve for diffraction normalization
line_y_diff_norm = xp.arange(-(s[-1] - 1) / 2, s[-1] / 2)
line_z_diff_norm = line_y_diff_norm * xp.tan(tilt) + (s[-1] - 1) / 2
line_y_diff_norm = np.arange(-(s[-1] - 1) / 2, s[-1] / 2)
line_z_diff_norm = line_y_diff_norm * np.tan(tilt) + (s[-1] - 1) / 2
line_y_diff_norm += (s[-1] - 1) / 2

yF_diff_norm = xp.floor(line_y_diff_norm).astype("int")
zF_diff_norm = xp.floor(line_z_diff_norm).astype("int")
yF_diff_norm = np.floor(line_y_diff_norm).astype("int")
zF_diff_norm = np.floor(line_z_diff_norm).astype("int")
dy_diff_norm = line_y_diff_norm - yF_diff_norm
dz_diff_norm = line_z_diff_norm - zF_diff_norm

ind0_diff_norm = np.hstack(
(
xp.tile(yF_diff_norm, s[-1]),
xp.tile(yF_diff_norm + 1, s[-1]),
xp.tile(yF_diff_norm, s[-1]),
xp.tile(yF_diff_norm + 1, s[-1]),
np.tile(yF_diff_norm, s[-1]),
np.tile(yF_diff_norm + 1, s[-1]),
np.tile(yF_diff_norm, s[-1]),
np.tile(yF_diff_norm + 1, s[-1]),
)
)

ind1_diff_norm = np.hstack(
(
xp.tile(zF_diff_norm, s[-1]),
xp.tile(zF_diff_norm, s[-1]),
xp.tile(zF_diff_norm + 1, s[-1]),
xp.tile(zF_diff_norm + 1, s[-1]),
np.tile(zF_diff_norm, s[-1]),
np.tile(zF_diff_norm, s[-1]),
np.tile(zF_diff_norm + 1, s[-1]),
np.tile(zF_diff_norm + 1, s[-1]),
)
)

weights_diff_norm = np.hstack(
(
xp.tile(((1 - dy_diff) * (1 - dz_diff_norm)), s[-1]),
xp.tile(((dy_diff) * (1 - dz_diff_norm)), s[-1]),
xp.tile(((1 - dy_diff) * (dz_diff_norm)), s[-1]),
xp.tile(((dy_diff) * (dz_diff_norm)), s[-1]),
np.tile(((1 - dy_diff) * (1 - dz_diff_norm)), s[-1]),
np.tile(((dy_diff) * (1 - dz_diff_norm)), s[-1]),
np.tile(((1 - dy_diff) * (dz_diff_norm)), s[-1]),
np.tile(((dy_diff) * (dz_diff_norm)), s[-1]),
)
)

ind_diff_norm = xp.ravel_multi_index(
ind_diff_norm = np.ravel_multi_index(
(
xp.tile(qxx.ravel(), 4),
np.tile(qxx.ravel(), 4),
ind0_diff_norm.ravel(),
ind1_diff_norm.ravel(),
),
Expand All @@ -1104,16 +1104,16 @@ def _solve_for_indicies(
)

# normalization real space
ind_real_bincount_weight = xp.bincount(
ind_real_bincount_weight = np.bincount(
ind_real.ravel(), weights_real.ravel(), minlength=ind_real.max()
)
ind_real_bincount = xp.bincount(ind_real.ravel(), minlength=ind_real.max())
ind_real_bincount = np.bincount(ind_real.ravel(), minlength=ind_real.max())
ind_real_bincount_weight = ind_real_bincount_weight[ind_real_bincount > 0]
ind_real_bincount = ind_real_bincount[ind_real_bincount > 0]
ind_real_bincount_weight[ind_real_bincount_weight == 0] = 1
correction_factor_real = 1 / ind_real_bincount_weight
correction_factor_real = xp.repeat(correction_factor_real, ind_real_bincount)
sorted_indicies = xp.argsort(xp.argsort(ind_real.ravel()))
correction_factor_real = np.repeat(correction_factor_real, ind_real_bincount)
sorted_indicies = np.argsort(np.argsort(ind_real.ravel()))
correction_factor_real = correction_factor_real[sorted_indicies].reshape(
ind_real.shape
)
Expand All @@ -1122,15 +1122,15 @@ def _solve_for_indicies(
# normalization reciprocal space
bincount_max = np.max((ind_diff.max(), ind_diff_norm.max())) + 1

ind_diff_bincount_weight = xp.bincount(
ind_diff_bincount_weight = np.bincount(
ind_diff.ravel(), weights_diff.ravel(), minlength=bincount_max
)
ind_diff_bincount = xp.bincount(ind_diff.ravel(), minlength=bincount_max)
ind_diff_bincount = np.bincount(ind_diff.ravel(), minlength=bincount_max)

ind_diff_bincount_weight_norm = xp.bincount(
ind_diff_bincount_weight_norm = np.bincount(
ind_diff_norm.ravel(), weights_diff_norm.ravel(), minlength=bincount_max
)
ind_diff_bincount_norm = xp.bincount(
ind_diff_bincount_norm = np.bincount(
ind_diff_norm.ravel(), minlength=bincount_max
)

Expand All @@ -1147,8 +1147,8 @@ def _solve_for_indicies(
ind_diff_bincount_weight_norm / ind_diff_bincount_weight
)

correction_factor_diff = xp.repeat(correction_factor_diff, ind_diff_bincount)
sorted_indicies = xp.argsort(xp.argsort(ind_diff.ravel()))
correction_factor_diff = np.repeat(correction_factor_diff, ind_diff_bincount)
sorted_indicies = np.argsort(np.argsort(ind_diff.ravel()))
correction_factor_diff = correction_factor_diff[sorted_indicies].reshape(
ind_diff.shape
)
Expand All @@ -1166,16 +1166,16 @@ def _solve_for_indicies(
self._ind_diff_norm = []
self._weights_diff_norm = []

self._ind_real.append(ind_real)
self._ind_diff.append(ind_diff)
self._weights_real.append(weights_real)
self._weights_diff.append(weights_diff)
self._ind0_diff.append(ind0_diff)
self._ind1_diff.append(ind1_diff)
self._ind0_diff_norm.append(ind0_diff_norm)
self._ind1_diff_norm.append(ind1_diff_norm)
self._ind_diff_norm.append(ind_diff_norm)
self._weights_diff_norm.append(weights_diff_norm)
self._ind_real.append(xp.asarray(ind_real))
self._ind_diff.append(xp.asarray(ind_diff))
self._weights_real.append(xp.asarray(weights_real))
self._weights_diff.append(xp.asarray(weights_diff))
self._ind0_diff.append(xp.asarray(ind0_diff))
self._ind1_diff.append(xp.asarray(ind1_diff))
self._ind0_diff_norm.append(xp.asarray(ind0_diff_norm))
self._ind1_diff_norm.append(xp.asarray(ind1_diff_norm))
self._ind_diff_norm.append(xp.asarray(ind_diff_norm))
self._weights_diff_norm.append(xp.asarray(weights_diff_norm))

def _reshape_4D_array_to_2D(
self,
Expand Down

0 comments on commit 0fb9234

Please sign in to comment.