Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Commit e117138

Browse files
committed
Move confusion matrix from %%ml to library. (#159)
* Move confusion matrix from %%ml to library. This is part of efforts to move %%ml magic stuff to library to provide a consistent experience (python only). * Add a comment.
1 parent 0611c53 commit e117138

File tree

2 files changed

+85
-124
lines changed

2 files changed

+85
-124
lines changed

datalab/mlalpha/_confusion_matrix.py

Lines changed: 85 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,57 +11,100 @@
1111
# the License.
1212

1313

14-
from plotly.offline import iplot
14+
import google.cloud.ml as ml
15+
import numpy as np
16+
import json
17+
import matplotlib.pyplot as plt
18+
import pandas as pd
19+
from sklearn.metrics import confusion_matrix
20+
21+
import datalab.bigquery as bq
22+
import datalab.data as data
1523

1624

1725
class ConfusionMatrix(object):
1826
"""Represents a confusion matrix."""
1927

20-
def __init__(self, predicted_labels, true_labels, counts):
21-
"""Initializes an instance of a ComfusionMatrix. the length of predicted_values,
22-
true_values, count must be the same.
28+
def __init__(self, cm, labels):
29+
"""
30+
Args:
31+
cm: a 2-dimensional matrix with row index being target, column index being predicted,
32+
and values being count.
33+
labels: the labels whose order matches the row/column indexes.
34+
"""
35+
self._cm = cm
36+
self._labels = labels
2337

38+
@staticmethod
39+
def from_csv(input_csv, headers=None, schema_file=None):
40+
"""Create a ConfusionMatrix from a csv file.
2441
Args:
25-
predicted_labels: a list of predicted labels.
26-
true_labels: a list of true labels.
27-
counts: a list of count for each (predicted, true) combination.
42+
input_csv: Path to a Csv file (with no header). Can be local or GCS path.
43+
headers: Csv headers. If present, it must include 'target' and 'predicted'.
44+
schema_file: Path to a JSON file containing BigQuery schema. Used if "headers" is None.
45+
If present, it must include 'target' and 'predicted' columns.
46+
Returns:
47+
A ConfusionMatrix that can be plotted.
48+
Raises:
49+
ValueError if both headers and schema_file are None, or it does not include 'target'
50+
or 'predicted' columns.
51+
"""
2852

29-
Raises: Exception if predicted_labels, true_labels, and counts are not of the same size
53+
if headers is not None:
54+
names = headers
55+
elif schema_file is not None:
56+
with ml.util._file.open_local_or_gcs(schema_file, mode='r') as f:
57+
schema = json.load(f)
58+
names = [x['name'] for x in schema]
59+
else:
60+
raise ValueError('Either headers or schema_file is needed')
61+
with ml.util._file.open_local_or_gcs(input_csv, mode='r') as f:
62+
df = pd.read_csv(f, names=names)
63+
if 'target' not in df or 'predicted' not in df:
64+
raise ValueError('Cannot find "target" or "predicted" column')
65+
66+
labels = sorted(set(df['target']) | set(df['predicted']))
67+
cm = confusion_matrix(df['target'], df['predicted'], labels=labels)
68+
return ConfusionMatrix(cm, labels)
69+
70+
@staticmethod
71+
def from_bigquery(sql):
72+
"""Create a ConfusionMatrix from a BigQuery table or query.
73+
Args:
74+
sql: Can be one of:
75+
A SQL query string.
76+
A SQL Query module defined with '%%sql --name [module_name]'.
77+
A Bigquery table.
78+
The query results or table must include "target", "predicted" columns.
79+
Returns:
80+
A ConfusionMatrix that can be plotted.
81+
Raises:
82+
ValueError if query results or table does not include 'target' or 'predicted' columns.
3083
"""
31-
if len(predicted_labels) != len(true_labels) or len(true_labels) != len(counts):
32-
raise Exception('The input predicted_labels, true_labels, counts need to be same size.')
33-
self._all_labels = list(set(predicted_labels) | set(true_labels))
34-
data = []
35-
for value in self._all_labels:
36-
predicts_for_current_true_label = \
37-
{p: c for p, t, c in zip(predicted_labels, true_labels, counts) if t == value}
38-
# sort by all_values and fill in zeros if needed
39-
predicts_for_current_true_label = [predicts_for_current_true_label.get(v, 0)
40-
for v in self._all_labels]
41-
data.append(predicts_for_current_true_label)
42-
self._data = data
84+
85+
query, _ = data.SqlModule.get_sql_statement_with_environment(sql, {})
86+
sql = ('SELECT target, predicted, count(*) as count FROM (%s) group by target, predicted'
87+
% query.sql)
88+
df = bq.Query(sql).results().to_dataframe()
89+
labels = sorted(set(df['target']) | set(df['predicted']))
90+
labels_count = len(labels)
91+
df['target'] = [labels.index(x) for x in df['target']]
92+
df['predicted'] = [labels.index(x) for x in df['predicted']]
93+
cm = [[0]*labels_count for i in range(labels_count)]
94+
for index, row in df.iterrows():
95+
cm[row['target']][row['predicted']] = row['count']
96+
return ConfusionMatrix(cm, labels)
4397

4498
def plot(self):
4599
"""Plot the confusion matrix."""
46-
figure_data = \
47-
{
48-
"data": [
49-
{
50-
"x": self._all_labels,
51-
"y": self._all_labels,
52-
"z": self._data,
53-
"colorscale": "YlGnBu",
54-
"type": "heatmap"
55-
}
56-
],
57-
"layout": {
58-
"title": "Confusion Matrix",
59-
"xaxis": {
60-
"title": "Predicted value",
61-
},
62-
"yaxis": {
63-
"title": "True Value",
64-
}
65-
}
66-
}
67-
iplot(figure_data)
100+
101+
plt.imshow(self._cm, interpolation='nearest', cmap=plt.cm.Blues)
102+
plt.title('Confusion matrix')
103+
plt.colorbar()
104+
tick_marks = np.arange(len(self._labels))
105+
plt.xticks(tick_marks, self._labels, rotation=45)
106+
plt.yticks(tick_marks, self._labels)
107+
plt.tight_layout()
108+
plt.ylabel('True label')
109+
plt.xlabel('Predicted label')
110+

datalab/mlalpha/commands/_ml.py

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,9 @@
1818

1919
import collections
2020
import google.cloud.ml as cloudml
21-
import matplotlib.pyplot as plt
2221
import numpy as np
2322
import os
2423
import pandas as pd
25-
from sklearn.metrics import confusion_matrix
2624
import yaml
2725

2826

@@ -93,23 +91,6 @@ def ml(line, cell=None):
9391
required=True)
9492
batch_predict_parser.set_defaults(func=_batch_predict)
9593

96-
confusion_matrix_parser = parser.subcommand('confusion_matrix',
97-
'Plot confusion matrix. The source is provided ' +
98-
'in one of "csv", "bqtable", and "sql" params.')
99-
confusion_matrix_parser.add_argument('--csv',
100-
help='GCS or local path of CSV file which contains ' +
101-
'"target", "predicted" columns at least. The CSV ' +
102-
'either comes with a schema file in the same dir, ' +
103-
'or specify "headers: name1, name2..." in cell.')
104-
confusion_matrix_parser.add_argument('--bqtable',
105-
help='name of the BigQuery table in the form of ' +
106-
'dataset.table.')
107-
confusion_matrix_parser.add_argument('--sql',
108-
help='name of the sql module defined in previous cell ' +
109-
'which should return "target", "predicted", ' +
110-
'and "count" columns at least in results.')
111-
confusion_matrix_parser.set_defaults(func=_confusion_matrix)
112-
11394
namespace = datalab.utils.commands.notebook_environment()
11495
return datalab.utils.commands.handle_magic_line(line, cell, parser, namespace=namespace)
11596

@@ -181,66 +162,3 @@ def _predict(args, cell):
181162

182163
def _batch_predict(args, cell):
183164
return _run_package(args, cell, 'batch_predict')
184-
185-
186-
def _plot_confusion_matrix(cm, labels):
187-
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
188-
plt.title('Confusion matrix')
189-
plt.colorbar()
190-
tick_marks = np.arange(len(labels))
191-
plt.xticks(tick_marks, labels, rotation=45)
192-
plt.yticks(tick_marks, labels)
193-
plt.tight_layout()
194-
plt.ylabel('True label')
195-
plt.xlabel('Predicted label')
196-
197-
198-
def _confusion_matrix_from_csv(input_csv, cell):
199-
schema_file = input_csv + '.schema.yaml'
200-
headers = None
201-
if cell is not None:
202-
env = datalab.utils.commands.notebook_environment()
203-
config = datalab.utils.commands.parse_config(cell, env)
204-
headers_str = config.get('headers', None)
205-
if headers_str is not None:
206-
headers = [x.strip() for x in headers_str.split(',')]
207-
if headers is not None:
208-
with cloudml.util._file.open_local_or_gcs(input_csv, mode='r') as f:
209-
df = pd.read_csv(f, names=headers)
210-
elif cloudml.util._file.file_exists(schema_file):
211-
df = datalab.mlalpha.csv_to_dataframe(input_csv, schema_file)
212-
else:
213-
raise Exception('headers is missing from cell, ' +
214-
'and there is no schema file in the same dir as csv')
215-
labels = sorted(set(df['target']) | set(df['predicted']))
216-
cm = confusion_matrix(df['target'], df['predicted'], labels=labels)
217-
return cm, labels
218-
219-
220-
def _confusion_matrix_from_query(sql_module_name, bq_table):
221-
if sql_module_name is not None:
222-
item = datalab.utils.commands.get_notebook_item(sql_module_name)
223-
query, _ = datalab.data.SqlModule.get_sql_statement_with_environment(item, {})
224-
else:
225-
query = ('select target, predicted, count(*) as count from %s group by target, predicted'
226-
% bq_table)
227-
dfbq = datalab.bigquery.Query(query).results().to_dataframe()
228-
labels = sorted(set(dfbq['target']) | set(dfbq['predicted']))
229-
labels_count = len(labels)
230-
dfbq['target'] = [labels.index(x) for x in dfbq['target']]
231-
dfbq['predicted'] = [labels.index(x) for x in dfbq['predicted']]
232-
cm = [[0]*labels_count for i in range(labels_count)]
233-
for index, row in dfbq.iterrows():
234-
cm[row['target']][row['predicted']] = row['count']
235-
return cm, labels
236-
237-
238-
def _confusion_matrix(args, cell):
239-
if args['csv'] is not None:
240-
#TODO: Maybe add cloud run for large CSVs with federated table.
241-
cm, labels = _confusion_matrix_from_csv(args['csv'], cell)
242-
elif args['sql'] is not None or args['bqtable'] is not None:
243-
cm, labels = _confusion_matrix_from_query(args['sql'], args['bqtable'])
244-
else:
245-
raise Exception('One of "csv", "bqtable", and "sql" param is needed.')
246-
_plot_confusion_matrix(cm, labels)

0 commit comments

Comments
 (0)