diff --git a/src/pyrovelocity/utils.py b/src/pyrovelocity/utils.py index 0e58fa48b..092cb3f31 100644 --- a/src/pyrovelocity/utils.py +++ b/src/pyrovelocity/utils.py @@ -5,15 +5,11 @@ import os import sys from inspect import getmembers -from numbers import Integral -from numbers import Real +from numbers import Integral, Real from pathlib import Path from pprint import pprint -from types import FunctionType -from types import ModuleType -from typing import Callable -from typing import List -from typing import Tuple +from types import FunctionType, ModuleType +from typing import Callable, List, Tuple import matplotlib.pyplot as plt import numpy as np @@ -25,12 +21,12 @@ import yaml from anndata._core.anndata import AnnData from beartype import beartype +from jaxtyping import ArrayLike from scvi.data import synthetic_iid from pyrovelocity.io.compressedpickle import CompressedPickle from pyrovelocity.logging import configure_logging - # import torch # from scipy.sparse import issparse # from sklearn.decomposition import PCA @@ -179,12 +175,15 @@ def pretty_log_dict(d: dict) -> str: for key, value in d.items(): # key_colored = colored(key, "green") key_colored = key - value_lines = str(value).split("\n") - value_colored = "\n".join( - # colored(line, "white") for line in value_lines - line - for line in value_lines - ) + if isinstance(value, ArrayLike): + value_colored = f"{value.shape}" + else: + value_lines = str(value).split("\n") + value_colored = "\n".join( + # colored(line, "white") for line in value_lines + line + for line in value_lines + ) dict_as_string += f"{key_colored}:\n{value_colored}\n" return dict_as_string