Skip to content

Commit

Permalink
Add traces with subplots cleanup
Browse files Browse the repository at this point in the history
  - Add optional row/col params to add_traces
  - Add singular add_trace method with optional row/col params
  - Deprecate append_trace and remap to add_trace
  - Add row/col paras to the add_* figure methods
  • Loading branch information
Jon M. Mease committed Apr 14, 2018
1 parent 2455dc1 commit c6d6540
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 20 deletions.
21 changes: 19 additions & 2 deletions codegen/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def reindent_validator_description(validator, extra_indent):
validator.description().strip().split('\n'))


def add_constructor_params(buffer, subtype_nodes):
def add_constructor_params(buffer, subtype_nodes, extras=()):
"""
Write datatype constructor params to a buffer
Expand All @@ -267,6 +267,8 @@ def add_constructor_params(buffer, subtype_nodes):
Buffer to write to
subtype_nodes : list of PlotlyNode
List of datatype nodes to be written as constructor params
extras : list[str]
List of extra parameters to include at the end of the params
Returns
-------
None
Expand All @@ -275,13 +277,17 @@ def add_constructor_params(buffer, subtype_nodes):
buffer.write(f""",
{subtype_node.name_property}=None""")

for extra in extras:
buffer.write(f""",
{extra}=None""")

buffer.write(""",
**kwargs""")
buffer.write(f"""
):""")


def add_docstring(buffer, node, header):
def add_docstring(buffer, node, header, extras=()):
"""
Write docstring for a compound datatype node
Expand Down Expand Up @@ -328,6 +334,17 @@ def add_docstring(buffer, node, header):
buffer.write(node.get_constructor_params_docstring(
indent=8))

# Write any extras
for p, v in extras:
v_wrapped = '\n'.join(textwrap.wrap(
v,
width=79-12,
initial_indent=' ' * 12,
subsequent_indent=' ' * 12))
buffer.write(f"""
{p}
{v_wrapped}""")

# Write return block and close docstring
# --------------------------------------
buffer.write(f"""
Expand Down
17 changes: 14 additions & 3 deletions codegen/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,22 @@ def __init__(self, data=None, layout=None, frames=None):
def add_{trace_node.plotly_name}(self""")

# #### Function params####
add_constructor_params(buffer, trace_node.child_datatypes)
add_constructor_params(buffer, trace_node.child_datatypes,
['row', 'col'])

# #### Docstring ####
header = f"Add a new {trace_node.name_datatype_class} trace"
add_docstring(buffer, trace_node, header)

extras = (('row : int or None (default)',
'Subplot row index (starting from 1) for the trace to be '
'added. Only valid if figure was created using '
'`plotly.tools.make_subplots`'),
('col : int or None (default)',
'Subplot col index (starting from 1) for the trace to be '
'added. Only valid if figure was created using '
'`plotly.tools.make_subplots`'))

add_docstring(buffer, trace_node, header, extras=extras)

# #### Function body ####
buffer.write(f"""
Expand All @@ -111,7 +122,7 @@ def add_{trace_node.plotly_name}(self""")
**kwargs)""")

buffer.write(f"""
return self.add_traces(new_trace)[0]""")
return self.add_trace(new_trace, row=row, col=col)""")

# Return source string
# --------------------
Expand Down
180 changes: 165 additions & 15 deletions plotly/basedatatypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import re
import typing as typ
import warnings
from contextlib import contextmanager
from copy import deepcopy
from typing import Dict, Tuple, Union, Callable, List
Expand Down Expand Up @@ -791,14 +792,114 @@ def _set_in(d, key_path_str, v):

# Add traces
# ----------
def add_traces(self, data):
@staticmethod
def _raise_invalid_rows_cols(name, n, invalid):
rows_err_msg = """
If specified, the {name} parameter must be a list or tuple of integers
of length {n} (The number of traces being added)
Received: {invalid}
""".format(name=name, n=n, invalid=invalid)

raise ValueError(rows_err_msg)

@staticmethod
def _validate_rows_cols(name, n, vals):
if vals is None:
pass
elif isinstance(vals, (list, tuple)):
if len(vals) != n:
BaseFigure._raise_invalid_rows_cols(
name=name, n=n, invalid=vals)

if [r for r in vals if not isinstance(r, int)]:
BaseFigure._raise_invalid_rows_cols(
name=name, n=n, invalid=vals)
else:
BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals)

def add_trace(self, trace, row=None, col=None):
"""
Add one or more traces to the figure
Add a trace to the figure
Parameters
----------
data : BaseTraceType or dict or list[BaseTraceType or dict]
A trace specification or list of trace specifications to be added.
trace : BaseTraceType or dict
Either:
- An instances of a trace classe from the plotly.graph_objs
package (e.g plotly.graph_objs.Scatter, plotly.graph_objs.Bar)
- or a dicts where:
- The 'type' property specifies the trace type (e.g.
'scatter', 'bar', 'area', etc.). If the dict has no 'type'
property then 'scatter' is assumed.
- All remaining properties are passed to the constructor
of the specified trace type.
row : int or None (default)
Subplot row index (starting from 1) for the trace to be added.
Only valid if figure was created using
`plotly.tools.make_subplots`
col : int or None (default)
Subplot col index (starting from 1) for the trace to be added.
Only valid if figure was created using
`plotly.tools.make_subplots`
Returns
-------
BaseTraceType
The newly added trace
Examples
--------
>>> from plotly import tools
>>> import plotly.graph_objs as go
Add two Scatter traces to a figure
>>> fig = go.Figure()
>>> fig.add_trace(go.Scatter(x=[1,2,3], y=[2,1,2]))
>>> fig.add_trace(go.Scatter(x=[1,2,3], y=[2,1,2]))
Add two Scatter traces to vertically stacked subplots
>>> fig = tools.make_subplots(rows=2)
This is the format of your plot grid:
[ (1,1) x1,y1 ]
[ (2,1) x2,y2 ]
>>> fig.add_trace(go.Scatter(x=[1,2,3], y=[2,1,2]), row=1, col=1)
>>> fig.add_trace(go.Scatter(x=[1,2,3], y=[2,1,2]), row=2, col=1)
"""
# Validate row/col
if row is not None and not isinstance(row, int):
pass

if col is not None and not isinstance(col, int):
pass

# Make sure we have both row and col or neither
if row is not None and col is None:
raise ValueError(
'Received row parameter but not col.\n'
'row and col must be specified together')
elif col is not None and row is None:
raise ValueError(
'Received col parameter but not row.\n'
'row and col must be specified together')

return self.add_traces(data=[trace],
rows=[row] if row is not None else None,
cols=[col] if col is not None else None
)[0]

def add_traces(self, data, rows=None, cols=None):
"""
Add traces to the figure
Parameters
----------
data : list[BaseTraceType or dict]
A list of trace specifications to be added.
Trace specifications may be either:
- Instances of trace classes from the plotly.graph_objs
Expand All @@ -810,23 +911,70 @@ def add_traces(self, data):
property then 'scatter' is assumed.
- All remaining properties are passed to the constructor
of the specified trace type.
rows : None or list[int] (default None)
List of subplot row indexes (starting from 1) for the traces to be
added. Only valid if figure was created using
`plotly.tools.make_subplots`
cols : None or list[int] (default None)
List of subplot column indexes (starting from 1) for the traces
to be added. Only valid if figure was created using
`plotly.tools.make_subplots`
Returns
-------
tuple[BaseTraceType]
Tuple of the newly added trace(s)
Tuple of the newly added traces
Examples
--------
>>> from plotly import tools
>>> import plotly.graph_objs as go
Add two Scatter traces to a figure
>>> fig = go.Figure()
>>> fig.add_traces([go.Scatter(x=[1,2,3], y=[2,1,2]),
... go.Scatter(x=[1,2,3], y=[2,1,2])])
Add two Scatter traces to vertically stacked subplots
>>> fig = tools.make_subplots(rows=2)
This is the format of your plot grid:
[ (1,1) x1,y1 ]
[ (2,1) x2,y2 ]
>>> fig.add_traces([go.Scatter(x=[1,2,3], y=[2,1,2]),
... go.Scatter(x=[1,2,3], y=[2,1,2])],
... rows=[1, 2], cols=[1, 1])
"""

if self._in_batch_mode:
self._batch_layout_edits.clear()
self._batch_trace_edits.clear()
raise ValueError('Traces may not be added in a batch context')

if not isinstance(data, (list, tuple)):
data = [data]

# Validate
# Validate traces
data = self._data_validator.validate_coerce(data)

# Validate rows / cols
n = len(data)
BaseFigure._validate_rows_cols('rows', n, rows)
BaseFigure._validate_rows_cols('cols', n, cols)

# Make sure we have both rows and cols or neither
if rows is not None and cols is None:
raise ValueError(
'Received rows parameter but not cols.\n'
'rows and cols must be specified together')
elif cols is not None and rows is None:
raise ValueError(
'Received cols parameter but not rows.\n'
'rows and cols must be specified together')

# Apply rows / cols
if rows is not None:
for trace, row, col in zip(data, rows, cols):
self._set_trace_grid_position(trace, row, col)

# Make deep copy of trace data (Optimize later if needed)
new_traces_data = [deepcopy(trace._props) for trace in data]

Expand Down Expand Up @@ -877,10 +1025,6 @@ def append_trace(self, trace, row, col):
col: int
Subplot column index (see Figure.print_grid)
:param (dict) trace: The data trace to be bound.
:param (int) row: Subplot row index (see Figure.print_grid).
:param (int) col: Subplot column index (see Figure.print_grid).
Examples
--------
>>> from plotly import tools
Expand All @@ -894,6 +1038,14 @@ def append_trace(self, trace, row, col):
>>> fig.append_trace(go.Scatter(x=[1,2,3], y=[2,1,2]), row=1, col=1)
>>> fig.append_trace(go.Scatter(x=[1,2,3], y=[2,1,2]), row=2, col=1)
"""
warnings.warn("""\
The append_trace method is deprecated and will be removed in a future version.
Please use the add_trace method with the row and col parameters.
""", DeprecationWarning)

self.add_trace(trace=trace, row=row, col=col)

def _set_trace_grid_position(self, trace, row, col):
try:
grid_ref = self._grid_ref
except AttributeError:
Expand Down Expand Up @@ -931,8 +1083,6 @@ def append_trace(self, trace, row, col):
trace['xaxis'] = ref[0]
trace['yaxis'] = ref[1]

self.add_traces([trace])

# Child property operations
# -------------------------
def _get_child_props(self, child):
Expand Down

0 comments on commit c6d6540

Please sign in to comment.