Skip to content

Commit 5a87f83

Browse files
dweindldilpath
andauthored
get_parameter_df: Allow any collection of parameter tables (#153)
not just lists. Co-authored-by: Dilan Pathirana <59329744+dilpath@users.noreply.github.com>
1 parent 12a07b9 commit 5a87f83

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

petab/parameters.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,38 +28,38 @@
2828

2929

3030
def get_parameter_df(
31-
parameter_file: Union[str, Path, List[str], pd.DataFrame, None]
32-
) -> pd.DataFrame:
31+
parameter_file: Union[str, Path, pd.DataFrame,
32+
Iterable[Union[str, Path, pd.DataFrame]], None]
33+
) -> Union[pd.DataFrame, None]:
3334
"""
3435
Read the provided parameter file into a ``pandas.Dataframe``.
3536
3637
Arguments:
37-
parameter_file: Name of the file to read from or pandas.Dataframe.
38+
parameter_file: Name of the file to read from or pandas.Dataframe,
39+
or an Iterable.
3840
3941
Returns:
40-
Parameter DataFrame
42+
Parameter ``DataFrame``, or ``None`` if ``None`` was passed.
4143
"""
4244
if parameter_file is None:
43-
return parameter_file
44-
45-
parameter_df = None
46-
45+
return None
4746
if isinstance(parameter_file, pd.DataFrame):
4847
parameter_df = parameter_file
49-
50-
if isinstance(parameter_file, (str, Path)):
48+
elif isinstance(parameter_file, (str, Path)):
5149
parameter_df = pd.read_csv(parameter_file, sep='\t',
5250
float_precision='round_trip')
51+
elif isinstance(parameter_file, Iterable):
52+
dfs = [get_parameter_df(x) for x in parameter_file if x]
5353

54-
if isinstance(parameter_file, list):
55-
parameter_df = pd.concat([pd.read_csv(subset_file, sep='\t',
56-
float_precision='round_trip')
57-
for subset_file in parameter_file])
54+
if not dfs:
55+
return None
56+
57+
parameter_df = pd.concat(dfs)
5858
# Remove identical parameter definitions
59-
parameter_df.drop_duplicates(inplace=True, ignore_index=True)
59+
parameter_df.drop_duplicates(inplace=True, ignore_index=False)
6060
# Check for contradicting parameter definitions
61-
parameter_duplicates = set(parameter_df[PARAMETER_ID].loc[
62-
parameter_df[PARAMETER_ID].duplicated()])
61+
parameter_duplicates = set(parameter_df.index.values[
62+
parameter_df.index.duplicated()])
6363
if parameter_duplicates:
6464
raise ValueError(
6565
f'The values of {PARAMETER_ID} must be unique or'
@@ -68,6 +68,8 @@ def get_parameter_df(
6868
f'{parameter_duplicates}'
6969
)
7070

71+
return parameter_df
72+
7173
lint.assert_no_leading_trailing_whitespace(
7274
parameter_df.columns.values, "parameter")
7375

@@ -76,9 +78,9 @@ def get_parameter_df(
7678

7779
try:
7880
parameter_df.set_index([PARAMETER_ID], inplace=True)
79-
except KeyError:
81+
except KeyError as e:
8082
raise KeyError(
81-
f"Parameter table missing mandatory field {PARAMETER_ID}.")
83+
f"Parameter table missing mandatory field {PARAMETER_ID}.") from e
8284

8385
return parameter_df
8486

0 commit comments

Comments
 (0)