Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
Pipfile.lock

# Poetry
poetry.lock

# PEP 582
__pypackages__/

# Celery
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# IDEs
.idea/
.vscode/
*.swp
*.swo
*~
.DS_Store

# Claude settings
.claude/*
116 changes: 116 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
[tool.poetry]
name = "lightweight_mmm"
version = "0.1.9"
description = "Package for Media-Mix-Modelling"
authors = ["Google LLC <no-reply@google.com>"]
license = "Apache-2.0"
readme = "README.md"
homepage = "https://github.com/google/lightweight_mmm"
repository = "https://github.com/google/lightweight_mmm"
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Mathematics",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12"
]
packages = [{include = "lightweight_mmm"}]

[tool.poetry.dependencies]
python = "^3.8"
absl-py = "*"
arviz = ">=0.11.2"
immutabledict = ">=2.0.0"
jax = ">=0.3.18"
jaxlib = ">=0.3.18"
matplotlib = "==3.6.1"
numpy = ">=1.21.0"
numpyro = ">=0.9.2"
pandas = ">=1.1.5"
scipy = "*"
seaborn = "==0.11.1"
scikit-learn = "*"
statsmodels = ">=0.13.0"
tensorflow = ">=2.7.2"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
pytest-cov = "^4.1.0"
pytest-mock = "^3.11.1"
pytest-xdist = "^3.3.1"

[tool.poetry.scripts]
test = "pytest:main"
tests = "pytest:main"

[tool.pytest.ini_options]
minversion = "7.0"
testpaths = ["tests"]
addopts = [
"-ra",
"--strict-markers",
"--cov=lightweight_mmm",
"--cov-report=term-missing:skip-covered",
"--cov-report=html",
"--cov-report=xml",
"--cov-fail-under=80",
"-vv",
"--tb=short",
"--maxfail=3"
]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
markers = [
"unit: marks tests as unit tests (fast, isolated)",
"integration: marks tests as integration tests (may be slower)",
"slow: marks tests as slow (deselect with '-m \"not slow\"')"
]
filterwarnings = [
"ignore::DeprecationWarning",
"ignore::PendingDeprecationWarning"
]

[tool.coverage.run]
source = ["lightweight_mmm"]
branch = true
omit = [
"*/tests/*",
"*/__pycache__/*",
"*/conftest.py",
"*/setup.py",
"*/.venv/*",
"*/venv/*"
]

[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if self.debug:",
"if __name__ == .__main__.:",
"raise AssertionError",
"raise NotImplementedError",
"if 0:",
"if False:",
"class .*\\bProtocol\\):",
"@(abc\\.)?abstractmethod"
]
show_missing = true
precision = 2
fail_under = 80

[tool.coverage.html]
directory = "htmlcov"

[tool.coverage.xml]
output = "coverage.xml"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Empty file added tests/__init__.py
Empty file.
119 changes: 119 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Shared pytest fixtures and configuration for lightweight_mmm tests."""

import os
import tempfile
from pathlib import Path
from typing import Generator

import numpy as np
import pandas as pd
import pytest


@pytest.fixture
def temp_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for test files."""
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)


@pytest.fixture
def mock_data():
"""Create mock data for testing MMM models."""
np.random.seed(42)
n_time_periods = 52
n_media_channels = 3
n_geos = 2

# Generate synthetic data
data = {
'date': pd.date_range('2023-01-01', periods=n_time_periods, freq='W'),
'sales': np.random.poisson(1000, n_time_periods) + np.random.normal(0, 50, n_time_periods),
}

# Add media spend data
for i in range(n_media_channels):
data[f'media_{i}'] = np.random.exponential(1000, n_time_periods)

# Add geo data
for i in range(n_geos):
data[f'geo_{i}_sales'] = np.random.poisson(500, n_time_periods)

return pd.DataFrame(data)


@pytest.fixture
def mock_config():
"""Create a mock configuration dictionary."""
return {
'n_media_channels': 3,
'n_geos': 2,
'model_type': 'adstock',
'priors': {
'intercept': {'mean': 0, 'std': 1},
'coef_media': {'mean': 0, 'std': 0.1},
},
'hyperparameters': {
'learning_rate': 0.001,
'n_iterations': 1000,
'batch_size': 32,
}
}


@pytest.fixture
def sample_media_data():
"""Generate sample media spend data."""
np.random.seed(123)
return np.random.rand(52, 3) * 10000 # 52 weeks, 3 channels


@pytest.fixture
def sample_target_data():
"""Generate sample target (sales) data."""
np.random.seed(123)
base_sales = 10000
trend = np.linspace(0, 1000, 52)
seasonality = 500 * np.sin(np.linspace(0, 4 * np.pi, 52))
noise = np.random.normal(0, 200, 52)
return base_sales + trend + seasonality + noise


@pytest.fixture(autouse=True)
def reset_random_seed():
"""Reset random seeds before each test for reproducibility."""
np.random.seed(42)
import random
random.seed(42)

# Reset JAX random seed if JAX is available
try:
import jax
jax.random.PRNGKey(42)
except ImportError:
pass


@pytest.fixture
def mock_model_params():
"""Create mock model parameters."""
return {
'intercept': np.array([1000.0]),
'coef_media': np.array([0.1, 0.2, 0.15]),
'coef_trend': np.array([10.0]),
'saturation_parameters': {
'alphas': np.array([2.0, 1.5, 2.5]),
'betas': np.array([0.5, 0.6, 0.4])
},
'adstock_parameters': {
'convolve_window': 3,
'decay_rates': np.array([0.3, 0.4, 0.35])
}
}


@pytest.fixture
def capture_logs(caplog):
"""Fixture to capture and assert log messages."""
with caplog.at_level('DEBUG'):
yield caplog
Empty file added tests/integration/__init__.py
Empty file.
Loading