|
11 | 11 | # the License.
|
12 | 12 |
|
13 | 13 |
|
14 |
| -from plotly.offline import iplot |
| 14 | +import google.cloud.ml as ml |
| 15 | +import numpy as np |
| 16 | +import json |
| 17 | +import matplotlib.pyplot as plt |
| 18 | +import pandas as pd |
| 19 | +from sklearn.metrics import confusion_matrix |
| 20 | + |
| 21 | +import datalab.bigquery as bq |
| 22 | +import datalab.data as data |
15 | 23 |
|
16 | 24 |
|
17 | 25 | class ConfusionMatrix(object):
|
18 | 26 | """Represents a confusion matrix."""
|
19 | 27 |
|
20 |
| - def __init__(self, predicted_labels, true_labels, counts): |
21 |
| - """Initializes an instance of a ComfusionMatrix. the length of predicted_values, |
22 |
| - true_values, count must be the same. |
| 28 | + def __init__(self, cm, labels): |
| 29 | + """ |
| 30 | + Args: |
| 31 | + cm: a 2-dimensional matrix with row index being target, column index being predicted, |
| 32 | + and values being count. |
| 33 | + labels: the labels whose order matches the row/column indexes. |
| 34 | + """ |
| 35 | + self._cm = cm |
| 36 | + self._labels = labels |
23 | 37 |
|
| 38 | + @staticmethod |
| 39 | + def from_csv(input_csv, headers=None, schema_file=None): |
| 40 | + """Create a ConfusionMatrix from a csv file. |
24 | 41 | Args:
|
25 |
| - predicted_labels: a list of predicted labels. |
26 |
| - true_labels: a list of true labels. |
27 |
| - counts: a list of count for each (predicted, true) combination. |
| 42 | + input_csv: Path to a Csv file (with no header). Can be local or GCS path. |
| 43 | + headers: Csv headers. If present, it must include 'target' and 'predicted'. |
| 44 | + schema_file: Path to a JSON file containing BigQuery schema. Used if "headers" is None. |
| 45 | + If present, it must include 'target' and 'predicted' columns. |
| 46 | + Returns: |
| 47 | + A ConfusionMatrix that can be plotted. |
| 48 | + Raises: |
| 49 | + ValueError if both headers and schema_file are None, or it does not include 'target' |
| 50 | + or 'predicted' columns. |
| 51 | + """ |
28 | 52 |
|
29 |
| - Raises: Exception if predicted_labels, true_labels, and counts are not of the same size |
| 53 | + if headers is not None: |
| 54 | + names = headers |
| 55 | + elif schema_file is not None: |
| 56 | + with ml.util._file.open_local_or_gcs(schema_file, mode='r') as f: |
| 57 | + schema = json.load(f) |
| 58 | + names = [x['name'] for x in schema] |
| 59 | + else: |
| 60 | + raise ValueError('Either headers or schema_file is needed') |
| 61 | + with ml.util._file.open_local_or_gcs(input_csv, mode='r') as f: |
| 62 | + df = pd.read_csv(f, names=names) |
| 63 | + if 'target' not in df or 'predicted' not in df: |
| 64 | + raise ValueError('Cannot find "target" or "predicted" column') |
| 65 | + |
| 66 | + labels = sorted(set(df['target']) | set(df['predicted'])) |
| 67 | + cm = confusion_matrix(df['target'], df['predicted'], labels=labels) |
| 68 | + return ConfusionMatrix(cm, labels) |
| 69 | + |
| 70 | + @staticmethod |
| 71 | + def from_bigquery(sql): |
| 72 | + """Create a ConfusionMatrix from a BigQuery table or query. |
| 73 | + Args: |
| 74 | + sql: Can be one of: |
| 75 | + A SQL query string. |
| 76 | + A SQL Query module defined with '%%sql --name [module_name]'. |
| 77 | + A Bigquery table. |
| 78 | + The query results or table must include "target", "predicted" columns. |
| 79 | + Returns: |
| 80 | + A ConfusionMatrix that can be plotted. |
| 81 | + Raises: |
| 82 | + ValueError if query results or table does not include 'target' or 'predicted' columns. |
30 | 83 | """
|
31 |
| - if len(predicted_labels) != len(true_labels) or len(true_labels) != len(counts): |
32 |
| - raise Exception('The input predicted_labels, true_labels, counts need to be same size.') |
33 |
| - self._all_labels = list(set(predicted_labels) | set(true_labels)) |
34 |
| - data = [] |
35 |
| - for value in self._all_labels: |
36 |
| - predicts_for_current_true_label = \ |
37 |
| - {p: c for p, t, c in zip(predicted_labels, true_labels, counts) if t == value} |
38 |
| - # sort by all_values and fill in zeros if needed |
39 |
| - predicts_for_current_true_label = [predicts_for_current_true_label.get(v, 0) |
40 |
| - for v in self._all_labels] |
41 |
| - data.append(predicts_for_current_true_label) |
42 |
| - self._data = data |
| 84 | + |
| 85 | + query, _ = data.SqlModule.get_sql_statement_with_environment(sql, {}) |
| 86 | + sql = ('SELECT target, predicted, count(*) as count FROM (%s) group by target, predicted' |
| 87 | + % query.sql) |
| 88 | + df = bq.Query(sql).results().to_dataframe() |
| 89 | + labels = sorted(set(df['target']) | set(df['predicted'])) |
| 90 | + labels_count = len(labels) |
| 91 | + df['target'] = [labels.index(x) for x in df['target']] |
| 92 | + df['predicted'] = [labels.index(x) for x in df['predicted']] |
| 93 | + cm = [[0]*labels_count for i in range(labels_count)] |
| 94 | + for index, row in df.iterrows(): |
| 95 | + cm[row['target']][row['predicted']] = row['count'] |
| 96 | + return ConfusionMatrix(cm, labels) |
43 | 97 |
|
44 | 98 | def plot(self):
|
45 | 99 | """Plot the confusion matrix."""
|
46 |
| - figure_data = \ |
47 |
| - { |
48 |
| - "data": [ |
49 |
| - { |
50 |
| - "x": self._all_labels, |
51 |
| - "y": self._all_labels, |
52 |
| - "z": self._data, |
53 |
| - "colorscale": "YlGnBu", |
54 |
| - "type": "heatmap" |
55 |
| - } |
56 |
| - ], |
57 |
| - "layout": { |
58 |
| - "title": "Confusion Matrix", |
59 |
| - "xaxis": { |
60 |
| - "title": "Predicted value", |
61 |
| - }, |
62 |
| - "yaxis": { |
63 |
| - "title": "True Value", |
64 |
| - } |
65 |
| - } |
66 |
| - } |
67 |
| - iplot(figure_data) |
| 100 | + |
| 101 | + plt.imshow(self._cm, interpolation='nearest', cmap=plt.cm.Blues) |
| 102 | + plt.title('Confusion matrix') |
| 103 | + plt.colorbar() |
| 104 | + tick_marks = np.arange(len(self._labels)) |
| 105 | + plt.xticks(tick_marks, self._labels, rotation=45) |
| 106 | + plt.yticks(tick_marks, self._labels) |
| 107 | + plt.tight_layout() |
| 108 | + plt.ylabel('True label') |
| 109 | + plt.xlabel('Predicted label') |
| 110 | + |
0 commit comments