Skip to content

Commit

Permalink
Change how Variables are Provided to Visualizations (#1754)
Browse files Browse the repository at this point in the history
* Changed way visualization variables are passed from request to NotebookNode

Visualization variables are now saved to a json file and loaded by a NotebookNode upon execution.

* Updated roc_curve visualization to reflect changes made to dependency injection

* Fixed bug where checking if is_generated is provided to roc_curve visualization would crash visualizaiton

Also changed ' -> "

* Changed text_exporter to always sort variables by key for testing

* Addressed PR suggestions
  • Loading branch information
ajchili authored and k8s-ci-robot committed Aug 15, 2019
1 parent d238bef commit ea67c99
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 71 deletions.
19 changes: 4 additions & 15 deletions backend/src/apiserver/visualization/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from enum import Enum
import json
from pathlib import Path
Expand Down Expand Up @@ -90,27 +89,17 @@ def __init__(
)

@staticmethod
def create_cell_from_args(args: argparse.Namespace) -> NotebookNode:
"""Creates a NotebookNode object with provided arguments as variables.
def create_cell_from_args(variables: dict) -> NotebookNode:
"""Creates NotebookNode object containing dict of provided variables.
Args:
args: Arguments that need to be injected into a NotebookNode.
variables: Arguments that need to be injected into a NotebookNode.
Returns:
NotebookNode with provided arguments as variables.
"""
variables = ""
args = json.loads(args)
for key in sorted(args.keys()):
# Check type of variable to maintain type when converting from JSON
# to notebook cell
if args[key] is None or isinstance(args[key], bool):
variables += "{} = {}\n".format(key, args[key])
else:
variables += '{} = "{}"\n'.format(key, args[key])

return new_code_cell(variables)
return new_code_cell("variables = {}".format(repr(variables)))

@staticmethod
def create_cell_from_file(filepath: Text) -> NotebookNode:
Expand Down
22 changes: 11 additions & 11 deletions backend/src/apiserver/visualization/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,39 +34,39 @@
# trueclass
# true_score_column

if is_generated is False:
if "is_generated" is not in variables or variables["is_generated"] is False:
# Create data from specified csv file(s).
# The schema file provides column names for the csv file that will be used
# to generate the roc curve.
schema_file = Path(source) / 'schema.json'
schema_file = Path(source) / "schema.json"
schema = json.loads(file_io.read_file_to_string(schema_file))
names = [x['name'] for x in schema]
names = [x["name"] for x in schema]

dfs = []
files = file_io.get_matching_files(source)
for f in files:
dfs.append(pd.read_csv(f, names=names))

df = pd.concat(dfs)
if target_lambda:
df['target'] = df.apply(eval(target_lambda), axis=1)
if variables["target_lambda"]:
df["target"] = df.apply(eval(variables["target_lambda"]), axis=1)
else:
df['target'] = df['target'].apply(lambda x: 1 if x == trueclass else 0)
fpr, tpr, thresholds = roc_curve(df['target'], df[true_score_column])
source = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresholds': thresholds})
df["target"] = df["target"].apply(lambda x: 1 if x == variables["trueclass"] else 0)
fpr, tpr, thresholds = roc_curve(df["target"], df[variables["true_score_column"]])
df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "thresholds": thresholds})
else:
# Load data from generated csv file.
source = pd.read_csv(
df = pd.read_csv(
source,
header=None,
names=['fpr', 'tpr', 'thresholds']
names=["fpr", "tpr", "thresholds"]
)

# Create visualization.
output_notebook()

p = figure(tools="pan,wheel_zoom,box_zoom,reset,hover,previewsave")
p.line('fpr', 'tpr', line_width=2, source=source)
p.line("fpr", "tpr", line_width=2, source=df)

hover = p.select(dict(type=HoverTool))
hover.tooltips = [("Threshold", "@thresholds")]
Expand Down
59 changes: 32 additions & 27 deletions backend/src/apiserver/visualization/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

import argparse
from argparse import Namespace
import importlib
import json
from pathlib import Path
from typing import Text
import shlex
Expand Down Expand Up @@ -73,48 +75,47 @@ def initialize(self):
help="JSON string of arguments to be provided to visualizations."
)

def get_arguments_from_body(self) -> argparse.Namespace:
"""Converts arguments from post request to argparser.Namespace format.
def get_arguments_from_body(self) -> Namespace:
"""Converts arguments from post request to Namespace format.
This is done because arguments, by default are provided in the
x-www-form-urlencoded format. This format is difficult to parse compared
to argparser.Namespace, which is a dict.
to Namespace, which is a dict.
Returns:
Arguments provided from post request as arparser.Namespace object.
Arguments provided from post request as a Namespace object.
"""
split_arguments = shlex.split(self.get_body_argument("arguments"))
return self.requestParser.parse_args(split_arguments)

def is_valid_request_arguments(self, arguments: argparse.Namespace) -> bool:
"""Validates arguments from post request and sends error if invalid.
def is_valid_request_arguments(self, arguments: Namespace):
"""Validates arguments from post request and raises error if invalid.
Args:
arguments: x-www-form-urlencoded formatted arguments
Returns:
Boolean value representing if provided arguments are valid.
arguments: Namespace formatted arguments
"""
if arguments.type is None:
self.send_error(400, reason="No type specified.")
return False
raise Exception("No type specified.")
if arguments.source is None:
self.send_error(400, reason="No source specified.")
return False
raise Exception("No source specified.")
try:
json.loads(arguments.arguments)
except json.JSONDecodeError:
raise Exception("Invalid JSON provided as arguments.")

return True

def generate_notebook_from_arguments(
self,
arguments: argparse.Namespace,
arguments: dict,
source: Text,
visualization_type: Text
) -> NotebookNode:
"""Generates a NotebookNode from provided arguments.
Args:
arguments: x-www-form-urlencoded formatted arguments.
input_path: Path or path pattern to be used as data reference for
arguments: JSON object containing provided arguments.
source: Path or path pattern to be used as data reference for
visualization.
visualization_type: Name of visualization to be generated.
Expand All @@ -139,16 +140,20 @@ def post(self):
# Parse arguments from request.
request_arguments = self.get_arguments_from_body()
# Validate arguments from request.
if self.is_valid_request_arguments(request_arguments):
# Create notebook with arguments from request.
nb = self.generate_notebook_from_arguments(
request_arguments.arguments,
request_arguments.source,
request_arguments.type
)
# Generate visualization (output for notebook).
html = _exporter.generate_html_from_notebook(nb)
self.write(html)
try:
self.is_valid_request_arguments(request_arguments)
except Exception as e:
return self.send_error(400, reason=str(e))

# Create notebook with arguments from request.
nb = self.generate_notebook_from_arguments(
json.loads(request_arguments.arguments),
request_arguments.source,
request_arguments.type
)
# Generate visualization (output for notebook).
html = _exporter.generate_html_from_notebook(nb)
self.write(html)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,73 @@

snapshots = Snapshot()

snapshots['TestExporterMethods::test_create_cell_from_args_with_multiple_args 1'] = '''source = "gs://ml-pipeline/data.csv"
target_lambda = "lambda x: (x[\'target\'] > x[\'fare\'] * 0.2)"
snapshots['TestExporterMethods::test_create_cell_from_args_with_multiple_args 1'] = '''
<div class="output_wrapper">
<div class="output">
<div class="output_area">
<div class="prompt"></div>
<div class="output_subarea output_stream output_stdout output_text">
<pre>[&#39;gs://ml-pipeline/data.csv&#39;, &#34;lambda x: (x[&#39;target&#39;] &gt; x[&#39;fare&#39;] * 0.2)&#34;]
</pre>
</div>
</div>
</div>
</div>
'''

snapshots['TestExporterMethods::test_create_cell_from_args_with_no_args 1'] = '''
<div class="output_wrapper">
<div class="output">
<div class="output_area">
<div class="prompt"></div>
<div class="output_subarea output_stream output_stdout output_text">
<pre>{}
</pre>
</div>
</div>
</div>
</div>
'''

snapshots['TestExporterMethods::test_create_cell_from_args_with_no_args 1'] = ''
snapshots['TestExporterMethods::test_create_cell_from_args_with_one_arg 1'] = '''
<div class="output_wrapper">
<div class="output">
<div class="output_area">
<div class="prompt"></div>
<div class="output_subarea output_stream output_stdout output_text">
<pre>[&#39;gs://ml-pipeline/data.csv&#39;]
</pre>
</div>
</div>
</div>
</div>
snapshots['TestExporterMethods::test_create_cell_from_args_with_one_arg 1'] = '''source = "gs://ml-pipeline/data.csv"
'''

snapshots['TestExporterMethods::test_create_cell_from_file 1'] = '''# Copyright 2019 Google LLC
Expand Down
37 changes: 23 additions & 14 deletions backend/src/apiserver/visualization/test_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,33 @@ def setUp(self):

def test_create_cell_from_args_with_no_args(self):
self.maxDiff = None
args = "{}"
cell = self.exporter.create_cell_from_args(args)
self.assertMatchSnapshot(cell.source)
nb = new_notebook()
args = {}
nb.cells.append(self.exporter.create_cell_from_args(args))
nb.cells.append(new_code_cell("print(variables)"))
html = self.exporter.generate_html_from_notebook(nb)
self.assertMatchSnapshot(html)

def test_create_cell_from_args_with_one_arg(self):
self.maxDiff = None
args = '{"source": "gs://ml-pipeline/data.csv"}'
cell = self.exporter.create_cell_from_args(args)
self.assertMatchSnapshot(cell.source)
nb = new_notebook()
args = {"source": "gs://ml-pipeline/data.csv"}
nb.cells.append(self.exporter.create_cell_from_args(args))
nb.cells.append(new_code_cell("print([variables[key] for key in sorted(variables.keys())])"))
html = self.exporter.generate_html_from_notebook(nb)
self.assertMatchSnapshot(html)

def test_create_cell_from_args_with_multiple_args(self):
self.maxDiff = None
args = (
'{"source": "gs://ml-pipeline/data.csv", '
"\"target_lambda\": \"lambda x: (x['target'] > x['fare'] * 0.2)\"}"
)
cell = self.exporter.create_cell_from_args(args)
self.assertMatchSnapshot(cell.source)
nb = new_notebook()
args = {
"source": "gs://ml-pipeline/data.csv",
"target_lambda": "lambda x: (x['target'] > x['fare'] * 0.2)"
}
nb.cells.append(self.exporter.create_cell_from_args(args))
nb.cells.append(new_code_cell("print([variables[key] for key in sorted(variables.keys())])"))
html = self.exporter.generate_html_from_notebook(nb)
self.assertMatchSnapshot(html)

def test_create_cell_from_file(self):
self.maxDiff = None
Expand All @@ -55,9 +64,9 @@ def test_create_cell_from_file(self):
def test_generate_html_from_notebook(self):
self.maxDiff = None
nb = new_notebook()
args = '{"x": 2}'
args = {"x": 2}
nb.cells.append(self.exporter.create_cell_from_args(args))
nb.cells.append(new_code_cell("print(x)"))
nb.cells.append(new_code_cell("print(variables['x'])"))
html = self.exporter.generate_html_from_notebook(nb)
self.assertMatchSnapshot(html)

Expand Down
11 changes: 11 additions & 0 deletions backend/src/apiserver/visualization/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ def test_create_visualization_fails_when_missing_input_path(self):
response.body
)

def test_create_visualization_fails_when_invalid_json_is_provided(self):
response = self.fetch(
"/",
method="POST",
body='arguments=--type test --source gs://ml-pipeline/data.csv --arguments "{"')
self.assertEqual(400, response.code)
self.assertEqual(
wrap_error_in_html("400: Invalid JSON provided as arguments."),
response.body
)

def test_create_visualization(self):
response = self.fetch(
"/",
Expand Down

0 comments on commit ea67c99

Please sign in to comment.