Skip to content

vega_templates: Handle content as dict instead of string. #124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions src/dvc_render/vega.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from copy import deepcopy
import json
from pathlib import Path
from typing import List, Optional
from typing import Any, Dict, List, Optional
from warnings import warn

from .base import Renderer
from .exceptions import DvcRenderException
from .utils import list_dict_to_dict_list
from .vega_templates import LinearTemplate, get_template


class BadTemplateError(DvcRenderException):
pass
from .vega_templates import BadTemplateError, LinearTemplate, get_template


class VegaRenderer(Renderer):
Expand Down Expand Up @@ -44,16 +39,15 @@ def __init__(self, datapoints: List, name: str, **properties):

def get_filled_template(
self, skip_anchors: Optional[List[str]] = None, strict: bool = True
) -> str:
) -> Dict[str, Any]:
"""Returns a functional vega specification"""
self.template.reset()
if not self.datapoints:
return ""
return {}

if skip_anchors is None:
skip_anchors = []

content = deepcopy(self.template.content)
Copy link
Collaborator

@skshetry skshetry Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without deep-copying, we may accidentally modify the same contents.

Copy link
Contributor Author

@daavoo daavoo Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without deep-copying, we may accidentally modify the same contents.

I added reset() method to bypass the need of deepcopy (which was also taking time)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s only protected because replace_value copies the whole collection, right?


if strict:
if self.properties.get("x"):
self.template.check_field_exists(
Expand All @@ -76,20 +70,20 @@ def get_filled_template(
if value is None:
continue
if name == "data":
if self.template.anchor_str(name) not in self.template.content:
if not self.template.has_anchor(name):
anchor = self.template.anchor(name)
raise BadTemplateError(
f"Template '{self.template.name}' "
f"is not using '{anchor}' anchor"
)
elif name in {"x", "y"}:
value = self.template.escape_special_characters(value)
content = self.template.fill_anchor(content, name, value)
self.template.fill_anchor(name, value)

return content
return self.template.content

def partial_html(self, **kwargs) -> str:
return self.get_filled_template()
return json.dumps(self.get_filled_template())

def generate_markdown(self, report_path=None) -> str:
if not isinstance(self.template, LinearTemplate):
Expand Down
114 changes: 75 additions & 39 deletions src/dvc_render/vega_templates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=missing-function-docstring
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -27,67 +28,102 @@ def __init__(self, template_name: str, path: str):
)


class BadTemplateError(DvcRenderException):
pass


def dict_replace_value(d: dict, name: str, value: Any) -> dict:
x = {}
for k, v in d.items():
if isinstance(v, dict):
v = dict_replace_value(v, name, value)
elif isinstance(v, list):
v = list_replace_value(v, name, value)
elif isinstance(v, str):
if v == name:
x[k] = value
continue
x[k] = v
return x


def list_replace_value(l: list, name: str, value: str) -> list: # noqa: E741
x = []
for e in l:
if isinstance(e, list):
e = list_replace_value(e, name, value)
elif isinstance(e, dict):
e = dict_replace_value(e, name, value)
elif isinstance(e, str):
if e == name:
e = value
x.append(e)
return x


def dict_find_value(d: dict, value: str) -> bool:
for v in d.values():
if isinstance(v, dict):
return dict_find_value(v, value)
if isinstance(v, str):
if v == value:
return True
return False


class Template:
INDENT = 4
SEPARATORS = (",", ": ")
EXTENSION = ".json"
ANCHOR = "<DVC_METRIC_{}>"

DEFAULT_CONTENT: Optional[Dict[str, Any]] = None
DEFAULT_NAME: Optional[str] = None

def __init__(self, content=None, name=None):
if content:
self.content = content
else:
self.content = (
json.dumps(
self.DEFAULT_CONTENT,
indent=self.INDENT,
separators=self.SEPARATORS,
)
+ "\n"
)

DEFAULT_CONTENT: Dict[str, Any] = {}
DEFAULT_NAME: str = ""

def __init__(
self, content: Optional[Dict[str, Any]] = None, name: Optional[str] = None
):
if (
content
and not isinstance(content, dict)
or self.DEFAULT_CONTENT
and not isinstance(self.DEFAULT_CONTENT, dict)
):
raise BadTemplateError()
self._original_content = content or self.DEFAULT_CONTENT
self.content: Dict[str, Any] = self._original_content
self.name = name or self.DEFAULT_NAME
assert self.content and self.name
self.filename = Path(self.name).with_suffix(self.EXTENSION)

@classmethod
def anchor(cls, name):
"Get ANCHOR formatted with name."
return cls.ANCHOR.format(name.upper())

def has_anchor(self, name) -> bool:
"Check if ANCHOR formatted with name is in content."
return self.anchor_str(name) in self.content

@classmethod
def fill_anchor(cls, content, name, value) -> str:
"Replace anchor `name` with `value` in content."
value_str = json.dumps(
value, indent=cls.INDENT, separators=cls.SEPARATORS, sort_keys=True
)
return content.replace(cls.anchor_str(name), value_str)

@classmethod
def escape_special_characters(cls, value: str) -> str:
"Escape special characters in `value`"
for character in (".", "[", "]"):
value = value.replace(character, "\\" + character)
return value

@classmethod
def anchor_str(cls, name) -> str:
"Get string wrapping ANCHOR formatted with name."
return f'"{cls.anchor(name)}"'

@staticmethod
def check_field_exists(data, field):
"Raise NoFieldInDataError if `field` not in `data`."
if not any(field in row for row in data):
raise NoFieldInDataError(field)

def reset(self):
"""Reset self.content to its original state."""
self.content = self._original_content

def has_anchor(self, name) -> bool:
"Check if ANCHOR formatted with name is in content."
found = dict_find_value(self.content, self.anchor(name))
return found

def fill_anchor(self, name, value) -> None:
"Replace anchor `name` with `value` in content."
self.content = dict_replace_value(self.content, self.anchor(name), value)


class BarHorizontalSortedTemplate(Template):
DEFAULT_NAME = "bar_horizontal_sorted"
Expand Down Expand Up @@ -606,7 +642,7 @@ def get_template(
_open = open if fs is None else fs.open
if template_path:
with _open(template_path, encoding="utf-8") as f:
content = f.read()
content = json.load(f)
return Template(content, name=template)

for template_cls in TEMPLATES:
Expand Down Expand Up @@ -635,6 +671,6 @@ def dump_templates(output: "StrPath", targets: Optional[List] = None) -> None:
if path.exists():
content = path.read_text(encoding="utf-8")
if content != template.content:
raise TemplateContentDoesNotMatch(template.DEFAULT_NAME or "", path)
raise TemplateContentDoesNotMatch(template.DEFAULT_NAME, str(path))
else:
path.write_text(template.content, encoding="utf-8")
path.write_text(json.dumps(template.content), encoding="utf-8")
15 changes: 10 additions & 5 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os

import pytest
Expand Down Expand Up @@ -38,8 +39,9 @@ def test_raise_on_no_template():
],
)
def test_get_template_from_dir(tmp_dir, template_path, target_name):
tmp_dir.gen(template_path, "template_content")
assert get_template(target_name, ".dvc/plots").content == "template_content"
template_content = {"template_content": "foo"}
tmp_dir.gen(template_path, json.dumps(template_content))
assert get_template(target_name, ".dvc/plots").content == template_content


def test_get_template_exact_match(tmp_dir):
Expand All @@ -51,13 +53,16 @@ def test_get_template_exact_match(tmp_dir):


def test_get_template_from_file(tmp_dir):
tmp_dir.gen("foo/bar.json", "template_content")
assert get_template("foo/bar.json").content == "template_content"
template_content = {"template_content": "foo"}
tmp_dir.gen("foo/bar.json", json.dumps(template_content))
assert get_template("foo/bar.json").content == template_content


def test_get_template_fs(tmp_dir, mocker):
tmp_dir.gen("foo/bar.json", "template_content")
template_content = {"template_content": "foo"}
tmp_dir.gen("foo/bar.json", json.dumps(template_content))
fs = mocker.MagicMock()
mocker.patch("json.load", return_value={})
get_template("foo/bar.json", fs=fs)
fs.open.assert_called()
fs.exists.assert_called()
Expand Down
17 changes: 5 additions & 12 deletions tests/test_vega.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

import pytest

from dvc_render.vega import BadTemplateError, VegaRenderer
Expand Down Expand Up @@ -33,7 +31,6 @@ def test_init_empty():
assert renderer.name == ""
assert renderer.properties == {}

assert renderer.generate_html() == ""
assert renderer.generate_markdown("foo") == ""


Expand All @@ -43,7 +40,7 @@ def test_default_template_mark():
{"first_val": 200, "second_val": 300, "val": 3},
]

plot_content = json.loads(VegaRenderer(datapoints, "foo").partial_html())
plot_content = VegaRenderer(datapoints, "foo").get_filled_template()

assert plot_content["layer"][0]["mark"] == "line"

Expand All @@ -60,7 +57,7 @@ def test_choose_axes():
{"first_val": 200, "second_val": 300, "val": 3},
]

plot_content = json.loads(VegaRenderer(datapoints, "foo", **props).partial_html())
plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template()

assert plot_content["data"]["values"] == [
{
Expand All @@ -85,7 +82,7 @@ def test_confusion():
]
props = {"template": "confusion", "x": "predicted", "y": "actual"}

plot_content = json.loads(VegaRenderer(datapoints, "foo", **props).partial_html())
plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template()

assert plot_content["data"]["values"] == [
{"predicted": "B", "actual": "A"},
Expand All @@ -100,12 +97,8 @@ def test_confusion():


def test_bad_template():
datapoints = [{"val": 2}, {"val": 3}]
props = {"template": Template("name", "content")}
renderer = VegaRenderer(datapoints, "foo", **props)
with pytest.raises(BadTemplateError):
renderer.get_filled_template()
renderer.get_filled_template(skip_anchors=["data"])
Template("name", "content")


def test_raise_on_wrong_field():
Expand Down Expand Up @@ -177,7 +170,7 @@ def test_escape_special_characters():
]
props = {"template": "simple", "x": "foo.bar[0]", "y": "foo.bar[1]"}
renderer = VegaRenderer(datapoints, "foo", **props)
filled = json.loads(renderer.get_filled_template())
filled = renderer.get_filled_template()
# data is not escaped
assert filled["data"]["values"][0] == datapoints[0]
# field and title yes
Expand Down