Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions py4DSTEM/process/diffraction/crystal_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ def plot_structure(
camera_dist: Optional[float] = None,
show_axes: bool = False,
perspective_axes: bool = True,
figsize: Union[tuple, list, np.ndarray] = (8, 8),
figsize: tuple = (8, 8),
returnfig: bool = False,
):
"""
Quick 3D plot of the untit cell /atomic structure.
Quick 3D plot of the unit cell/atomic structure.

Args:
orientation_matrix (array): (3,3) orientation matrix, where columns represent projection directions.
zone_axis_lattice (array): (3,) projection direction in lattice indices
proj_x_lattice (array): (3,) x-axis direction in lattice indices
zone_axis_cartesian (array): (3,) cartesian projection direction
proj_x_cartesian (array): (3,) cartesian projection direction
scale_markers (float): Size scaling for markers
size_marker (float): Size scaling for markers
tol_distance (float): Tolerance for repeating atoms on edges on cell boundaries.
plot_limit (float): (2,3) numpy array containing x y z plot min and max in columns.
Default is 1.1* unit cell dimensions.
Expand Down Expand Up @@ -98,20 +98,17 @@ def plot_structure(
sub = pos[:, 0] < tol_distance
pos = np.vstack([pos, pos[sub, :] + np.array([1, 0, 0])])
ID = np.hstack([ID, ID[sub]])
if occ is not None:
occ = np.hstack([occ, occ[sub]])
occ = np.hstack([occ, occ[sub]])
# y tile
sub = pos[:, 1] < tol_distance
pos = np.vstack([pos, pos[sub, :] + np.array([0, 1, 0])])
ID = np.hstack([ID, ID[sub]])
if occ is not None:
occ = np.hstack([occ, occ[sub]])
occ = np.hstack([occ, occ[sub]])
# z tile
sub = pos[:, 2] < tol_distance
pos = np.vstack([pos, pos[sub, :] + np.array([0, 0, 1])])
ID = np.hstack([ID, ID[sub]])
if occ is not None:
occ = np.hstack([occ, occ[sub]])
occ = np.hstack([occ, occ[sub]])

# Cartesian atomic positions
xyz = pos @ self.lat_real
Expand Down Expand Up @@ -150,7 +147,7 @@ def plot_structure(

# atoms
ID_all = np.unique(ID)
if occ is None:
if np.all(occ == 1.0):
for ID_plot in ID_all:
sub = ID == ID_plot
ax.scatter(
Expand All @@ -166,7 +163,7 @@ def plot_structure(
# init
tol = 1e-4
num_seg = 180
radius = 0.7
radius = size_marker / 800
zp = np.zeros(num_seg + 1)

mark = np.ones(xyz.shape[0], dtype="bool")
Expand Down