Skip to content

Networks docstrings: add summary/inference network indicator #462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bayesflow/experimental/cif/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# disable module check, use potential module after moving from experimental
@serializable("bayesflow.networks", disable_module_check=True)
class CIF(InferenceNetwork):
"""Implements a continuously indexed flow (CIF) with a `CouplingFlow`
"""(IN) Implements a continuously indexed flow (CIF) with a `CouplingFlow`
bijection and `ConditionalGaussian` distributions p and q. Improves on
eliminating leaky sampling found topologically in normalizing flows.
Built in reference to [1].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# disable module check, use potential module after moving from experimental
@serializable("bayesflow.networks", disable_module_check=True)
class ContinuousTimeConsistencyModel(InferenceNetwork):
"""Implements an sCM (simple, stable, and scalable Consistency Model)
"""(IN) Implements an sCM (simple, stable, and scalable Consistency Model)
with continous-time Consistency Training (CT) as described in [1].
The sampling procedure is taken from [2].

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/experimental/free_form_flow/free_form_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# disable module check, use potential module after moving from experimental
@serializable("bayesflow.networks", disable_module_check=True)
class FreeFormFlow(InferenceNetwork):
"""Implements a dimensionality-preserving Free-form Flow.
"""(IN) Implements a dimensionality-preserving Free-form Flow.
Incorporates ideas from [1-2].

[1] Draxler, F., Sorrenson, P., Zimmermann, L., Rousselot, A., & Köthe, U. (2024).F
Expand Down
2 changes: 2 additions & 0 deletions bayesflow/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
A rich collection of neural network architectures for use in :py:class:`~bayesflow.approximators.Approximator`\ s.

The module features inference networks (IN), summary networks (SN), as well as general purpose networks.
"""

from .consistency_models import ConsistencyModel
Expand Down
4 changes: 2 additions & 2 deletions bayesflow/networks/consistency_models/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

@serializable("bayesflow.networks")
class ConsistencyModel(InferenceNetwork):
"""Implements a Consistency Model with Consistency Training (CT) a described in [1-2]. The adaptations to CT
described in [2] were taken into account in our implementation for ABI [3].
"""(IN) Implements a Consistency Model with Consistency Training (CT) as described in [1-2].
The adaptations to CT described in [2] were taken into account in our implementation for ABI [3].

[1] Song, Y., Dhariwal, P., Chen, M. & Sutskever, I. (2023). Consistency Models. arXiv preprint arXiv:2303.01469

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/networks/coupling_flow/coupling_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@serializable("bayesflow.networks")
class CouplingFlow(InferenceNetwork):
"""Implements a coupling flow as a sequence of dual couplings with permutations and activation
"""(IN) Implements a coupling flow as a sequence of dual couplings with permutations and activation
normalization. Incorporates ideas from [1-5].

[1] Kingma, D. P., & Dhariwal, P. (2018).
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/networks/deep_set/deep_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

@serializable("bayesflow.networks")
class DeepSet(SummaryNetwork):
"""Implements a deep set encoder introduced in [1] for learning permutation-invariant representations of
"""(SN) Implements a deep set encoder introduced in [1] for learning permutation-invariant representations of
set-based data, as generated by exchangeable models.

[1] Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., & Smola, A. J. (2017).
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

@serializable("bayesflow.networks")
class FlowMatching(InferenceNetwork):
"""Implements Optimal Transport Flow Matching, originally introduced as Rectified Flow, with ideas incorporated
"""(IN) Implements Optimal Transport Flow Matching, originally introduced as Rectified Flow, with ideas incorporated
from [1-3].

[1] Rectified Flow: arXiv:2209.03003
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@serializable("bayesflow.networks")
class TimeSeriesNetwork(SummaryNetwork):
"""
Implements a LSTNet Architecture as described in [1]
(SN) Implements a LSTNet Architecture as described in [1]

[1] Y. Zhang and L. Mikelsons, Solving Stochastic Inverse Problems with Stochastic BayesFlow,
2023 IEEE/ASME International Conference on Advanced Intelligent Mechatronics (AIM),
Expand Down
3 changes: 2 additions & 1 deletion bayesflow/networks/transformers/fusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

@serializable("bayesflow.networks")
class FusionTransformer(SummaryNetwork):
"""Implements a more flexible version of the TimeSeriesTransformer that applies a series of self-attention layers
"""
(SN) Implements a more flexible version of the TimeSeriesTransformer that applies a series of self-attention layers
followed by cross-attention between the representation and a learnable template summarized via a recurrent net."""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/networks/transformers/set_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

@serializable("bayesflow.networks")
class SetTransformer(SummaryNetwork):
"""Implements the set transformer architecture from [1] which ultimately represents
"""(SN) Implements the set transformer architecture from [1] which ultimately represents
a learnable permutation-invariant function. Designed to naturally model interactions in
the input set, which may be hard to capture with the simpler ``DeepSet`` architecture.

Expand Down
6 changes: 3 additions & 3 deletions bayesflow/networks/transformers/time_series_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def __init__(
time_axis: int = None,
**kwargs,
):
"""Creates a regular transformer coupled with Time2Vec embeddings of time used to flexibly compress time series.
If the time intervals vary across batches, it is highly recommended that your simulator also returns a "time"
vector appended to the simulator outputs and specified via the "time_axis" argument.
"""(SN) Creates a regular transformer coupled with Time2Vec embeddings of time used to flexibly compress time
series. If the time intervals vary across batches, it is highly recommended that your simulator also returns a
"time" vector appended to the simulator outputs and specified via the "time_axis" argument.

Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/wrappers/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@serializable("bayesflow.wrappers")
class Mamba(SummaryNetwork):
"""
Wraps a sequence of Mamba modules using the simple Mamba module from:
(SN) Wraps a sequence of Mamba modules using the simple Mamba module from:
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py

Copyright (c) 2023, Tri Dao, Albert Gu.
Expand Down