Skip to content

Commit

Permalink
Update parameters for vector field visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
LiQian-XC committed Aug 10, 2022
1 parent 2494ac3 commit 63c4e8e
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 65 deletions.
9 changes: 9 additions & 0 deletions sctour/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import logging

logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)

info = logger.info
warn = logger.warning
error = logger.error
debug = logger.debug
165 changes: 100 additions & 65 deletions sctour/vector_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from typing import Optional, Union

from ._utils import l2_norm
from . import logger


def cosine_similarity(
adata: AnnData,
zs_key: str,
reverse: bool = False,
use_rep_neigh: Optional[str] = None,
vf_key: str = 'VF',
vf_key: str = 'X_VF',
run_neigh: bool = True,
n_neigh: int = 20,
t_key: Optional[str] = None,
Expand All @@ -31,23 +32,23 @@ def cosine_similarity(
adata
An :class:`~anndata.AnnData` object.
reverse
Whether to reverse the direction of vector field.
Whether to reverse the direction of the vector field. When the pseudotime returned by get_time() function was in reverse order and you used the post-inference adjustment (reverse_time() function), please set this parameter to `True`.
(Default: `False`)
zs_key
The key in `.obsm` for storing the latent space.
vf_key
The key in `.obsm` for storing the vector field.
(Default: `'VF'`)
(Default: `'X_VF'`)
run_neigh
Whether to run neighbor detection.
(Default: `True`)
use_rep_neigh
The representation in `.obsm` for neighbor detection.
The representation in `.obsm` which will be used for neighbor detection.
n_neigh
The number of neighbors for each cell.
The number of neighbors considered for each cell.
(Default: 20)
t_key:
The key in `.obs` for estimated time for neighbor detection.
The key in `.obs` for estimated pseudotime which will be considered when detecting neighbors.
var_stabilize_transform
Whether to perform variance-stabilizing transformation for vector field and cell-neighbor latent state difference.
(Default: `False`)
Expand All @@ -58,21 +59,29 @@ def cosine_similarity(
A sparse matrix with cosine similarities.
"""

Z = adata.obsm[f'X_{zs_key}']
V = adata.obsm[f'X_{vf_key}']
Z = np.array(adata.obsm[zs_key])
V = np.array(adata.obsm[vf_key])
if reverse:
V = -V
if var_stabilize_transform:
V = np.sqrt(np.abs(V)) * np.sign(V)

ncells = adata.n_obs

if run_neigh:
sc.pp.neighbors(adata, use_rep = f'X_{use_rep_neigh}', n_neighbors = n_neigh)
if run_neigh or ('neighbors' not in adata.uns):
if use_rep_neigh is None:
use_rep_neigh = zs_key
logger.warn(f"Warning: the parameter `use_rep_neigh` in function `plot_vector_field` is not provided. Use `{zs_key}` in `.obsm` of the AnnData instead.")
else:
if use_rep_neigh not in adata.obsm:
raise KeyError(f"`{use_rep_neigh}` not found in `.obsm` of the AnnData. Please provide valid `use_rep_neigh` for neighbor detection.")
sc.pp.neighbors(adata, use_rep = use_rep_neigh, n_neighbors = n_neigh)
n_neigh = adata.uns['neighbors']['params']['n_neighbors'] - 1
# indices_matrix = adata.obsp['distances'].indices.reshape(-1, n_neigh)

if t_key is not None:
if t_key not in adata.obs:
raise KeyError(f"`{t_key}` not found in `.obs` of the AnnData. Please provide valid `t_key` for estimated pseudotime.")
ts = adata.obs[t_key].values
indices_matrix2 = np.zeros((ncells, n_neigh), dtype = int)
for i in range(ncells):
Expand All @@ -94,6 +103,7 @@ def cosine_similarity(
if var_stabilize_transform:
dZ = np.sqrt(np.abs(dZ)) * np.sign(dZ)
cos_sim = np.einsum("ij, j", dZ, V[i]) / (l2_norm(dZ, axis = 1) * l2_norm(V[i]))
cos_sim[np.isnan(cos_sim)] = 0
vals.extend(cos_sim)
rows.extend(np.repeat(i, len(idx)))
cols.extend(idx)
Expand Down Expand Up @@ -172,7 +182,7 @@ def vector_field_embedding(
The weighted unitary displacement vectors.
"""

T = adata.obsp[T_key]
T = adata.obsp[T_key].copy()

if self_transition:
max_t = T.max(1).A.flatten()
Expand All @@ -186,7 +196,7 @@ def vector_field_embedding(
T.setdiag(0)
T.eliminate_zeros()

E = adata.obsm[f'X_{E_key}']
E = np.array(adata.obsm[E_key])
V = np.zeros(E.shape)

for i in range(adata.n_obs):
Expand Down Expand Up @@ -219,7 +229,7 @@ def vector_field_embedding_grid(
V
The unitary displacement vectors under the embedding.
smooth
The factor for scale in Gaussian kernel.
The factor for scale in Gaussian pdf.
(Default: 0.5)
stream
Whether to adjust for streamplot.
Expand Down Expand Up @@ -260,7 +270,7 @@ def vector_field_embedding_grid(

if stream:
E_grid = np.stack(grs)
ns = int(50 * density)
ns = E_grid.shape[1]
V_grid = V_grid.T.reshape(2, ns, ns)

mass = np.sqrt((V_grid * V_grid).sum(0))
Expand All @@ -283,100 +293,126 @@ def vector_field_embedding_grid(

def plot_vector_field(
adata: AnnData,
zs_key: str,
reverse: bool = False,
zs_key: Optional[str] = None,
vf_key: str = 'VF',
vf_key: str = 'X_VF',
run_neigh: bool = True,
use_rep_neigh: Optional[str] = None,
t_key: Optional[str] = None,
n_neigh: int = 20,
var_stabilize_transform: bool = False,
E_key: str = 'umap',
E_key: str = 'X_umap',
scale: int = 10,
self_transition: bool = False,
smooth: float = 0.5,
density: float = 1.,
grid: bool = False,
stream: bool = True,
stream_density: int = 2,
stream_color: str = 'k',
linewidth: int = 1,
arrowsize: int = 1,
density: float = 1.,
arrow_size_grid: int = 1,
arrow_length_grid: int = 1,
arrow_color_grid: str = 'grey',
stream_linewidth: int = 1,
stream_arrowsize: int = 1,
grid_density: float = 1.,
color: Optional[str] = None,
grid_arrowcolor: str = 'grey',
grid_arrowlength: int = 1,
grid_arrowsize: int = 1,
# color: Optional[str] = None,
# ax: Optional[Axes] = None,
**kwargs,
):
"""
Visualize the vector field.
The visulization of vector field under an embedding borrows the ideas from scvelo: https://github.com/theislab/scvelo.
The visualization of vector field under an embedding borrows the ideas from scvelo: https://github.com/theislab/scvelo.
Parameters
----------
adata
An :class:`~anndata.AnnData` object.
reverse
Whether to reverse the direction of vector field.
zs_key
The key in `.obsm` for storing the latent space.
reverse
Whether to reverse the direction of the vector field. When the pseudotime returned by get_time() function was in reverse order and you used the post-inference adjustment (reverse_time() function), please set this parameter to `True`.
(Default: `False`)
vf_key
The key in `.obsm` for storing the vector field.
run_neigh
Whether to run neighbor detection.
(Default: `True`)
use_rep_neigh
The representation in `.obsm` for neighbor detection.
The representation in `.obsm` which will be used for neighbor detection.
t_key:
The key in `.obs` for estimated time for neighbor detection.
The key in `.obs` for estimated pseudotime which will be considered when detecting neighbors.
n_neigh
The number of neighbors for each cell.
The number of neighbors considered for each cell.
(Default: 20)
var_stabilize_transform
Whether to perform variance-stabilizing transformation for vector field and cell-neighbor latent state difference.
(Default: `False`)
E_key
The key in `.obsm` for embedding.
(Default: `'X_umap'`)
scale
Scale factor for cosine similarity.
(Default: 10)
self_transition
Whether to take self-transition into consideration.
(Default: `False`)
smooth
The factor for scale in Gaussian kernel.
The factor for scale in Gaussian pdf.
(Default: 0.5)
density
Percentage of cells to show when displaying the vector field in the per-cell level.
(Default: 1.)
grid
Whether to draw grid-level vector field.
Whether to display vector field as arrows in grid level.
(Default: `False`)
stream
Whether to draw streamplot.
Whether to display vector field as streamplot.
(Default: `True`)
stream_density
The density parameter for streamplot for controlling the closeness of the streamlines.
The density parameter in streamplot for controlling the closeness of the streamlines.
(Default: 2)
stream_color
The streamline color for streamplot.
linewidth
(Default: 'k')
stream_linewidth
The line width for streamplot.
arrowsize
(Default: 1)
stream_arrowsize
The arrow size for streamplot.
density
Percentage of cell positions to show.
arrow_size_grid
The arrow size in grid-level vector field.
arrow_length_grid
The arrow length in grid-level vector field
arrow_color_grid
The arrow color in grid-level vector field
(Default: 1)
grid_density
The grid-level density for showing vector field
color
`color` parameter in :func:`scanpy.pl.umap`.
ax
The matplotlib axes
The grid-level density for showing vector field.
(Default: 1.)
grid_arrowcolor
The arrow color when showing vector field as arrows in grid level.
(Default: `'grey'`)
grid_arrowlength
The arrow length when showing vector field as arrows in grid level.
(Default: 1)
grid_arrowsize
The arrow size when showing vector field as arrows in grid level.
(Default: 1)
kwargs
Parameters passed to :func:`scanpy.pl.umap`
Parameters passed to :func:`scanpy.pl.embedding`.
Returns
----------
:class:`~matplotlib.axes.Axes`
An :class:`~matplotlib.axes.Axes` object.
"""

if zs_key not in adata.obsm:
raise KeyError(f"`{zs_key}` not found in `.obsm` of the AnnData. Please provide valid `zs_key` for latent space.")
if vf_key not in adata.obsm:
raise KeyError(f"`{vf_key}` not found in `.obsm` of the AnnData. Please provide valid `vf_key` for vector field.")
if E_key not in adata.obsm:
raise KeyError(f"`{E_key}` not found in `.obsm` of the AnnData. Please provide valid `E_key` for embedding.")
if (grid_density < 0) or (grid_density > 1):
raise ValueError("`grid_density` must be between 0 and 1.")
if (density < 0) or (density > 1):
raise ValueError("`density` must be between 0 and 1.")

##calculate cosine similarity
adata.obsp['cosine_similarity'] = cosine_similarity(
adata,
Expand All @@ -398,8 +434,8 @@ def plot_vector_field(
self_transition = self_transition,
)

E = adata.obsm[f'X_{E_key}']
V = adata.obsm[f'X_DV']
E = np.array(adata.obsm[E_key])
V = adata.obsm['X_DV']

if grid:
stream = False
Expand All @@ -413,28 +449,29 @@ def plot_vector_field(
density = grid_density,
)

ax = sc.pl.embedding(adata, basis = E_key, color = color, show=False, **kwargs)
ax = sc.pl.embedding(adata, basis = E_key, show=False, **kwargs)
if stream:
lengths = np.sqrt((V * V).sum(0))
linewidth *= 2 * lengths / lengths[~np.isnan(lengths)].max()
stream_linewidth *= 2 * lengths / lengths[~np.isnan(lengths)].max()
stream_kwargs = dict(
linewidth = linewidth,
linewidth = stream_linewidth,
density = stream_density,
zorder = 3,
color = stream_color,
arrowsize = arrowsize,
arrowsize = stream_arrowsize,
arrowstyle = '-|>',
maxlength = 4,
integration_direction = 'both',
)
ax.streamplot(E[0], E[1], V[0], V[1], **stream_kwargs)
else:
if density < 1:
idx = np.random.choice(len(E), int(len(E) * density), replace = False)
E = E[idx]
V = V[idx]
scale = 1 / arrow_length_grid
hl, hw, hal = 6 * arrow_size_grid, 5 * arrow_size_grid, 4 * arrow_size_grid
if not grid:
if density < 1:
idx = np.random.choice(len(E), int(len(E) * density), replace = False)
E = E[idx]
V = V[idx]
scale = 1 / grid_arrowlength
hl, hw, hal = 6 * grid_arrowsize, 5 * grid_arrowsize, 4 * grid_arrowsize
quiver_kwargs = dict(
angles = 'xy',
scale_units = 'xy',
Expand All @@ -444,12 +481,10 @@ def plot_vector_field(
headlength = hl,
headwidth = hw,
headaxislength = hal,
color = arrow_color_grid,
color = grid_arrowcolor,
linewidth = 0.2,
zorder = 3,
)
ax.quiver(E[:, 0], E[:, 1], V[:, 0], V[:, 1], **quiver_kwargs)

# ax = sc.pl.embedding(adata, basis = E_key, color = color, ax = ax, show = False, **kwargs)

return ax

0 comments on commit 63c4e8e

Please sign in to comment.