|
1 |
| -# Copyright 2016 Google Inc. All rights reserved. |
2 |
| -# |
3 |
| -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except |
4 |
| -# in compliance with the License. You may obtain a copy of the License at |
5 |
| -# |
6 |
| -# http://www.apache.org/licenses/LICENSE-2.0 |
7 |
| -# |
8 |
| -# Unless required by applicable law or agreed to in writing, software distributed under the License |
9 |
| -# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express |
10 |
| -# or implied. See the License for the specific language governing permissions and limitations under |
11 |
| -# the License. |
12 |
| - |
13 |
| -"""Implements Cloud ML Summary wrapper.""" |
14 |
| - |
15 | 1 | import datetime
|
| 2 | +import fnmatch |
16 | 3 | import glob
|
| 4 | +import google.cloud.ml as ml |
| 5 | +import matplotlib.pyplot as plt |
17 | 6 | import os
|
| 7 | +import pandas as pd |
18 | 8 | from tensorflow.core.util import event_pb2
|
19 | 9 | from tensorflow.python.lib.io import tf_record
|
20 | 10 |
|
21 |
| -import datalab.storage as storage |
22 |
| - |
23 | 11 |
|
24 | 12 | class Summary(object):
|
25 |
| - """Represents TensorFlow summary events from files under a directory.""" |
| 13 | + """Represents TensorFlow summary events from files under specified directories.""" |
26 | 14 |
|
27 |
| - def __init__(self, path): |
| 15 | + def __init__(self, paths): |
28 | 16 | """Initializes an instance of a Summary.
|
29 | 17 |
|
30 | 18 | Args:
|
31 |
| - path: the path of the directory which holds TensorFlow events files. |
32 |
| - Can be local path or GCS path. |
| 19 | + path: a list of paths to directories which hold TensorFlow events files. |
| 20 | + Can be local path or GCS paths. Wild cards allowed. |
33 | 21 | """
|
34 |
| - self._path = path |
35 |
| - |
36 |
| - def _get_events_files(self): |
37 |
| - if self._path.startswith('gs://'): |
38 |
| - storage._api.Api.verify_permitted_to_read(self._path) |
39 |
| - bucket, prefix = storage._bucket.parse_name(self._path) |
40 |
| - items = storage.Items(bucket, prefix, None) |
41 |
| - filtered_list = [item.uri for item in items if os.path.basename(item.uri).find('tfevents')] |
42 |
| - return filtered_list |
43 |
| - else: |
44 |
| - path_pattern = os.path.join(self._path, '*tfevents*') |
45 |
| - return glob.glob(path_pattern) |
| 22 | + self._paths = [paths] if isinstance(paths, basestring) else paths |
| 23 | + |
| 24 | + def _glob_events_files(self, paths): |
| 25 | + event_files = [] |
| 26 | + for path in paths: |
| 27 | + if path.startswith('gs://'): |
| 28 | + event_files += ml.util._file.glob_files(os.path.join(path, '*.tfevents.*')) |
| 29 | + else: |
| 30 | + dirs = ml.util._file.glob_files(path) |
| 31 | + for dir in dirs: |
| 32 | + for root, _, filenames in os.walk(dir): |
| 33 | + for filename in fnmatch.filter(filenames, '*.tfevents.*'): |
| 34 | + event_files.append(os.path.join(root, filename)) |
| 35 | + return event_files |
46 | 36 |
|
47 | 37 | def list_events(self):
|
48 | 38 | """List all scalar events in the directory.
|
49 | 39 |
|
50 | 40 | Returns:
|
51 |
| - A set of unique event tags. |
| 41 | + A dictionary. Key is the name of a event. Value is a set of dirs that contain that event. |
52 | 42 | """
|
53 |
| - event_tags = set() |
54 |
| - for event_file in self._get_events_files(): |
| 43 | + event_dir_dict = {} |
| 44 | + for event_file in self._glob_events_files(self._paths): |
| 45 | + dir = os.path.dirname(event_file) |
55 | 46 | for record in tf_record.tf_record_iterator(event_file):
|
56 | 47 | event = event_pb2.Event.FromString(record)
|
57 | 48 | if event.summary is None or event.summary.value is None:
|
58 | 49 | continue
|
59 | 50 | for value in event.summary.value:
|
60 |
| - if value.simple_value is None: |
| 51 | + if value.simple_value is None or value.tag is None: |
61 | 52 | continue
|
62 |
| - if value.tag is not None and value.tag not in event_tags: |
63 |
| - event_tags.add(value.tag) |
64 |
| - return event_tags |
| 53 | + if not value.tag in event_dir_dict: |
| 54 | + event_dir_dict[value.tag] = set() |
| 55 | + event_dir_dict[value.tag].add(dir) |
| 56 | + return event_dir_dict |
| 57 | + |
65 | 58 |
|
66 |
| - def get_events(self, event_name): |
67 |
| - """Get all events of a certain tag. |
| 59 | + def get_events(self, event_names): |
| 60 | + """Get all events as pandas DataFrames given a list of names. |
68 | 61 |
|
69 | 62 | Args:
|
70 |
| - event_name: the tag of event to look for. |
| 63 | + event_names: A list of events to get. |
71 | 64 |
|
72 | 65 | Returns:
|
73 |
| - A tuple. First is a list of {time_span, event_name}. Second is a list of {step, event_name}. |
74 |
| -
|
75 |
| - Raises: |
76 |
| - Exception if event start time cannot be found |
| 66 | + A list with the same length as event_names. Each element is a dictionary |
| 67 | + {dir1: DataFrame1, dir2: DataFrame2, ...}. |
| 68 | + Multiple directories may contain events with the same name, but they are different |
| 69 | + events (i.e. 'loss' under trains_set/, and 'loss' under eval_set/.) |
77 | 70 | """
|
78 |
| - events_time = [] |
79 |
| - events_step = [] |
80 |
| - event_start_time = None |
81 |
| - for event_file in self._get_events_files(): |
82 |
| - for record in tf_record.tf_record_iterator(event_file): |
83 |
| - event = event_pb2.Event.FromString(record) |
84 |
| - if event.file_version is not None: |
85 |
| - # first event in the file. |
86 |
| - time = datetime.datetime.fromtimestamp(event.wall_time) |
87 |
| - if event_start_time is None or event_start_time > time: |
88 |
| - event_start_time = time |
| 71 | + event_names = [event_names] if isinstance(event_names, basestring) else event_names |
89 | 72 |
|
90 |
| - if event.summary is None or event.summary.value is None: |
91 |
| - continue |
92 |
| - for value in event.summary.value: |
93 |
| - if value.simple_value is None or value.tag is None: |
| 73 | + all_events = self.list_events() |
| 74 | + dirs_to_look = set() |
| 75 | + for event, dirs in all_events.iteritems(): |
| 76 | + if event in event_names: |
| 77 | + dirs_to_look.update(dirs) |
| 78 | + |
| 79 | + ret_events = [dict() for i in range(len(event_names))] |
| 80 | + for dir in dirs_to_look: |
| 81 | + for event_file in self._glob_events_files([dir]): |
| 82 | + for record in tf_record.tf_record_iterator(event_file): |
| 83 | + event = event_pb2.Event.FromString(record) |
| 84 | + if event.summary is None or event.wall_time is None or event.summary.value is None: |
94 | 85 | continue
|
95 |
| - if value.tag == event_name: |
96 |
| - if event.wall_time is not None: |
97 |
| - time = datetime.datetime.fromtimestamp(event.wall_time) |
98 |
| - events_time.append({'time': time, event_name: value.simple_value}) |
99 |
| - if event.step is not None: |
100 |
| - events_step.append({'step': event.step, event_name: value.simple_value}) |
101 |
| - if event_start_time is None: |
102 |
| - raise Exception('Empty or invalid TF events file. Cannot find event start time.') |
103 |
| - for event in events_time: |
104 |
| - event['time'] = event['time'] - event_start_time # convert time to timespan |
105 |
| - events_time = sorted(events_time, key=lambda k: k['time']) |
106 |
| - events_step = sorted(events_step, key=lambda k: k['step']) |
107 |
| - return events_time, events_step |
| 86 | + |
| 87 | + event_time = datetime.datetime.fromtimestamp(event.wall_time) |
| 88 | + for value in event.summary.value: |
| 89 | + if value.tag not in event_names or value.simple_value is None: |
| 90 | + continue |
| 91 | + |
| 92 | + index = event_names.index(value.tag) |
| 93 | + dir_event_dict = ret_events[index] |
| 94 | + if dir not in dir_event_dict: |
| 95 | + dir_event_dict[dir] = pd.DataFrame( |
| 96 | + [[event_time, event.step, value.simple_value]], |
| 97 | + columns=['time', 'step', 'value']) |
| 98 | + else: |
| 99 | + df = dir_event_dict[dir] |
| 100 | + # Append a row. |
| 101 | + df.loc[len(df)] = [event_time, event.step, value.simple_value] |
| 102 | + |
| 103 | + for dir_event_dict in ret_events: |
| 104 | + for df in dir_event_dict.values(): |
| 105 | + df.sort_values(by=['time'], inplace=True) |
| 106 | + |
| 107 | + return ret_events |
| 108 | + |
| 109 | + def plot(self, event_names, x_axis='step'): |
| 110 | + """Plots a list of events. Each event (a dir+event_name) is represetented as a line |
| 111 | + in the graph. |
| 112 | +
|
| 113 | + Args: |
| 114 | + event_names: A list of events to plot. Each event_name may correspond to multiple events, |
| 115 | + each in a different directory. |
| 116 | + x_axis: whether to use step or time as x axis. |
| 117 | + """ |
| 118 | + event_names = [event_names] if isinstance(event_names, basestring) else event_names |
| 119 | + events_list = self.get_events(event_names) |
| 120 | + for event_name, dir_event_dict in zip(event_names, events_list): |
| 121 | + for dir, df in dir_event_dict.iteritems(): |
| 122 | + label = event_name + ':' + dir |
| 123 | + x_column = df['step'] if x_axis == 'step' else df['time'] |
| 124 | + plt.plot(x_column, df['value'], label=label) |
| 125 | + plt.legend(loc='best') |
| 126 | + plt.show() |
| 127 | + |
0 commit comments