Skip to content

Commit

Permalink
Use class number as columns
Browse files Browse the repository at this point in the history
This is more common in the field and expected by users.
  • Loading branch information
sachaMorin committed Aug 15, 2023
1 parent 8f2b262 commit 14ba163
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
6 changes: 4 additions & 2 deletions stepmix/emission/emission.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def permute_classes(self, perm):
self.parameters[key] = item[perm]

def print_parameters(
self, indent=1, feature_names=None, index=["param", "class_no"], model_name=None
self, indent=1, feature_names=None, index=["param", "variable"], columns=["model_name", "class_no"], model_name=None
):
"""Print parameters with nice formatting.
Expand All @@ -195,6 +195,8 @@ def print_parameters(
Variable names.
index: List of str
Column names in self.get_parameters_df to use as index in the displayed dataframe.
columns: List of str
Column names in self.get_parameters_df to use as columns in the displayed dataframe.
model_name: str
str to display as model name.
"""
Expand All @@ -203,7 +205,7 @@ def print_parameters(
if model_name is not None:
df["model_name"] = model_name
df = pd.pivot_table(
df, index=index, columns=["model_name", "variable"], values="value"
df, index=index, columns=columns, values="value"
)
print(
indent_string + df.round(4).to_string().replace("\n", "\n" + indent_string)
Expand Down
12 changes: 6 additions & 6 deletions stepmix/emission/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,11 @@ def __init__(self, **kwargs):
self.model_str = "gaussian_full"

def print_parameters(
self, indent=1, feature_names=None, index=["class_no", "param"]
self, indent=1, feature_names=None, index=["class_no", "param"], columns=["model_name", "variable"]
):
"""Flipping class_no and index is nicer for full covariances."""
"""Flipping class_no and variable is nicer for full covariances."""
super().print_parameters(
indent=indent, feature_names=feature_names, index=index
indent=indent, feature_names=feature_names, index=index, columns=columns
)

def get_parameters_df(self, feature_names=None):
Expand Down Expand Up @@ -267,11 +267,11 @@ def __init__(self, **kwargs):
self.model_str = "gaussian_tied"

def print_parameters(
self, indent=1, feature_names=None, index=["class_no", "param"]
self, indent=1, feature_names=None, index=["class_no", "param"], columns=["model_name", "variable"]
):
"""Flipping class_no and index is nicer for full covariances."""
"""Flipping class_no and variable is nicer for full covariances."""
super().print_parameters(
indent=indent, feature_names=feature_names, index=index
indent=indent, feature_names=feature_names, index=index, columns=columns
)

def get_parameters_df(self, feature_names=None):
Expand Down

0 comments on commit 14ba163

Please sign in to comment.