Skip to content

Commit 38d7977

Browse files
committed
refactor
1 parent 5ed34de commit 38d7977

File tree

2 files changed

+287
-86
lines changed

2 files changed

+287
-86
lines changed

src/dvc_render/vega.py

Lines changed: 133 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .base import Renderer
1010
from .utils import list_dict_to_dict_list
11-
from .vega_templates import BadTemplateError, LinearTemplate, get_template
11+
from .vega_templates import BadTemplateError, LinearTemplate, Template, get_template
1212

1313

1414
class VegaRenderer(Renderer):
@@ -58,14 +58,12 @@ def __init__(self, datapoints: List, name: str, **properties):
5858
],
5959
"shape": ["square", "circle", "triangle", "diamond"],
6060
}
61-
self._optional_anchor_values: Dict[
62-
str,
63-
Dict[str, Dict[str, str]],
64-
] = defaultdict(dict)
61+
62+
self._split_content: Dict[str, Any] = {}
6563

6664
def get_filled_template(
6765
self,
68-
skip_anchors: Optional[List[str]] = None,
66+
split_anchors: Optional[List[str]] = None,
6967
strict: bool = True,
7068
as_string: bool = True,
7169
) -> Union[str, Dict[str, Any]]:
@@ -74,8 +72,8 @@ def get_filled_template(
7472
if not self.datapoints:
7573
return {}
7674

77-
if skip_anchors is None:
78-
skip_anchors = []
75+
if split_anchors is None:
76+
split_anchors = []
7977

8078
if strict:
8179
if self.properties.get("x"):
@@ -91,15 +89,18 @@ def get_filled_template(
9189
self.properties.setdefault("y_label", self.properties.get("y"))
9290
self.properties.setdefault("data", self.datapoints)
9391

94-
self._process_optional_anchors(skip_anchors)
92+
self._process_optional_anchors(split_anchors)
9593

9694
names = ["title", "x", "y", "x_label", "y_label", "data"]
9795
for name in names:
98-
if name in skip_anchors:
99-
continue
10096
value = self.properties.get(name)
10197
if value is None:
10298
continue
99+
100+
if name in split_anchors:
101+
self._set_split_content(name, value)
102+
continue
103+
103104
if name == "data":
104105
if not self.template.has_anchor(name):
105106
anchor = self.template.anchor(name)
@@ -116,6 +117,15 @@ def get_filled_template(
116117

117118
return self.template.content
118119

120+
def get_partial_filled_template(self):
121+
"""
122+
Returns a partially filled template along with the split out anchor content
123+
"""
124+
content = self.get_filled_template(
125+
split_anchors=["data", "color", "stroke_dash", "shape"], strict=True
126+
)
127+
return content, self._split_content
128+
119129
def partial_html(self, **kwargs) -> str:
120130
return self.get_filled_template() # type: ignore
121131

@@ -164,7 +174,7 @@ def generate_markdown(self, report_path=None) -> str:
164174

165175
return ""
166176

167-
def _process_optional_anchors(self, skip_anchors: List[str]):
177+
def _process_optional_anchors(self, split_anchors: List[str]):
168178
optional_anchors = [
169179
anchor
170180
for anchor in [
@@ -177,79 +187,85 @@ def _process_optional_anchors(self, skip_anchors: List[str]):
177187
]
178188
if self.template.has_anchor(anchor)
179189
]
180-
if optional_anchors:
181-
# split varied_keys out from _fill_optional_anchors to avoid bugs
182-
# but first.... tests
183-
varied_keys = self._fill_optional_anchors(skip_anchors, optional_anchors)
184-
self._update_datapoints(varied_keys)
185-
186-
def _fill_optional_anchors(
187-
self, skip_anchors: List[str], optional_anchors: List[str]
188-
) -> List[str]:
189-
self._fill_color(skip_anchors, optional_anchors)
190-
191190
if not optional_anchors:
192-
return []
191+
return
193192

194193
y_defn = self.properties.get("anchors_y_defn", [])
194+
is_single_source = len(y_defn) <= 1
195195

196-
if len(y_defn) <= 1:
197-
self._fill_optional_anchor(
198-
skip_anchors, optional_anchors, "group_by", ["rev"]
199-
)
200-
self._fill_optional_anchor(
201-
skip_anchors, optional_anchors, "pivot_field", "datum.rev"
202-
)
203-
for anchor in optional_anchors:
204-
self.template.fill_anchor(anchor, {})
205-
return []
196+
if is_single_source:
197+
self._process_single_source_plot(split_anchors, optional_anchors)
198+
return
199+
200+
self._process_multi_source_plot(split_anchors, optional_anchors, y_defn)
201+
202+
def _process_single_source_plot(
203+
self, split_anchors: List[str], optional_anchors: List[str]
204+
):
205+
self._fill_color(split_anchors, optional_anchors)
206+
self._fill_optional_anchor(split_anchors, optional_anchors, "group_by", ["rev"])
207+
self._fill_optional_anchor(
208+
split_anchors, optional_anchors, "pivot_field", "datum.rev"
209+
)
210+
for anchor in optional_anchors:
211+
self.template.fill_anchor(anchor, {})
212+
213+
self._update_datapoints([])
214+
215+
def _process_multi_source_plot(
216+
self,
217+
split_anchors: List[str],
218+
optional_anchors: List[str],
219+
y_defn: List[Dict[str, str]],
220+
):
221+
varied_keys, varied_values = self._collect_variations(y_defn)
222+
domain = self._get_domain(varied_keys, varied_values, y_defn)
223+
224+
self._fill_optional_multi_source_anchors(
225+
split_anchors, optional_anchors, varied_keys, domain
226+
)
227+
self._update_datapoints(varied_keys)
228+
229+
def _fill_optional_multi_source_anchors(
230+
self,
231+
split_anchors: List[str],
232+
optional_anchors: List[str],
233+
varied_keys: List[str],
234+
domain: List[str],
235+
):
236+
self._fill_color(split_anchors, optional_anchors)
237+
238+
if not optional_anchors:
239+
return
206240

207-
varied_keys, variations = self._collect_variations(y_defn)
208241
grouped_keys = ["rev", *varied_keys]
209-
concat_field = "::".join(varied_keys)
210242
self._fill_optional_anchor(
211-
skip_anchors, optional_anchors, "group_by", grouped_keys
243+
split_anchors, optional_anchors, "group_by", grouped_keys
212244
)
213245
self._fill_optional_anchor(
214-
skip_anchors,
246+
split_anchors,
215247
optional_anchors,
216248
"pivot_field",
217249
" + '::' + ".join([f"datum.{key}" for key in grouped_keys]),
218250
)
219-
# concatenate grouped_keys together
220-
self._fill_optional_anchor(
221-
skip_anchors, optional_anchors, "row", {"field": concat_field}
222-
)
223-
224-
if not optional_anchors:
225-
return varied_keys
226251

227-
if len(varied_keys) == 2:
228-
domain = ["::".join([d.get("filename"), d.get("field")]) for d in y_defn]
229-
else:
230-
filenameOrField = varied_keys[0]
231-
domain = list(variations[filenameOrField])
232-
233-
domain.sort()
234-
235-
stroke_dash_scale = self._set_optional_anchor_scale(
236-
optional_anchors, concat_field, "stroke_dash", domain
237-
)
252+
concat_field = "::".join(varied_keys)
238253
self._fill_optional_anchor(
239-
skip_anchors, optional_anchors, "stroke_dash", stroke_dash_scale
254+
split_anchors, optional_anchors, "row", {"field": concat_field}
240255
)
241256

242-
shape_scale = self._set_optional_anchor_scale(
243-
optional_anchors, concat_field, "shape", domain
244-
)
245-
self._fill_optional_anchor(skip_anchors, optional_anchors, "shape", shape_scale)
257+
if not optional_anchors:
258+
return
246259

247-
return varied_keys
260+
for field in ["stroke_dash", "shape"]:
261+
self._fill_optional_anchor_mapping(
262+
split_anchors, optional_anchors, concat_field, field, domain
263+
)
248264

249-
def _fill_color(self, skip_anchors: List[str], optional_anchors: List[str]):
265+
def _fill_color(self, split_anchors: List[str], optional_anchors: List[str]):
250266
all_revs = self.properties.get("anchor_revs", [])
251267
self._fill_optional_anchor(
252-
skip_anchors,
268+
split_anchors,
253269
optional_anchors,
254270
"color",
255271
{
@@ -266,15 +282,15 @@ def _fill_color(self, skip_anchors: List[str], optional_anchors: List[str]):
266282
def _collect_variations(
267283
self, y_defn: List[Dict[str, str]]
268284
) -> Tuple[List[str], Dict[str, set]]:
269-
variations = defaultdict(set)
285+
varied_values = defaultdict(set)
270286
for defn in y_defn:
271287
for key in ["filename", "field"]:
272-
variations[key].add(defn.get(key, None))
288+
varied_values[key].add(defn.get(key, None))
273289

274290
values_match_variations = []
275291
less_values_than_variations = []
276292

277-
for filenameOrField, valueSet in variations.items():
293+
for filenameOrField, valueSet in varied_values.items():
278294
num_values = len(valueSet)
279295
if num_values == 1:
280296
continue
@@ -286,14 +302,14 @@ def _collect_variations(
286302
if values_match_variations:
287303
values_match_variations.extend(less_values_than_variations)
288304
values_match_variations.sort(reverse=True)
289-
return values_match_variations, variations
305+
return values_match_variations, varied_values
290306

291307
less_values_than_variations.sort(reverse=True)
292-
return less_values_than_variations, variations
308+
return less_values_than_variations, varied_values
293309

294310
def _fill_optional_anchor(
295311
self,
296-
skip_anchors: List[str],
312+
split_anchors: List[str],
297313
optional_anchors: List[str],
298314
name: str,
299315
value: Any,
@@ -303,26 +319,63 @@ def _fill_optional_anchor(
303319

304320
optional_anchors.remove(name)
305321

306-
if name in skip_anchors:
322+
if name in split_anchors:
307323
return
308324

309325
self.template.fill_anchor(name, value)
310326

311-
def _set_optional_anchor_scale(
312-
self, optional_anchors: List[str], field: str, name: str, domain: List[str]
327+
def _get_domain(
328+
self,
329+
varied_keys: List[str],
330+
varied_values: Dict[str, set],
331+
y_defn: List[Dict[str, str]],
332+
):
333+
if len(varied_keys) == 2:
334+
domain = [
335+
"::".join([d.get("filename", ""), d.get("field", "")]) for d in y_defn
336+
]
337+
else:
338+
filenameOrField = varied_keys[0]
339+
domain = list(varied_values[filenameOrField])
340+
341+
domain.sort()
342+
return domain
343+
344+
def _fill_optional_anchor_mapping(
345+
self,
346+
split_anchors: List[str],
347+
optional_anchors: List[str],
348+
field: str,
349+
name: str,
350+
domain: List[str],
313351
):
314352
if name not in optional_anchors:
315-
return {"field": field, "scale": {"domain": [], "range": []}}
353+
return
354+
355+
optional_anchors.remove(name)
356+
357+
encoding = self._get_optional_anchor_mapping(field, name, domain)
316358

359+
if name in split_anchors:
360+
self._set_split_content(name, encoding)
361+
return
362+
363+
self.template.fill_anchor(name, encoding)
364+
365+
def _get_optional_anchor_mapping(
366+
self,
367+
field: str,
368+
name: str,
369+
domain: List[str],
370+
):
317371
full_range_values: List[Any] = self._optional_anchor_ranges.get(name, [])
318372
anchor_range_values = full_range_values.copy()
319-
anchor_range = []
320373

321-
for domain_value in domain:
374+
anchor_range = []
375+
for _ in range(len(domain)):
322376
if not anchor_range_values:
323377
anchor_range_values = full_range_values.copy()
324378
range_value = anchor_range_values.pop(0)
325-
self._optional_anchor_values[name][domain_value] = range_value
326379
anchor_range.append(range_value)
327380

328381
return {
@@ -347,3 +400,6 @@ def _update_datapoints(self, varied_keys: List[str]):
347400
)
348401
for key in to_remove:
349402
datapoint.pop(key, None)
403+
404+
def _set_split_content(self, name: str, value: Any):
405+
self._split_content[Template.anchor(name)] = value

0 commit comments

Comments
 (0)