Skip to content

Commit

Permalink
Reformat docs
Browse files Browse the repository at this point in the history
  • Loading branch information
LiQian-XC committed Nov 4, 2022
1 parent 3ae065c commit 44681d1
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 79 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ build_doc.sh
build/
build_wheel.sh
/dist/
docs/source/generated/
35 changes: 35 additions & 0 deletions docs/source/_templates/autosummary/class.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{{ fullname | escape | underline}}

.. currentmodule:: {{ module }}

.. add toctree option to make autodoc generate the pages
.. autoclass:: {{ objname }}

{% block attributes %}
{% if attributes %}
.. rubric:: Attributes

.. autosummary::
:toctree: .
{% for item in attributes %}
{% if has_attr(fullname, item) %}
~{{ fullname }}.{{ item }}
{% endif %}
{%- endfor %}
{% endif %}
{% endblock %}

{% block methods %}
{% if methods %}
.. rubric:: Methods

.. autosummary::
:toctree: .
{% for item in methods %}
{%- if item != '__init__' %}
~{{ fullname }}.{{ item }}
{%- endif -%}
{%- endfor %}
{% endif %}
{% endblock %}
10 changes: 10 additions & 0 deletions docs/source/api_train.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Model training and prediction
==================================

.. currentmodule:: sctour

.. autosummary::
:toctree: generated/
:nosignatures:

train.Trainer
4 changes: 4 additions & 0 deletions docs/source/api_vf.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Vector field visualization
==================================

.. autofunction:: sctour.vf.plot_vector_field
59 changes: 19 additions & 40 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath('../../'))
Expand All @@ -22,47 +10,32 @@
project = 'sctour'
author = 'Qian Li'
copyright = f'{datetime.now():%Y}, {author}'

# The full version, including alpha/beta/rc tags
release = sctour.__version__


# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.intersphinx',
'sphinx.ext.viewcode',
'myst_parser',
'nbsphinx'
'nbsphinx',
'sphinx.ext.autosummary',
]
autosummary_generate = True
autodoc_member_order = 'bysource'
napoleon_include_init_with_doc = False
napoleon_numpy_docstring = True
napoleon_use_rtype = True
napoleon_use_param = True

# Add any paths that contain templates here, relative to this directory.
source_suffix = ['.rst', '.md']
# settings
master_doc = 'index'

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []

templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
pygments_style = 'sphinx'
# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
napoleon_numpy_docstring = True

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']

intersphinx_mapping = dict(
python=('https://docs.python.org/3/', None),
Expand All @@ -71,5 +44,11 @@
anndata=('https://anndata.readthedocs.io/en/stable/', None),
scanpy=('https://scanpy.readthedocs.io/en/stable/', None),
scipy=('https://docs.scipy.org/doc/scipy/reference/', None),
torch=('https://pytorch.org/docs/master/', None)
torch=('https://pytorch.org/docs/master/', None),
matplotlib=('https://matplotlib.org/stable/', None),
)


# -- Options for HTML output -------------------------------------------------
html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']
12 changes: 4 additions & 8 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
.. sctour documentation master file, created by
sphinx-quickstart on Fri Apr 8 21:15:33 2022.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to sctour's documentation!
Welcome to scTour's documentation!
==================================

.. include:: ../../README.md
Expand All @@ -26,7 +21,8 @@ Welcome to sctour's documentation!

.. toctree::
:maxdepth: 2
:caption: Main API:
:caption: API:
:hidden:

sctour
api_train
api_vf
20 changes: 0 additions & 20 deletions docs/source/sctour.rst

This file was deleted.

22 changes: 11 additions & 11 deletions sctour/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __init__(
self.log = defaultdict(list)


def get_data_loaders(self) -> None:
def _get_data_loaders(self) -> None:
"""
Generate Data Loaders for training and validation datasets.
"""
Expand All @@ -213,23 +213,23 @@ def get_data_loaders(self) -> None:


def train(self):
self.get_data_loaders()
self._get_data_loaders()

params = filter(lambda p: p.requires_grad, self.model.parameters())
self.optimizer = torch.optim.Adam(params, lr = self.lr, weight_decay = self.wt_decay, eps = self.eps)

with tqdm(total=self.nepoch, unit='epoch') as t:
for tepoch in range(t.total):
train_loss = self.on_epoch_train(self.train_dl)
val_loss = self.on_epoch_val(self.val_dl)
train_loss = self._on_epoch_train(self.train_dl)
val_loss = self._on_epoch_val(self.val_dl)
self.log['train_loss'].append(train_loss)
self.log['validation_loss'].append(val_loss)
t.set_description(f"Epoch {tepoch + 1}")
t.set_postfix({'train_loss': train_loss, 'val_loss': val_loss}, refresh=False)
t.update()


def on_epoch_train(self, DL) -> float:
def _on_epoch_train(self, DL) -> float:
"""
Go through the model and update the model parameters.
Expand Down Expand Up @@ -263,7 +263,7 @@ def on_epoch_train(self, DL) -> float:


@torch.no_grad()
def on_epoch_val(self, DL) -> float:
def _on_epoch_val(self, DL) -> float:
"""
Validate using validation dataset.
Expand Down Expand Up @@ -354,7 +354,7 @@ def get_vector_field(
The estimated vector field.
"""

model = self.get_model(model)
model = self._get_model(model)
model.eval()
if not (isinstance(T, np.ndarray) and isinstance(Z, np.ndarray)):
raise TypeError('The inputs must be numpy arrays.')
Expand Down Expand Up @@ -440,7 +440,7 @@ def get_latentsp(
3-tuple of mixed latent space, encoder-derived latent space, and ODE-solver-derived latent space.
"""

model = self.get_model(model)
model = self._get_model(model)
model.eval()

if (alpha_z < 0) or (alpha_z > 1):
Expand Down Expand Up @@ -563,7 +563,7 @@ def predict_time(
The predicted pseudotime and (if `get_ltsp = True`) the latent space.
"""

model = self.get_model(model)
model = self._get_model(model)
model.eval()

if self.time_reverse is None:
Expand Down Expand Up @@ -680,7 +680,7 @@ def predict_ltsp_from_time(
Predicted latent space for the unobserved time interval.
"""

model = self.get_model(model)
model = self._get_model(model)
model.eval()

if not isinstance(T, np.ndarray):
Expand Down Expand Up @@ -735,7 +735,7 @@ def predict_ltsp_from_time(
return pred_T_zs.numpy()


def get_model(self, model):
def _get_model(self, model):
"""
Get the model for inference/prediction (for internal use).
"""
Expand Down

0 comments on commit 44681d1

Please sign in to comment.