Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding occupancy to py4DSTEM diffraction #538

Merged
merged 10 commits into from
Jan 23, 2024
Merged
113 changes: 102 additions & 11 deletions py4DSTEM/process/diffraction/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fractions import Fraction
from typing import Union, Optional
import sys
import warnings

from emdfile import PointList
from py4DSTEM.process.utils import single_atom_scatter, electron_wavelength_angstrom
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
positions,
numbers,
cell,
occupancy=None,
):
"""
Args:
Expand All @@ -76,7 +78,7 @@ def __init__(
3 numbers: the three lattice parameters for an orthorhombic cell
6 numbers: the a,b,c lattice parameters and ɑ,β,ɣ angles for any cell
3x3 array: row vectors containing the (u,v,w) lattice vectors.

occupancy (np.array): Partial occupancy values for each atomic site. Must match the length of positions
"""
# Initialize Crystal
self.positions = np.asarray(positions) #: fractional atomic coordinates
Expand Down Expand Up @@ -131,6 +133,17 @@ def __init__(
else:
raise Exception("Cell cannot contain " + np.size(cell) + " entries")

# occupancy
if occupancy is not None:
self.occupancy = np.array(occupancy)
# check the occupancy shape makes sense
if self.occupancy.shape[0] != self.positions.shape[0]:
raise Warning(
f"Number of occupancies ({self.occupancy.shape[0]}) and atomic positions ({self.positions.shape[0]}) do not match"
)
else:
self.occupancy = np.ones(self.positions.shape[0], dtype=np.float32)

# pymatgen flag
if "pymatgen" in sys.modules:
self.pymatgen_available = True
Expand Down Expand Up @@ -257,7 +270,70 @@ def get_strained_crystal(
else:
return crystal_strained

def from_CIF(CIF, conventional_standard_structure=True):
@staticmethod
def from_ase(
atoms,
):
"""
Create a py4DSTEM Crystal object from an ASE atoms object

Args:
atoms (ase.Atoms): an ASE atoms object

"""
# get the occupancies from the atoms object
occupancies = (
atoms.arrays["occupancies"]
if "occupancies" in atoms.arrays.keys()
else None
)

if "occupancy" in atoms.info.keys():
warnings.warn(
"This Atoms object contains occupancy information but it will be ignored."
)

xtal = Crystal(
positions=atoms.get_scaled_positions(), # fractional coords
numbers=atoms.numbers,
cell=atoms.cell.array,
occupancy=occupancies,
)
return xtal

@staticmethod
def from_prismatic(filepath):
"""
Create a py4DSTEM Crystal object from an prismatic style xyz co-ordinate file

Args:
filepath (str|Pathlib.Path): path to the prismatic format xyz file

"""

from ase import io

# read the atoms using ase
atoms = io.read(filepath, format="prismatic")

# get the occupancies from the atoms object
occupancies = (
atoms.arrays["occupancies"]
if "occupancies" in atoms.arrays.keys()
else None
)
xtal = Crystal(
positions=atoms.get_scaled_positions(), # fractional coords
numbers=atoms.numbers,
cell=atoms.cell.array,
occupancy=occupancies,
)
return xtal

@staticmethod
def from_CIF(
CIF, primitive: bool = True, conventional_standard_structure: bool = True
):
"""
Create a Crystal object from a CIF file, using pymatgen to import the CIF

Expand All @@ -273,12 +349,13 @@ def from_CIF(CIF, conventional_standard_structure=True):

parser = CifParser(CIF)

structure = parser.get_structures(False)[0]
structure = parser.get_structures(primitive=primitive)[0]

return Crystal.from_pymatgen_structure(
structure, conventional_standard_structure=conventional_standard_structure
)

@staticmethod
def from_pymatgen_structure(
structure=None,
formula=None,
Expand Down Expand Up @@ -375,8 +452,6 @@ def from_pymatgen_structure(
else selected["structure"]
)

positions = structure.frac_coords #: fractional atomic coordinates

cell = np.array(
[
structure.lattice.a,
Expand All @@ -388,10 +463,22 @@ def from_pymatgen_structure(
]
)

numbers = np.array([s.species.elements[0].Z for s in structure])
site_data = np.array(
[
(*site.frac_coords, elem.number, comp)
for site in structure
for elem, comp in site.species.items()
]
)
positions = site_data[:, :3]
numbers = site_data[:, 3]
occupancies = site_data[:, 4]

return Crystal(positions, numbers, cell)
return Crystal(
positions=positions, numbers=numbers, cell=cell, occupancy=occupancies
)

@staticmethod
def from_unitcell_parameters(
latt_params,
elements,
Expand Down Expand Up @@ -575,10 +662,14 @@ def calculate_structure_factors(
# Calculate structure factors
self.struct_factors = np.zeros(np.size(self.g_vec_leng, 0), dtype="complex64")
for a0 in range(self.positions.shape[0]):
self.struct_factors += f_all[:, a0] * np.exp(
(2j * np.pi)
* np.sum(
self.hkl * np.expand_dims(self.positions[a0, :], axis=1), axis=0
self.struct_factors += (
f_all[:, a0]
* self.occupancy[a0]
* np.exp(
(2j * np.pi)
* np.sum(
self.hkl * np.expand_dims(self.positions[a0, :], axis=1), axis=0
)
)
)

Expand Down
7 changes: 4 additions & 3 deletions py4DSTEM/process/diffraction/crystal_bloch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def calculate_dynamical_structure_factors(
tol_structure_factor: float = 0.0,
recompute_kinematic_structure_factors=True,
g_vec_precision=None,
verbose=True,
):
"""
Calculate and store the relativistic corrected structure factors used for Bloch computations
Expand Down Expand Up @@ -92,7 +91,7 @@ def calculate_dynamical_structure_factors(

# Calculate the reciprocal lattice points to include based on k_max

k_max = np.asarray(k_max)
k_max: np.ndarray = np.asarray(k_max)

if recompute_kinematic_structure_factors:
if hasattr(self, "struct_factors"):
Expand Down Expand Up @@ -215,7 +214,9 @@ def get_f_e(q, Z, thermal_sigma, method):

# Calculate structure factors
struct_factors = np.sum(
f_e * np.exp(2.0j * np.pi * np.squeeze(self.positions[:, None, :] @ hkl)),
f_e
* self.occupancy[:, None]
* np.exp(2.0j * np.pi * np.squeeze(self.positions[:, None, :] @ hkl)),
axis=0,
)

Expand Down
124 changes: 113 additions & 11 deletions py4DSTEM/process/diffraction/crystal_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from matplotlib.axes import Axes
import matplotlib.tri as mtri
from mpl_toolkits.mplot3d import Axes3D, art3d
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

from scipy.signal import medfilt
from scipy.ndimage import gaussian_filter
from scipy.ndimage import distance_transform_edt
Expand Down Expand Up @@ -91,18 +93,26 @@ def plot_structure(

# Fractional atomic coordinates
pos = self.positions
occ = self.occupancy

# x tile
sub = pos[:, 0] < tol_distance
pos = np.vstack([pos, pos[sub, :] + np.array([1, 0, 0])])
ID = np.hstack([ID, ID[sub]])
if occ is not None:
occ = np.hstack([occ, occ[sub]])
# y tile
sub = pos[:, 1] < tol_distance
pos = np.vstack([pos, pos[sub, :] + np.array([0, 1, 0])])
ID = np.hstack([ID, ID[sub]])
if occ is not None:
occ = np.hstack([occ, occ[sub]])
# z tile
sub = pos[:, 2] < tol_distance
pos = np.vstack([pos, pos[sub, :] + np.array([0, 0, 1])])
ID = np.hstack([ID, ID[sub]])
if occ is not None:
occ = np.hstack([occ, occ[sub]])

# Cartesian atomic positions
xyz = pos @ self.lat_real
Expand Down Expand Up @@ -141,17 +151,109 @@ def plot_structure(

# atoms
ID_all = np.unique(ID)
for ID_plot in ID_all:
sub = ID == ID_plot
ax.scatter(
xs=xyz[sub, 1], # + d[0],
ys=xyz[sub, 0], # + d[1],
zs=xyz[sub, 2], # + d[2],
s=size_marker,
linewidth=2,
facecolors=atomic_colors(ID_plot),
edgecolor=[0, 0, 0],
)
if occ is None:
for ID_plot in ID_all:
sub = ID == ID_plot
ax.scatter(
xs=xyz[sub, 1], # + d[0],
ys=xyz[sub, 0], # + d[1],
zs=xyz[sub, 2], # + d[2],
s=size_marker,
linewidth=2,
facecolors=atomic_colors(ID_plot),
edgecolor=[0, 0, 0],
)
else:
# init
tol = 1e-4
num_seg = 180
radius = 0.7
zp = np.zeros(num_seg + 1)

mark = np.ones(xyz.shape[0], dtype="bool")
for a0 in range(xyz.shape[0]):
if mark[a0]:
xyz_plot = xyz[a0, :]
inds = np.argwhere(np.sum((xyz - xyz_plot) ** 2, axis=1) < tol)
occ_plot = occ[inds]
mark[inds] = False
ID_plot = ID[inds]

if np.sum(occ_plot) < 1.0:
occ_plot = np.append(occ_plot, 1 - np.sum(occ_plot))
ID_plot = np.append(ID_plot, -1)
else:
occ_plot = occ_plot[0]
ID_plot = ID_plot[0]

# Plot site as series of filled arcs
theta0 = 0
for a1 in range(occ_plot.shape[0]):
theta1 = theta0 + occ_plot[a1] * 2.0 * np.pi
theta = np.linspace(theta0, theta1, num_seg + 1)
xp = np.cos(theta) * radius
yp = np.sin(theta) * radius

# Rotate towards camera
xyz_rot = np.vstack((xp.ravel(), yp.ravel(), zp.ravel()))
if occ_plot[a1] < 1.0:
xyz_rot = np.append(
xyz_rot, np.array((0, 0, 0))[:, None], axis=1
)
xyz_rot = orientation_matrix @ xyz_rot

# add to plot
verts = [
list(
zip(
xyz_rot[1, :] + xyz_plot[1],
xyz_rot[0, :] + xyz_plot[0],
xyz_rot[2, :] + xyz_plot[2],
)
)
]
# ax.add_collection3d(
# Poly3DCollection(
# verts
# )
# )
collection = Poly3DCollection(
verts,
linewidths=2.0,
alpha=1.0,
edgecolors="k",
)
face_color = [
0.5,
0.5,
1,
] # alternative: matplotlib.colors.rgb2hex([0.5, 0.5, 1])
if ID_plot[a1] == -1:
collection.set_facecolor((1.0, 1.0, 1.0))
else:
collection.set_facecolor(atomic_colors(ID_plot[a1]))
ax.add_collection3d(collection)

# update start point
if a1 < occ_plot.size:
theta0 = theta1

# for ID_plot in ID_all:
# sub = ID == ID_plot
# ax.scatter(
# xs=xyz[sub, 1], # + d[0],
# ys=xyz[sub, 0], # + d[1],
# zs=xyz[sub, 2], # + d[2],
# s=size_marker,
# linewidth=2,
# facecolors='none',
# edgecolor=[0, 0, 0],
# )
# poly = PolyCollection(
# verts,
# facecolors=['r', 'g', 'b', 'y'],
# alpha = 0.6)
# ax.add_collection3d(poly, zs=zs, zdir='y')

# plot limit
if plot_limit is None:
Expand Down