Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Commit cd2f434

Browse files
brandondutraqimingj
authored andcommitted
confusion matrix can glob files (#289)
* cm can glob files * comment
1 parent 973a7a8 commit cd2f434

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

datalab/ml/_confusion_matrix.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def from_csv(input_csv, headers=None, schema_file=None):
4141
"""Create a ConfusionMatrix from a csv file.
4242
Args:
4343
input_csv: Path to a Csv file (with no header). Can be local or GCS path.
44+
Path may contain wildcards.
4445
headers: Csv headers. If present, it must include 'target' and 'predicted'.
4546
schema_file: Path to a JSON file containing BigQuery schema. Used if "headers" is None.
4647
If present, it must include 'target' and 'predicted' columns.
@@ -59,8 +60,14 @@ def from_csv(input_csv, headers=None, schema_file=None):
5960
names = [x['name'] for x in schema]
6061
else:
6162
raise ValueError('Either headers or schema_file is needed')
62-
with _util.open_local_or_gcs(input_csv, mode='r') as f:
63-
df = pd.read_csv(f, names=names)
63+
64+
all_files = _util.glob_files(input_csv)
65+
all_df = []
66+
for file_name in all_files:
67+
with _util.open_local_or_gcs(file_name, mode='r') as f:
68+
all_df.append(pd.read_csv(f, names=names))
69+
df = pd.concat(all_df, ignore_index=True)
70+
6471
if 'target' not in df or 'predicted' not in df:
6572
raise ValueError('Cannot find "target" or "predicted" column')
6673

0 commit comments

Comments
 (0)