-
Notifications
You must be signed in to change notification settings - Fork 78
Remove old feature-slicing pipeline implementation (is replaced by BigQuery) Add Confusion matrix magic. #129
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,9 +17,15 @@ | |
raise Exception('This module can only be loaded in ipython.') | ||
|
||
import collections | ||
import google.cloud.ml as cloudml | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import os | ||
import pandas as pd | ||
from sklearn.metrics import confusion_matrix | ||
import yaml | ||
|
||
|
||
import datalab.context | ||
import datalab.mlalpha | ||
import datalab.utils.commands | ||
|
@@ -87,6 +93,23 @@ def ml(line, cell=None): | |
required=True) | ||
batch_predict_parser.set_defaults(func=_batch_predict) | ||
|
||
confusion_matrix_parser = parser.subcommand('confusion_matrix', | ||
'Plot confusion matrix. The source is provided ' + | ||
'in one of "csv", "bqtable", and "sql" params.') | ||
confusion_matrix_parser.add_argument('--csv', | ||
help='GCS or local path of CSV file which contains ' + | ||
'"target", "predicted" columns at least. The CSV ' + | ||
'either comes with a schema file in the same dir, ' + | ||
'or specify "headers: name1, name2..." in cell.') | ||
confusion_matrix_parser.add_argument('--bqtable', | ||
help='name of the BigQuery table in the form of ' + | ||
'dataset.table.') | ||
confusion_matrix_parser.add_argument('--sql', | ||
help='name of the sql module defined in previous cell ' + | ||
'which should return "target", "predicted", ' + | ||
'and "count" columns at least in results.') | ||
confusion_matrix_parser.set_defaults(func=_confusion_matrix) | ||
|
||
namespace = datalab.utils.commands.notebook_environment() | ||
return datalab.utils.commands.handle_magic_line(line, cell, parser, namespace=namespace) | ||
|
||
|
@@ -158,3 +181,66 @@ def _predict(args, cell): | |
|
||
def _batch_predict(args, cell): | ||
return _run_package(args, cell, 'batch_predict') | ||
|
||
|
||
def _plot_confusion_matrix(cm, labels): | ||
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) | ||
plt.title('Confusion matrix') | ||
plt.colorbar() | ||
tick_marks = np.arange(len(labels)) | ||
plt.xticks(tick_marks, labels, rotation=45) | ||
plt.yticks(tick_marks, labels) | ||
plt.tight_layout() | ||
plt.ylabel('True label') | ||
plt.xlabel('Predicted label') | ||
|
||
|
||
def _confusion_matrix_from_csv(input_csv, cell): | ||
schema_file = input_csv + '.schema.yaml' | ||
headers = None | ||
if cell is not None: | ||
env = datalab.utils.commands.notebook_environment() | ||
config = datalab.utils.commands.parse_config(cell, env) | ||
headers_str = config.get('headers', None) | ||
if headers_str is not None: | ||
headers = [x.strip() for x in headers_str.split(',')] | ||
if headers is not None: | ||
with cloudml.util._file.open_local_or_gcs(input_csv, mode='r') as f: | ||
df = pd.read_csv(f, names=headers) | ||
elif cloudml.util._file.file_exists(schema_file): | ||
df = datalab.mlalpha.csv_to_dataframe(input_csv, schema_file) | ||
else: | ||
raise Exception('headers is missing from cell, ' + | ||
'and there is no schema file in the same dir as csv') | ||
labels = sorted(set(df['target']) | set(df['predicted'])) | ||
cm = confusion_matrix(df['target'], df['predicted'], labels=labels) | ||
return cm, labels | ||
|
||
|
||
def _confusion_matrix_from_query(sql_module_name, bq_table): | ||
if sql_module_name is not None: | ||
item = datalab.utils.commands.get_notebook_item(sql_module_name) | ||
query, _ = datalab.data.SqlModule.get_sql_statement_with_environment(item, {}) | ||
else: | ||
query = ('select target, predicted, count(*) as count from %s group by target, predicted' | ||
% bq_table) | ||
dfbq = datalab.bigquery.Query(query).results().to_dataframe() | ||
labels = sorted(set(dfbq['target']) | set(dfbq['predicted'])) | ||
labels_count = len(labels) | ||
dfbq['target'] = [labels.index(x) for x in dfbq['target']] | ||
dfbq['predicted'] = [labels.index(x) for x in dfbq['predicted']] | ||
cm = [[0]*labels_count for i in range(labels_count)] | ||
for index, row in dfbq.iterrows(): | ||
cm[row['target']][row['predicted']] = row['count'] | ||
return cm, labels | ||
|
||
|
||
def _confusion_matrix(args, cell): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this function could call two helpers to make it more readable like _cm_from_csv and _cm_from_sql There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. Done. |
||
if args['csv'] is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, ok. So for now the focus is on building a confusion_matrix from bigquery, and package authors are responsible for getting the results into bigquery. CVS is offered for smaller datasets. Please add a todo saying something like "add option to load csv data into bigquery". Or somehow document that the csv path is not recommended. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added TODO. Will address that TODO once service supports "offline batch prediction". |
||
#TODO: Maybe add cloud run for large CSVs with federated table. | ||
cm, labels = _confusion_matrix_from_csv(args['csv'], cell) | ||
elif args['sql'] is not None or args['bqtable'] is not None: | ||
cm, labels = _confusion_matrix_from_query(args['sql'], args['bqtable']) | ||
else: | ||
raise Exception('One of "csv", "bqtable", and "sql" param is needed.') | ||
_plot_confusion_matrix(cm, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you give an example of the schema file in the comment?
After loading the schema file, the order of the column names in the python object is preserved, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The schema would be:
type: STRING
type: STRING
...
It's a list of dict, so it is ordered even after loaded in memory.