Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change how Variables are Provided to Visualizations #1754

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions backend/src/apiserver/visualization/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import uuid
from enum import Enum
import json
import os
from pathlib import Path
from typing import Text
from jupyter_client import KernelManager
Expand Down Expand Up @@ -90,27 +91,38 @@ def __init__(
)

@staticmethod
def create_cell_from_args(args: argparse.Namespace) -> NotebookNode:
"""Creates a NotebookNode object with provided arguments as variables.
def create_cell_from_args(variables: dict) -> NotebookNode:
"""Creates NotebookNode object that loads provided variables from JSON.

Provided dict is saved to file. It is then loaded by the returned
NotebookNode to be used for visualization generation.

Args:
args: Arguments that need to be injected into a NotebookNode.
variables: Arguments that need to be injected into a NotebookNode.

Returns:
NotebookNode with provided arguments as variables.

"""
variables = ""
args = json.loads(args)
for key in sorted(args.keys()):
# Check type of variable to maintain type when converting from JSON
# to notebook cell
if args[key] is None or isinstance(args[key], bool):
variables += "{} = {}\n".format(key, args[key])
else:
variables += '{} = "{}"\n'.format(key, args[key])

return new_code_cell(variables)
# Generates random file name to ensure variables for visualization are
# not overwritten by future visualizations.
file_name = "variables-{}.json".format(uuid.uuid4())
if os.path.exists(file_name):
ajchili marked this conversation as resolved.
Show resolved Hide resolved
os.remove(file_name)
with open(file_name, "w") as f:
json.dump(variables, f)

return new_code_cell("""
import json
import os

variables = dict()
ajchili marked this conversation as resolved.
Show resolved Hide resolved

with open("{}", "r") as f:
variables = json.load(f)

os.remove("{}")
""".format(file_name, file_name))

@staticmethod
def create_cell_from_file(filepath: Text) -> NotebookNode:
Expand Down
22 changes: 11 additions & 11 deletions backend/src/apiserver/visualization/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,39 +34,39 @@
# trueclass
# true_score_column

if is_generated is False:
if "is_generated" is not in variables or variables["is_generated"] is False:
# Create data from specified csv file(s).
# The schema file provides column names for the csv file that will be used
# to generate the roc curve.
schema_file = Path(source) / 'schema.json'
schema_file = Path(source) / "schema.json"
schema = json.loads(file_io.read_file_to_string(schema_file))
names = [x['name'] for x in schema]
names = [x["name"] for x in schema]

dfs = []
files = file_io.get_matching_files(source)
for f in files:
dfs.append(pd.read_csv(f, names=names))

df = pd.concat(dfs)
if target_lambda:
df['target'] = df.apply(eval(target_lambda), axis=1)
if variables["target_lambda"]:
df["target"] = df.apply(eval(variables["target_lambda"]), axis=1)
else:
df['target'] = df['target'].apply(lambda x: 1 if x == trueclass else 0)
fpr, tpr, thresholds = roc_curve(df['target'], df[true_score_column])
source = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresholds': thresholds})
df["target"] = df["target"].apply(lambda x: 1 if x == variables["trueclass"] else 0)
fpr, tpr, thresholds = roc_curve(df["target"], df[variables["true_score_column"]])
df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "thresholds": thresholds})
else:
# Load data from generated csv file.
source = pd.read_csv(
df = pd.read_csv(
source,
header=None,
names=['fpr', 'tpr', 'thresholds']
names=["fpr", "tpr", "thresholds"]
)

# Create visualization.
output_notebook()

p = figure(tools="pan,wheel_zoom,box_zoom,reset,hover,previewsave")
p.line('fpr', 'tpr', line_width=2, source=source)
p.line("fpr", "tpr", line_width=2, source=df)

hover = p.select(dict(type=HoverTool))
hover.tooltips = [("Threshold", "@thresholds")]
Expand Down
16 changes: 11 additions & 5 deletions backend/src/apiserver/visualization/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import argparse
import importlib
import json
from pathlib import Path
from typing import Text
import shlex
Expand Down Expand Up @@ -90,7 +91,7 @@ def is_valid_request_arguments(self, arguments: argparse.Namespace) -> bool:
"""Validates arguments from post request and sends error if invalid.

Args:
arguments: x-www-form-urlencoded formatted arguments
arguments: argparser.Namespace formatted arguments

Returns:
Boolean value representing if provided arguments are valid.
Expand All @@ -101,20 +102,25 @@ def is_valid_request_arguments(self, arguments: argparse.Namespace) -> bool:
if arguments.source is None:
self.send_error(400, reason="No source specified.")
return False
try:
json.loads(arguments.arguments)
except json.JSONDecodeError:
self.send_error(400, reason="Invalid JSON provided as arguments.")
return False

return True

def generate_notebook_from_arguments(
self,
arguments: argparse.Namespace,
arguments: dict,
source: Text,
visualization_type: Text
) -> NotebookNode:
"""Generates a NotebookNode from provided arguments.

Args:
arguments: x-www-form-urlencoded formatted arguments.
input_path: Path or path pattern to be used as data reference for
arguments: JSON object containing provided arguments.
source: Path or path pattern to be used as data reference for
visualization.
visualization_type: Name of visualization to be generated.

Expand Down Expand Up @@ -142,7 +148,7 @@ def post(self):
if self.is_valid_request_arguments(request_arguments):
# Create notebook with arguments from request.
nb = self.generate_notebook_from_arguments(
request_arguments.arguments,
json.loads(request_arguments.arguments),
request_arguments.source,
request_arguments.type
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,73 @@

snapshots = Snapshot()

snapshots['TestExporterMethods::test_create_cell_from_args_with_multiple_args 1'] = '''source = "gs://ml-pipeline/data.csv"
target_lambda = "lambda x: (x[\'target\'] > x[\'fare\'] * 0.2)"
snapshots['TestExporterMethods::test_create_cell_from_args_with_multiple_args 1'] = '''
<div class="output_wrapper">
<div class="output">


<div class="output_area">

<div class="prompt"></div>


<div class="output_subarea output_stream output_stdout output_text">
<pre>[&#39;gs://ml-pipeline/data.csv&#39;, &#34;lambda x: (x[&#39;target&#39;] &gt; x[&#39;fare&#39;] * 0.2)&#34;]
</pre>
</div>
</div>

</div>
</div>



'''

snapshots['TestExporterMethods::test_create_cell_from_args_with_no_args 1'] = '''
<div class="output_wrapper">
<div class="output">


<div class="output_area">

<div class="prompt"></div>


<div class="output_subarea output_stream output_stdout output_text">
<pre>{}
</pre>
</div>
</div>

</div>
</div>



'''

snapshots['TestExporterMethods::test_create_cell_from_args_with_no_args 1'] = ''
snapshots['TestExporterMethods::test_create_cell_from_args_with_one_arg 1'] = '''
<div class="output_wrapper">
<div class="output">


<div class="output_area">

<div class="prompt"></div>


<div class="output_subarea output_stream output_stdout output_text">
<pre>[&#39;gs://ml-pipeline/data.csv&#39;]
</pre>
</div>
</div>

</div>
</div>



snapshots['TestExporterMethods::test_create_cell_from_args_with_one_arg 1'] = '''source = "gs://ml-pipeline/data.csv"
'''

snapshots['TestExporterMethods::test_create_cell_from_file 1'] = '''# Copyright 2019 Google LLC
Expand Down
55 changes: 41 additions & 14 deletions backend/src/apiserver/visualization/test_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import importlib
import os
import unittest
from nbformat.v4 import new_code_cell
from nbformat.v4 import new_notebook
Expand All @@ -25,27 +26,53 @@ class TestExporterMethods(snapshottest.TestCase):

def setUp(self):
self.exporter = exporter.Exporter(100, exporter.TemplateType.BASIC)
# Remove any leftover json files from testing/development. This is done
# to ensure that test_create_cell_from_args_deletes_file_on_execution
# is tested in a valid environment where no json files are present.
files = [x for x in os.listdir("./") if len(x) >= 5 and x[-5:] == ".json"]
for f in files:
os.remove(f)

def test_create_cell_from_args_deletes_file_on_execution(self):
self.maxDiff = None
nb = new_notebook()
args = {}
nb.cells.append(self.exporter.create_cell_from_args(args))
self.exporter.generate_html_from_notebook(nb)
self.assertEqual(
[x for x in os.listdir("./") if len(x) >= 5 and x[-5:] == ".json"],
[]
)

def test_create_cell_from_args_with_no_args(self):
self.maxDiff = None
args = "{}"
cell = self.exporter.create_cell_from_args(args)
self.assertMatchSnapshot(cell.source)
nb = new_notebook()
args = {}
nb.cells.append(self.exporter.create_cell_from_args(args))
nb.cells.append(new_code_cell("print(variables)"))
html = self.exporter.generate_html_from_notebook(nb)
self.assertMatchSnapshot(html)

def test_create_cell_from_args_with_one_arg(self):
self.maxDiff = None
args = '{"source": "gs://ml-pipeline/data.csv"}'
cell = self.exporter.create_cell_from_args(args)
self.assertMatchSnapshot(cell.source)
nb = new_notebook()
args = {"source": "gs://ml-pipeline/data.csv"}
nb.cells.append(self.exporter.create_cell_from_args(args))
nb.cells.append(new_code_cell("print([variables[key] for key in sorted(variables.keys())])"))
html = self.exporter.generate_html_from_notebook(nb)
self.assertMatchSnapshot(html)

def test_create_cell_from_args_with_multiple_args(self):
self.maxDiff = None
args = (
'{"source": "gs://ml-pipeline/data.csv", '
"\"target_lambda\": \"lambda x: (x['target'] > x['fare'] * 0.2)\"}"
)
cell = self.exporter.create_cell_from_args(args)
self.assertMatchSnapshot(cell.source)
nb = new_notebook()
args = {
"source": "gs://ml-pipeline/data.csv",
"target_lambda": "lambda x: (x['target'] > x['fare'] * 0.2)"
}
nb.cells.append(self.exporter.create_cell_from_args(args))
nb.cells.append(new_code_cell("print([variables[key] for key in sorted(variables.keys())])"))
html = self.exporter.generate_html_from_notebook(nb)
self.assertMatchSnapshot(html)

def test_create_cell_from_file(self):
self.maxDiff = None
Expand All @@ -55,9 +82,9 @@ def test_create_cell_from_file(self):
def test_generate_html_from_notebook(self):
self.maxDiff = None
nb = new_notebook()
args = '{"x": 2}'
args = {"x": 2}
nb.cells.append(self.exporter.create_cell_from_args(args))
nb.cells.append(new_code_cell("print(x)"))
nb.cells.append(new_code_cell("print(variables['x'])"))
html = self.exporter.generate_html_from_notebook(nb)
self.assertMatchSnapshot(html)

Expand Down
11 changes: 11 additions & 0 deletions backend/src/apiserver/visualization/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ def test_create_visualization_fails_when_missing_input_path(self):
response.body
)

def test_create_visualization_fails_when_invalid_json_is_provided(self):
response = self.fetch(
"/",
method="POST",
body='arguments=--type test --source gs://ml-pipeline/data.csv --arguments "{"')
self.assertEqual(400, response.code)
self.assertEqual(
wrap_error_in_html("400: Invalid JSON provided as arguments."),
response.body
)

def test_create_visualization(self):
response = self.fetch(
"/",
Expand Down