Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Jul 23, 2022
1 parent 2678175 commit 0cd2d55
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 37 deletions.
2 changes: 1 addition & 1 deletion petab/parameter_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def merge_preeq_and_sim_pars_condition(
This function is meant for the case where we cannot have different
parameters (and scales) for preequilibration and simulation. Therefore,
merge both and ensure matching scales and parameters.
``condition_map_sim`` and ``condition_scale_map_sim`` will ne modified in
``condition_map_sim`` and ``condition_scale_map_sim`` will be modified in
place.
Arguments:
Expand Down
54 changes: 32 additions & 22 deletions petab/visualize/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ def __init__(self, figure: Figure, data_provider: DataProvider):
self.data_provider = data_provider

@abstractmethod
def generate_figure(self, subplot_dir: Optional[str] = None
) -> Optional[Dict[str, plt.Subplot]]:
def generate_figure(
self,
subplot_dir: Optional[str] = None
) -> Optional[Dict[str, plt.Subplot]]:
pass


Expand Down Expand Up @@ -65,9 +67,12 @@ def _error_column_for_plot_type_data(plot_type_data: str) -> Optional[str]:
return 'noise_model'
return None

def generate_lineplot(self, ax: 'matplotlib.pyplot.Axes',
dataplot: DataPlot,
plotTypeData: str) -> None:
def generate_lineplot(
self,
ax: 'matplotlib.pyplot.Axes',
dataplot: DataPlot,
plotTypeData: str
) -> None:
"""
Generate lineplot.
Expand All @@ -82,7 +87,6 @@ def generate_lineplot(self, ax: 'matplotlib.pyplot.Axes',
plotTypeData:
Specifies how replicates should be handled.
"""

simu_color = None
measurements_to_plot, simulations_to_plot = \
self.data_provider.get_data_to_plot(dataplot,
Expand Down Expand Up @@ -152,9 +156,12 @@ def generate_lineplot(self, ax: 'matplotlib.pyplot.Axes',
label=label_base + " simulation", color=simu_color
)

def generate_barplot(self, ax: 'matplotlib.pyplot.Axes',
dataplot: DataPlot,
plotTypeData: str) -> None:
def generate_barplot(
self,
ax: 'matplotlib.pyplot.Axes',
dataplot: DataPlot,
plotTypeData: str
) -> None:
"""
Generate barplot.
Expand Down Expand Up @@ -200,9 +207,12 @@ def generate_barplot(self, ax: 'matplotlib.pyplot.Axes',
color='white', edgecolor=color, **bar_kwargs,
label='simulation')

def generate_scatterplot(self, ax: 'matplotlib.pyplot.Axes',
dataplot: DataPlot,
plotTypeData: str) -> None:
def generate_scatterplot(
self,
ax: 'matplotlib.pyplot.Axes',
dataplot: DataPlot,
plotTypeData: str
) -> None:
"""
Generate scatterplot.
Expand All @@ -215,7 +225,6 @@ def generate_scatterplot(self, ax: 'matplotlib.pyplot.Axes',
plotTypeData:
Specifies how replicates should be handled.
"""

measurements_to_plot, simulations_to_plot = \
self.data_provider.get_data_to_plot(dataplot,
plotTypeData == PROVIDED)
Expand All @@ -228,9 +237,11 @@ def generate_scatterplot(self, ax: 'matplotlib.pyplot.Axes',
label=getattr(dataplot, LEGEND_ENTRY))
self._square_plot_equal_ranges(ax)

def generate_subplot(self,
ax,
subplot: Subplot) -> None:
def generate_subplot(
self,
ax: plt.Axes,
subplot: Subplot
) -> None:
"""
Generate subplot based on markup provided by subplot.
Expand All @@ -241,7 +252,6 @@ def generate_subplot(self,
subplot:
Subplot visualization settings.
"""

# set yScale
if subplot.yScale == LIN:
ax.set_yscale("linear")
Expand Down Expand Up @@ -270,7 +280,6 @@ def generate_subplot(self,
for data_plot in subplot.data_plots:
self.generate_scatterplot(ax, data_plot, subplot.plotTypeData)
else:

# set xScale
if subplot.xScale == LIN:
ax.set_xscale("linear")
Expand Down Expand Up @@ -345,7 +354,6 @@ def generate_figure(
None:
In case subplots are saved to file.
"""

if subplot_dir is None:
# compute, how many rows and columns we need for the subplots
num_row = int(np.round(np.sqrt(self.figure.num_subplots)))
Expand All @@ -361,7 +369,7 @@ def generate_figure(
axes = dict(zip([plot.plotId for plot in self.figure.subplots],
axes.flat))

for idx, subplot in enumerate(self.figure.subplots):
for subplot in self.figure.subplots:
if subplot_dir is not None:
fig, ax = plt.subplots(figsize=self.figure.size)
fig.set_tight_layout(True)
Expand Down Expand Up @@ -419,6 +427,8 @@ class SeabornPlotter(Plotter):
def __init__(self, figure: Figure, data_provider: DataProvider):
super().__init__(figure, data_provider)

def generate_figure(self, subplot_dir: Optional[str] = None
) -> Optional[Dict[str, plt.Subplot]]:
def generate_figure(
self,
subplot_dir: Optional[str] = None
) -> Optional[Dict[str, plt.Subplot]]:
pass
34 changes: 20 additions & 14 deletions petab/visualize/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from numbers import Number, Real
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union, Literal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -130,6 +130,9 @@ def from_df(cls, plot_spec: pd.DataFrame):

return cls(vis_spec_dict)

def __repr__(self):
return f"{self.__class__.__name__}({self.__dict__})"


class Subplot:
"""
Expand Down Expand Up @@ -480,9 +483,13 @@ def _get_independent_var_values(self, data_df: pd.DataFrame,

return uni_condition_id, col_name_unique, conditions_

def get_data_series(self, data_df: pd.DataFrame, data_col: str,
dataplot: DataPlot,
provided_noise: bool) -> DataSeries:
def get_data_series(
self,
data_df: pd.DataFrame,
data_col: Literal['measurement', 'simulation'],
dataplot: DataPlot,
provided_noise: bool
) -> DataSeries:
"""
Get data to plot from measurement or simulation DataFrame.
Expand All @@ -499,10 +506,8 @@ def get_data_series(self, data_df: pd.DataFrame, data_col: str,
-------
Data to plot
"""

uni_condition_id, col_name_unique, conditions_ = \
self._get_independent_var_values(data_df,
dataplot)
self._get_independent_var_values(data_df, dataplot)

dataset_id = getattr(dataplot, DATASET_ID)

Expand Down Expand Up @@ -643,8 +648,10 @@ def _data_df(self):
None else self.simulations_data

@staticmethod
def create_subplot(plot_id: str,
subplot_vis_spec: pd.DataFrame) -> Subplot:
def create_subplot(
plot_id: str,
subplot_vis_spec: pd.DataFrame
) -> Subplot:
"""
Create subplot.
Expand All @@ -661,7 +668,6 @@ def create_subplot(plot_id: str,
Subplot
"""

subplot_columns = [col for col in subplot_vis_spec.columns if col in
VISUALIZATION_DF_SUBPLOT_LEVEL_COLS]
subplot = Subplot.from_df(plot_id,
Expand All @@ -677,9 +683,10 @@ def create_subplot(plot_id: str,

return subplot

def parse_from_vis_spec(self,
vis_spec: Optional[Union[str, Path, pd.DataFrame]],
) -> Tuple[Figure, DataProvider]:
def parse_from_vis_spec(
self,
vis_spec: Optional[Union[str, Path, pd.DataFrame]],
) -> Tuple[Figure, DataProvider]:
"""
Get visualization settings from a visualization specification.
Expand All @@ -694,7 +701,6 @@ def parse_from_vis_spec(self,
A figure template with visualization settings and a data provider
"""

# import visualization specification, if file was specified
if isinstance(vis_spec, (str, Path)):
vis_spec = core.get_visualization_df(vis_spec)
Expand Down

0 comments on commit 0cd2d55

Please sign in to comment.