Skip to content

Commit 7ef62f5

Browse files
committed
prototype filling optional anchors
1 parent c31ff7c commit 7ef62f5

File tree

2 files changed

+180
-4
lines changed

2 files changed

+180
-4
lines changed

src/dvc_render/vega.py

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import base64
22
import io
33
import json
4+
from collections import defaultdict
45
from pathlib import Path
5-
from typing import Any, Dict, List, Optional, Union
6+
from typing import Any, Dict, List, Optional, Tuple, Union
67
from warnings import warn
78

89
from .base import Renderer
@@ -38,6 +39,29 @@ def __init__(self, datapoints: List, name: str, **properties):
3839
self.properties.get("template", None),
3940
self.properties.get("template_dir", None),
4041
)
42+
self._optional_anchor_ranges: Dict[
43+
str,
44+
Union[
45+
List[str],
46+
List[List[int]],
47+
],
48+
] = {
49+
"stroke_dash": [[1, 0], [8, 8], [8, 4], [4, 4], [4, 2], [2, 1], [1, 1]],
50+
"color": [
51+
"#945dd6",
52+
"#13adc7",
53+
"#f46837",
54+
"#48bb78",
55+
"#4299e1",
56+
"#ed8936",
57+
"#f56565",
58+
],
59+
"shape": ["square", "circle", "triangle", "diamond"],
60+
}
61+
self._optional_anchor_values: Dict[
62+
str,
63+
Dict[str, Dict[str, str]],
64+
] = defaultdict()
4165

4266
def get_filled_template(
4367
self,
@@ -85,6 +109,8 @@ def get_filled_template(
85109
value = self.template.escape_special_characters(value)
86110
self.template.fill_anchor(name, value)
87111

112+
self._fill_optional_anchors(skip_anchors)
113+
88114
if as_string:
89115
return json.dumps(self.template.content)
90116

@@ -136,3 +162,152 @@ def generate_markdown(self, report_path=None) -> str:
136162
return f"\n![{self.name}]({src})"
137163

138164
return ""
165+
166+
def _fill_optional_anchors(self, skip_anchors: List[str]):
167+
optional_anchors = [
168+
anchor
169+
for anchor in [
170+
"row",
171+
"group_by",
172+
"pivot_field",
173+
"color",
174+
"stroke_dash",
175+
"shape",
176+
]
177+
if anchor not in skip_anchors and self.template.has_anchor(anchor)
178+
]
179+
if not optional_anchors:
180+
return
181+
182+
self._fill_color(optional_anchors)
183+
184+
if not optional_anchors:
185+
return
186+
187+
y_defn = self.properties.get("anchors_y_defn", [])
188+
189+
if len(y_defn) <= 1:
190+
self._fill_optional_anchor(optional_anchors, "group_by", ["rev"])
191+
self._fill_optional_anchor(optional_anchors, "pivot_field", "rev")
192+
for anchor in optional_anchors:
193+
self.template.fill_anchor(anchor, {})
194+
self._update_datapoints(to_remove=["filename", "file"])
195+
return
196+
197+
keys, variations = self._collect_variations(y_defn)
198+
grouped_keys = ["rev", *keys]
199+
self._fill_optional_anchor(optional_anchors, "group_by", grouped_keys)
200+
self._fill_optional_anchor(
201+
optional_anchors, "pivot_field", "::".join(grouped_keys)
202+
)
203+
# concatenate grouped_keys together
204+
self._fill_optional_anchor(optional_anchors, "row", {"field": "::".join(keys)})
205+
206+
if not optional_anchors:
207+
return
208+
209+
if len(keys) == 2:
210+
self._update_datapoints(
211+
to_remove=["filename", "file"], to_concatenate=[["filename", "file"]]
212+
)
213+
domain = ["::".join([d.get("filename"), d.get("file")]) for d in y_defn]
214+
else:
215+
filenameOrField = keys[0]
216+
to_remove = ["filename", "file"]
217+
to_remove.remove(filenameOrField)
218+
self._update_datapoints(to_remove=to_remove)
219+
220+
domain = list(variations[filenameOrField])
221+
222+
stroke_dash_scale = self._set_optional_anchor_scale(
223+
optional_anchors, "stroke_dash", domain
224+
)
225+
self._fill_optional_anchor(optional_anchors, "stroke_dash", stroke_dash_scale)
226+
227+
shape_scale = self._set_optional_anchor_scale(optional_anchors, "shape", domain)
228+
self._fill_optional_anchor(optional_anchors, "shape", shape_scale)
229+
230+
def _fill_color(self, optional_anchors: List[str]):
231+
all_revs = self.properties.get("anchor_revs", [])
232+
self._fill_optional_anchor(
233+
optional_anchors,
234+
"color",
235+
{
236+
"scale": {
237+
"domain": list(all_revs),
238+
"range": self._optional_anchor_ranges.get("color", [])[
239+
: len(all_revs)
240+
],
241+
}
242+
},
243+
)
244+
245+
def _collect_variations(
246+
self, y_defn: List[Dict[str, str]]
247+
) -> Tuple[List[str], Dict[str, set]]:
248+
variations = defaultdict(set)
249+
for defn in y_defn:
250+
for key in ["filename", "field"]:
251+
variations[key].add(defn.get(key, None))
252+
253+
valuesMatchVariations = []
254+
lessValuesThanVariations = []
255+
256+
for filenameOrField, valueSet in variations.items():
257+
num_values = len(valueSet)
258+
if num_values == 1:
259+
continue
260+
if num_values == len(y_defn):
261+
valuesMatchVariations.append(filenameOrField)
262+
continue
263+
lessValuesThanVariations.append(filenameOrField)
264+
265+
if valuesMatchVariations:
266+
valuesMatchVariations.extend(lessValuesThanVariations)
267+
valuesMatchVariations.sort(reverse=True)
268+
return valuesMatchVariations, variations
269+
270+
lessValuesThanVariations.sort(reverse=True)
271+
return lessValuesThanVariations, variations
272+
273+
def _fill_optional_anchor(self, optional_anchors: List[str], name: str, value: Any):
274+
if name not in optional_anchors:
275+
return
276+
277+
optional_anchors.remove(name)
278+
self.template.fill_anchor(name, value)
279+
280+
def _set_optional_anchor_scale(
281+
self, optional_anchors: List[str], name: str, domain: List[str]
282+
):
283+
if name not in optional_anchors:
284+
return {"scale": {"domain": [], "range": []}}
285+
286+
full_range_values: List[Any] = self._optional_anchor_ranges.get(name, [])
287+
anchor_range_values = full_range_values.copy()
288+
anchor_range = []
289+
290+
for domain_value in domain:
291+
if not anchor_range_values:
292+
anchor_range_values = full_range_values.copy()
293+
range_value = anchor_range_values.pop()
294+
self._optional_anchor_values[name][domain_value] = range_value
295+
anchor_range.append(range_value)
296+
297+
return {"scale": {"domain": domain, "range": anchor_range}}
298+
299+
def _update_datapoints(
300+
self,
301+
to_remove: Optional[List[str]] = None,
302+
to_concatenate: Optional[List[List[str]]] = None,
303+
):
304+
if to_concatenate:
305+
for datapoint in self.datapoints:
306+
for keys in to_concatenate:
307+
concat_key = "::".join(keys)
308+
datapoint[concat_key] = "::".join([datapoint.get(k) for k in keys])
309+
310+
if to_remove:
311+
for datapoint in self.datapoints:
312+
for concat_key in to_remove:
313+
datapoint.pop(concat_key, None)

src/dvc_render/vega_templates.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ class SmoothLinearTemplate(Template):
529529
{
530530
"loess": Template.anchor("y"),
531531
"on": Template.anchor("x"),
532-
"groupby": ["rev", "filename", "field", "filename::field"],
532+
"groupby": Template.anchor("group_by"),
533533
"bandwidth": {"signal": "smooth"},
534534
},
535535
],
@@ -572,11 +572,12 @@ class SmoothLinearTemplate(Template):
572572
},
573573
{
574574
"transform": [
575+
{"calculate": Template.anchor("pivot_field"), "as": "pivot_field"},
575576
{
576-
"pivot": Template.anchor("group_by"),
577+
"pivot": "pivot_field",
577578
"value": Template.anchor("y"),
578579
"groupby": [Template.anchor("x")],
579-
}
580+
},
580581
],
581582
"mark": {
582583
"type": "rule",

0 commit comments

Comments
 (0)