Skip to content

Commit dc0be75

Browse files
authored
Fix sympy symbol name clashes (#202)
Make sure symbols like `N`, `beta`, ... are used as scalar symbols, and not as sympy functions with the same name. See also ICB-DCM/pyPESTO#1048
1 parent bd0176f commit dc0be75

File tree

4 files changed

+26
-9
lines changed

4 files changed

+26
-9
lines changed

petab/calculate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
import numpy as np
88
import pandas as pd
9-
import petab
109
import sympy
10+
from sympy.abc import _clash
1111

12+
import petab
1213
from .C import *
1314

1415
__all__ = ['calculate_residuals', 'calculate_residuals_for_table',
@@ -138,7 +139,7 @@ def get_symbolic_noise_formulas(observable_df) -> Dict[str, sympy.Expr]:
138139
if NOISE_FORMULA not in observable_df.columns:
139140
noise_formula = None
140141
else:
141-
noise_formula = sympy.sympify(row.noiseFormula)
142+
noise_formula = sympy.sympify(row.noiseFormula, locals=_clash)
142143
noise_formulas[observable_id] = noise_formula
143144
return noise_formulas
144145

petab/lint.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@
44
import logging
55
import numbers
66
import re
7-
from typing import Optional, Iterable, Any
87
from collections import Counter
8+
from typing import Any, Iterable, Optional
99

1010
import numpy as np
1111
import pandas as pd
1212
import sympy as sp
13+
from sympy.abc import _clash
1314

1415
import petab
15-
from . import (core, parameters, measurements)
16-
from .models import Model
16+
from . import (core, measurements, parameters)
1717
from .C import * # noqa: F403
18+
from .models import Model
1819

1920
logger = logging.getLogger(__name__)
2021
__all__ = ['assert_all_parameters_present_in_parameter_df',
@@ -287,15 +288,15 @@ def check_observable_df(observable_df: pd.DataFrame) -> None:
287288
for row in observable_df.itertuples():
288289
obs = getattr(row, OBSERVABLE_FORMULA)
289290
try:
290-
sp.sympify(obs)
291+
sp.sympify(obs, locals=_clash)
291292
except sp.SympifyError as e:
292293
raise AssertionError(
293294
f"Cannot parse expression '{obs}' "
294295
f"for observable {row.Index}: {e}") from e
295296

296297
noise = getattr(row, NOISE_FORMULA)
297298
try:
298-
sympified_noise = sp.sympify(noise)
299+
sympified_noise = sp.sympify(noise, locals=_clash)
299300
if sympified_noise is None \
300301
or (sympified_noise.is_Number
301302
and not sympified_noise.is_finite):

petab/observables.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import re
44
from collections import OrderedDict
55
from pathlib import Path
6-
from typing import List, Union, Literal
6+
from typing import List, Literal, Union
77

88
import pandas as pd
99
import sympy as sp
10+
from sympy.abc import _clash
1011

1112
from . import core, lint
1213
from .C import * # noqa: F403
@@ -97,7 +98,7 @@ def get_output_parameters(
9798
output_parameters = OrderedDict()
9899

99100
for formula in formulas:
100-
free_syms = sorted(sp.sympify(formula).free_symbols,
101+
free_syms = sorted(sp.sympify(formula, locals=_clash).free_symbols,
101102
key=lambda symbol: symbol.name)
102103
for free_sym in free_syms:
103104
sym = str(free_sym)

tests/test_observables.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,20 @@ def test_get_output_parameters():
8383

8484
assert output_parameters == ['offset', 'scaling']
8585

86+
# test sympy-special symbols (e.g. N, beta, ...)
87+
# see https://github.com/ICB-DCM/pyPESTO/issues/1048
88+
observable_df = pd.DataFrame(data={
89+
OBSERVABLE_ID: ['observable_1'],
90+
OBSERVABLE_NAME: ['observable name 1'],
91+
OBSERVABLE_FORMULA: ['observable_1 * N + beta'],
92+
NOISE_FORMULA: [1],
93+
}).set_index(OBSERVABLE_ID)
94+
95+
output_parameters = petab.get_output_parameters(
96+
observable_df, SbmlModel(sbml_model=ss_model.model))
97+
98+
assert output_parameters == ['N', 'beta']
99+
86100

87101
def test_get_formula_placeholders():
88102
"""Test get_formula_placeholders"""

0 commit comments

Comments
 (0)