-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodelcard.py
240 lines (197 loc) · 11.1 KB
/
modelcard.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Configuration base class and utilities."""
import copy
import json
import logging
import os
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .file_utils import (
CONFIG_NAME,
MODEL_CARD_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
cached_path,
hf_bucket_url,
is_remote_url,
)
logger = logging.getLogger(__name__)
class ModelCard(object):
r""" Model Card class.
Store model card as well as methods for loading/downloading/saving model cards.
Please read the following paper for details and explanation on the sections:
"Model Cards for Model Reporting"
by Margaret Mitchell, Simone Wu,
Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer,
Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards.
Link: https://arxiv.org/abs/1810.03993
Note:
A model card can be loaded and saved to disk.
Parameters:
"""
def __init__(self, **kwargs):
# Recomended attributes from https://arxiv.org/abs/1810.03993 (see papers)
self.model_details = kwargs.pop("model_details", {})
self.intended_use = kwargs.pop("intended_use", {})
self.factors = kwargs.pop("factors", {})
self.metrics = kwargs.pop("metrics", {})
self.evaluation_data = kwargs.pop("evaluation_data", {})
self.training_data = kwargs.pop("training_data", {})
self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
self.ethical_considerations = kwargs.pop("ethical_considerations", {})
self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})
# Open additional attributes
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error("Can't set {} with value {} for {}".format(key, value, self))
raise err
def save_pretrained(self, save_directory_or_file):
""" Save a model card object to the directory or file `save_directory_or_file`.
"""
if os.path.isdir(save_directory_or_file):
# If we save using the predefined names, we can load using `from_pretrained`
output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
else:
output_model_card_file = save_directory_or_file
self.to_json_file(output_model_card_file)
logger.info("Model card saved in {}".format(output_model_card_file))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a :class:`~transformers.ModelCard` from a pre-trained model model card.
Parameters:
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model card to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model card that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a mode card file saved using the :func:`~transformers.ModelCard.save_pretrained` method, e.g.: ``./my_model_directory/``.
- a path or url to a saved model card JSON `file`, e.g.: ``./my_model_directory/modelcard.json``.
cache_dir: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
card should be cached if the standard cache should not be used.
kwargs: (`optional`) dict: key/value pairs with which to update the ModelCard object after loading.
- The values in kwargs of any keys which are model card attributes will be used to override the loaded values.
- Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the `return_unused_kwargs` keyword parameter.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
find_from_standard_name: (`optional`) boolean, default True:
If the pretrained_model_name_or_path ends with our standard model or config filenames, replace them with our standard modelcard filename.
Can be used to directly feed a model/config url and access the colocated modelcard.
return_unused_kwargs: (`optional`) bool:
- If False, then this function returns just the final model card object.
- If True, then this functions returns a tuple `(model card, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not model card attributes: ie the part of kwargs which has not been used to update `ModelCard` and is otherwise ignored.
Examples::
modelcard = ModelCard.from_pretrained('bert-base-uncased') # Download model card from S3 and cache.
modelcard = ModelCard.from_pretrained('./test/saved_model/') # E.g. model card was saved using `save_pretrained('./test/saved_model/')`
modelcard = ModelCard.from_pretrained('./test/saved_model/modelcard.json')
modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
"""
cache_dir = kwargs.pop("cache_dir", None)
proxies = kwargs.pop("proxies", None)
find_from_standard_name = kwargs.pop("find_from_standard_name", True)
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
# For simplicity we use the same pretrained url than the configuration files
# but with a different suffix (modelcard.json). This suffix is replaced below.
model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
model_card_file = pretrained_model_name_or_path
else:
model_card_file = hf_bucket_url(pretrained_model_name_or_path, postfix=MODEL_CARD_NAME)
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
model_card_file = model_card_file.replace(WEIGHTS_NAME, MODEL_CARD_NAME)
model_card_file = model_card_file.replace(TF2_WEIGHTS_NAME, MODEL_CARD_NAME)
try:
# Load from URL or cache if already cached
resolved_model_card_file = cached_path(
model_card_file, cache_dir=cache_dir, force_download=True, proxies=proxies, resume_download=False
)
if resolved_model_card_file == model_card_file:
logger.info("loading model card file {}".format(model_card_file))
else:
logger.info(
"loading model card file {} from cache at {}".format(model_card_file, resolved_model_card_file)
)
# Load model card
modelcard = cls.from_json_file(resolved_model_card_file)
except EnvironmentError:
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.warning("Couldn't reach server at '{}' to download model card file.".format(model_card_file))
else:
logger.warning(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url to a model card file named {} or "
"a directory containing such a file but couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path,
", ".join(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
model_card_file,
MODEL_CARD_NAME,
)
)
logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card
modelcard = cls()
except json.JSONDecodeError:
logger.warning(
"Couldn't reach server at '{}' to download model card file or "
"model card file is not a valid JSON file. "
"Please check network or file content here: {}.".format(model_card_file, resolved_model_card_file)
)
logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card
modelcard = cls()
# Update model card with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(modelcard, key):
setattr(modelcard, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model card: %s", str(modelcard))
if return_unused_kwargs:
return modelcard, kwargs
else:
return modelcard
@classmethod
def from_dict(cls, json_object):
"""Constructs a `ModelCard` from a Python dictionary of parameters."""
return cls(**json_object)
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `ModelCard` from a json file of parameters."""
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
dict_obj = json.loads(text)
return cls(**dict_obj)
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())