Skip to content

boschresearch/gresit

Repository files navigation

gRESIT

view - Documentation

Made with Python PyPI - maintained tests - passing License: AGPL v3 Code style: ruff

This repo aims at learning and representing causal graphs based on grouped data. Theoretical details are presented in the paper

@misc{goebler2025,
  title={Nonlinear Causal Discovery for Grouped Data},
  author={Konstantin G\"obler and Tobias Windisch and Mathias Drton},
  year={2025},
  eprint={2506.05120},
  archivePrefix={arXiv},
  primaryClass={stat.ML},
  url={<https://arxiv.org/abs/2506.05120}>,
}

Authors

Maintainer: Martin Roth (Bosch)

Table of contents

The documentation can be found here.

The package can be installed with

pip install gresit

Using the Makefile the package can be installed in an editable way like this:

make sync-venv

To use the pre-commit hooks, one has to enable them in the venv, by

pre-commit install

Then these hooks are excecuted before every commit. You can run the hooks for all files also separately

pre-commit run --all-files

or to disable the pip-compile hook, which takes some time

SKIP=pip-compile pre-commit run --all-files

or equivalent

make pre-commit

Consider the following example. We refer to the documentation for more detailed information.

Generating Synthetic Data

We first generate synthetic data using an Erdős–Rényi random graph model. Each group of variables is defined with a specified size and edge density.

from gresit.synthetic_data import GenERData

data_gen = GenERData(
    number_of_nodes=10,
    group_size=2,
    edge_density=0.2,
)

data_dict, _ = data_gen.generate_data(num_samples=1000)

The output data_dict is a dictionary where each key corresponds to a group, and the values are the observed samples.

Fitting a Graph Model

We now fit a gRESIT model using Multioutcome_MLP as the regressor and HSIC as independence test.

from gresit.group_resit import GroupResit
from gresit.independence_tests import HSIC
from gresit.torch_models import Multioutcome_MLP

model = GroupResit(
    regressor=Multioutcome_MLP(),
    test=HSIC,
    pruning_method="murgs",
)
learned_dag = model.learn_graph(data_dict=data_dict)

# Show the learned graph:
learned_dag.show()
# or show interactive mode:
model.show_interactive()

Accessing the Learned Graph

The learned adjacency matrix representing the estimated group-level graph and a causal ordering can be accessed via:

model.adjacency_matrix
model.causal_ordering

In general we use pytest and the test suite can be executed locally via

python -m pytest

We use mkdocs for building the documentation, this is the corresponding workflow.

Automated issue workflow

With this workflow newly created issues are automatically added to our MFD2 project.

Pre-commit

With this workflow the pre-commit rules, specified in .pre-commit-config.yaml, are executed.

To use pre-commit locally, please use

pre-commit install

Testing

With this workflow the tests are executed.

Runtime dependencies

Name License Type
numpy BSD-3-Clause License Dependency
pandas BSD-3-Clause License Dependency
scikit-learn BSD-3-Clause License Dependency
statsmodels BSD-3-Clause License Dependency
plotly MIT License Dependency
xgboost Apache License 2.0 Dependency
torch BSD-3-Clause License Dependency
seaborn BSD-3-Clause License Dependency
pyspark Apache License 2.0 Dependency
scikit-misc BSD-3-Clause License Dependency
gadjid MIT License Dependency
tqdm MIT License Dependency
dcor MIT License Dependency
llvmlite BSD-2-Clause License Dependency
causal-learn MIT License Dependency
gcastle Apache License 2.0 Dependency
gpytorch MIT License Dependency

Development dependency

Name License Type
mike BSD-3-Clause License Optional
mkdocs BSD-2-Clause License Optional
mkdocs-material MIT License Optional
mkdocstrings ISC License Optional
pip-licenses MIT License Optional
pip-tools BSD-3-Clause License Optional
pre-commit MIT License Optional
pytest MIT License Optional
pytest-cov MIT License Optional
ruff MIT License Optional
uv MIT License Optional