Skip to content

Commit 01d353e

Browse files
committed
reformat with black.
1 parent 1aa4d4b commit 01d353e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2195
-1018
lines changed

docs/source/conf.py

+45-31
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@
2121
import sys
2222
import shutil
2323
from unittest.mock import MagicMock
24+
2425
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..", "..")))
2526

2627
# Mechanism to mock out modules
2728
class ModuleMock(object):
2829
def __init__(self, *args, **kwargs):
2930
pass
3031

32+
3133
# Putting all of our dirty hacks together
3234
class Mock(MagicMock):
3335
__metaclass__ = type
@@ -59,13 +61,18 @@ def __getattr__(cls, name):
5961
"torch.utils",
6062
"torch.utils.data",
6163
"torch.utils.cpp_extension",
62-
"numpy"
64+
"numpy",
6365
]
6466
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
6567

6668
import sphinx_rtd_theme
67-
examples_source = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "examples", "tutorial"))
68-
examples_dest = os.path.abspath(os.path.join(os.path.dirname(__file__), "examples", "tutorial"))
69+
70+
examples_source = os.path.abspath(
71+
os.path.join(os.path.dirname(__file__), "..", "..", "examples", "tutorial")
72+
)
73+
examples_dest = os.path.abspath(
74+
os.path.join(os.path.dirname(__file__), "examples", "tutorial")
75+
)
6976

7077
if os.path.exists(examples_dest):
7178
shutil.rmtree(examples_dest)
@@ -91,36 +98,38 @@ def __getattr__(cls, name):
9198
# Add any Sphinx extension module names here, as strings. They can be
9299
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
93100
# ones.
94-
extensions = ['sphinx.ext.autodoc',
95-
'sphinx.ext.mathjax',
96-
'sphinx.ext.viewcode',
97-
'nbsphinx']
101+
extensions = [
102+
"sphinx.ext.autodoc",
103+
"sphinx.ext.mathjax",
104+
"sphinx.ext.viewcode",
105+
"nbsphinx",
106+
]
98107

99108
# Add any paths that contain templates here, relative to this directory.
100-
templates_path = ['ntemplates']
109+
templates_path = ["ntemplates"]
101110

102111
# The suffix(es) of source filenames.
103112
# You can specify multiple suffix as a list of string:
104113
#
105114
# source_suffix = ['.rst', '.md']
106-
source_suffix = '.rst'
115+
source_suffix = ".rst"
107116

108117
# The master toctree document.
109-
master_doc = 'index'
118+
master_doc = "index"
110119

111120
# General information about the project.
112-
project = 'QPyTorch'
113-
copyright = '2019, Tianyi Zhang, Zhiqiu Lin, Guandao Yang, Christopher De Sa'
114-
author = 'Tianyi Zhang, Zhiqiu Lin, Guandao Yang, Christopher De Sa'
121+
project = "QPyTorch"
122+
copyright = "2019, Tianyi Zhang, Zhiqiu Lin, Guandao Yang, Christopher De Sa"
123+
author = "Tianyi Zhang, Zhiqiu Lin, Guandao Yang, Christopher De Sa"
115124

116125
# The version info for the project you're documenting, acts as replacement for
117126
# |version| and |release|, also used in various other places throughout the
118127
# built documents.
119128
#
120129
# The short X.Y version.
121-
version = '0.0.1'
130+
version = "0.0.1"
122131
# The full version, including alpha/beta/rc tags.
123-
release = '0.0.1 alpha'
132+
release = "0.0.1 alpha"
124133

125134
# The language for content autogenerated by Sphinx. Refer to documentation
126135
# for a list of supported languages.
@@ -135,7 +144,7 @@ def __getattr__(cls, name):
135144
exclude_patterns = []
136145

137146
# The name of the Pygments (syntax highlighting) style to use.
138-
pygments_style = 'sphinx'
147+
pygments_style = "sphinx"
139148

140149
# If true, `todo` and `todoList` produce output, else they produce nothing.
141150
todo_include_todos = False
@@ -146,7 +155,7 @@ def __getattr__(cls, name):
146155
# The theme to use for HTML and HTML Help pages. See the documentation for
147156
# a list of builtin themes.
148157
#
149-
html_theme = 'sphinx_rtd_theme'
158+
html_theme = "sphinx_rtd_theme"
150159
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
151160
html_theme_options = {
152161
"collapse_navigation": False,
@@ -162,13 +171,13 @@ def __getattr__(cls, name):
162171
# Add any paths that contain custom static files (such as style sheets) here,
163172
# relative to this directory. They are copied after the builtin static files,
164173
# so a file named "default.css" will overwrite the builtin "default.css".
165-
html_static_path = ['nstatic']
174+
html_static_path = ["nstatic"]
166175

167176

168177
# -- Options for HTMLHelp output ------------------------------------------
169178

170179
# Output file base name for HTML help builder.
171-
htmlhelp_basename = 'QPyTorchdoc'
180+
htmlhelp_basename = "QPyTorchdoc"
172181

173182

174183
# -- Options for LaTeX output ---------------------------------------------
@@ -177,15 +186,12 @@ def __getattr__(cls, name):
177186
# The paper size ('letterpaper' or 'a4paper').
178187
#
179188
# 'papersize': 'letterpaper',
180-
181189
# The font size ('10pt', '11pt' or '12pt').
182190
#
183191
# 'pointsize': '10pt',
184-
185192
# Additional stuff for the LaTeX preamble.
186193
#
187194
# 'preamble': '',
188-
189195
# Latex figure (float) alignment
190196
#
191197
# 'figure_align': 'htbp',
@@ -195,19 +201,21 @@ def __getattr__(cls, name):
195201
# (source start file, target name, title,
196202
# author, documentclass [howto, manual, or own class]).
197203
latex_documents = [
198-
(master_doc, 'QPyTorch.tex', 'QPyTorch Documentation',
199-
'Tianyi Zhang, Zhiqiu Lin, Christopher De Sa', 'manual'),
204+
(
205+
master_doc,
206+
"QPyTorch.tex",
207+
"QPyTorch Documentation",
208+
"Tianyi Zhang, Zhiqiu Lin, Christopher De Sa",
209+
"manual",
210+
),
200211
]
201212

202213
autodoc_inherit_docstrings = False
203214
# -- Options for manual page output ---------------------------------------
204215

205216
# One entry per manual page. List of tuples
206217
# (source start file, name, description, authors, manual section).
207-
man_pages = [
208-
(master_doc, 'qpytorch', 'QPyTorch Documentation',
209-
[author], 1)
210-
]
218+
man_pages = [(master_doc, "qpytorch", "QPyTorch Documentation", [author], 1)]
211219

212220

213221
# -- Options for Texinfo output -------------------------------------------
@@ -216,7 +224,13 @@ def __getattr__(cls, name):
216224
# (source start file, target name, title, author,
217225
# dir menu entry, description, category)
218226
texinfo_documents = [
219-
(master_doc, 'QPyTorch', 'QPyTorch Documentation',
220-
author, 'QPyTorch', 'One line description of project.',
221-
'Miscellaneous'),
227+
(
228+
master_doc,
229+
"QPyTorch",
230+
"QPyTorch Documentation",
231+
author,
232+
"QPyTorch",
233+
"One line description of project.",
234+
"Miscellaneous",
235+
),
222236
]

examples/IBM8/data.py

+55-35
Original file line numberDiff line numberDiff line change
@@ -4,65 +4,85 @@
44
import torchvision.datasets as datasets
55
import os
66

7+
78
def get_data(dataset, data_path, batch_size, val_ratio, num_workers):
8-
assert dataset in ["CIFAR10", "IMAGENET12"], "dataset not supported {}".format(dataset)
9+
assert dataset in ["CIFAR10", "IMAGENET12"], "dataset not supported {}".format(
10+
dataset
11+
)
912
assert val_ratio >= 0, "invalid validation ratio: {}".format(val_ratio)
10-
print('Loading dataset {} from {}'.format(dataset, data_path))
11-
if dataset=="CIFAR10":
13+
print("Loading dataset {} from {}".format(dataset, data_path))
14+
if dataset == "CIFAR10":
1215
ds = getattr(datasets, dataset)
1316
path = os.path.join(data_path, dataset.lower())
14-
transform_train = transforms.Compose([
15-
transforms.RandomCrop(32, padding=4),
16-
transforms.RandomHorizontalFlip(),
17-
transforms.ToTensor(),
18-
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
19-
])
20-
transform_test = transforms.Compose([
21-
transforms.ToTensor(),
22-
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
23-
])
17+
transform_train = transforms.Compose(
18+
[
19+
transforms.RandomCrop(32, padding=4),
20+
transforms.RandomHorizontalFlip(),
21+
transforms.ToTensor(),
22+
transforms.Normalize(
23+
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
24+
),
25+
]
26+
)
27+
transform_test = transforms.Compose(
28+
[
29+
transforms.ToTensor(),
30+
transforms.Normalize(
31+
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
32+
),
33+
]
34+
)
2435
train_set = ds(path, train=True, download=True, transform=transform_train)
2536
val_set = ds(path, train=True, download=True, transform=transform_test)
2637
test_set = ds(path, train=False, download=True, transform=transform_test)
2738
train_sampler = None
2839
val_sampler = None
2940
num_classes = 10
30-
elif dataset=="IMAGENET12":
31-
traindir = os.path.join(data_path, dataset.lower(), 'train')
32-
valdir = os.path.join(data_path, dataset.lower(), 'val')
33-
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
34-
std=[0.229, 0.224, 0.225])
41+
elif dataset == "IMAGENET12":
42+
traindir = os.path.join(data_path, dataset.lower(), "train")
43+
valdir = os.path.join(data_path, dataset.lower(), "val")
44+
normalize = transforms.Normalize(
45+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
46+
)
3547
train_set = datasets.ImageFolder(
3648
traindir,
37-
transforms.Compose([
38-
transforms.RandomResizedCrop(224),
39-
transforms.RandomHorizontalFlip(),
40-
transforms.ToTensor(),
41-
normalize,
42-
]))
43-
test_set = datasets.ImageFolder(valdir, transforms.Compose([
44-
transforms.Resize(256),
45-
transforms.CenterCrop(224),
46-
transforms.ToTensor(),
47-
normalize,
48-
]))
49+
transforms.Compose(
50+
[
51+
transforms.RandomResizedCrop(224),
52+
transforms.RandomHorizontalFlip(),
53+
transforms.ToTensor(),
54+
normalize,
55+
]
56+
),
57+
)
58+
test_set = datasets.ImageFolder(
59+
valdir,
60+
transforms.Compose(
61+
[
62+
transforms.Resize(256),
63+
transforms.CenterCrop(224),
64+
transforms.ToTensor(),
65+
normalize,
66+
]
67+
),
68+
)
4969
train_sampler = None
5070
val_sampler = None
5171
num_classes = 1000
5272
loaders = {
53-
'train': torch.utils.data.DataLoader(
73+
"train": torch.utils.data.DataLoader(
5474
train_set,
5575
batch_size=batch_size,
5676
shuffle=True,
5777
num_workers=num_workers,
58-
pin_memory=True
78+
pin_memory=True,
5979
),
60-
'test': torch.utils.data.DataLoader(
80+
"test": torch.utils.data.DataLoader(
6181
test_set,
6282
batch_size=batch_size,
6383
shuffle=False,
6484
num_workers=num_workers,
65-
pin_memory=True
66-
)
85+
pin_memory=True,
86+
),
6787
}
6888
return loaders

0 commit comments

Comments
 (0)