Skip to content
13 changes: 7 additions & 6 deletions advanced_source/neural_style_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
developed by Leon A. Gatys, Alexander S. Ecker and Matthias Bethge.
Neural-Style, or Neural-Transfer, allows you to take an image and
reproduce it with a new artistic style. The algorithm takes three images,
an input image, a content-image, and a style-image, and changes the input
an input image, a content-image, and a style-image, and changes the input
to resemble the content of the content-image and the artistic style of the style-image.


Expand Down Expand Up @@ -70,6 +70,7 @@
# method is used to move tensors or modules to a desired device.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)

######################################################################
# Loading the Images
Expand Down Expand Up @@ -261,7 +262,7 @@ def forward(self, input):
# network to evaluation mode using ``.eval()``.
#

cnn = models.vgg19(pretrained=True).features.to(device).eval()
cnn = models.vgg19(pretrained=True).features.eval()



Expand All @@ -271,8 +272,8 @@ def forward(self, input):
# We will use them to normalize the image before sending it into the network.
#

cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])

# create a module to normalize input image so we can easily put it in a
# ``nn.Sequential``
Expand Down Expand Up @@ -308,7 +309,7 @@ def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
content_layers=content_layers_default,
style_layers=style_layers_default):
# normalization module
normalization = Normalization(normalization_mean, normalization_std).to(device)
normalization = Normalization(normalization_mean, normalization_std)

# just in order to have an iterable access to or list of content/style
# losses
Expand Down Expand Up @@ -373,7 +374,7 @@ def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
#
# ::
#
# input_img = torch.randn(content_img.data.size(), device=device)
# input_img = torch.randn(content_img.data.size())

# add the original input image to the figure:
plt.figure()
Expand Down
14 changes: 7 additions & 7 deletions beginner_source/examples_autograd/polynomial_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@
import math

dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)

# Create Tensors to hold input and outputs.
# By default, requires_grad=False, which indicates that we do not need to
# compute gradients with respect to these Tensors during the backward pass.
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
x = torch.linspace(-math.pi, math.pi, 2000, dtype=dtype)
y = torch.sin(x)

# Create random Tensors for weights. For a third order polynomial, we need
# 4 weights: y = a + b x + c x^2 + d x^3
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
a = torch.randn((), device=device, dtype=dtype, requires_grad=True)
b = torch.randn((), device=device, dtype=dtype, requires_grad=True)
c = torch.randn((), device=device, dtype=dtype, requires_grad=True)
d = torch.randn((), device=device, dtype=dtype, requires_grad=True)
a = torch.randn((), dtype=dtype, requires_grad=True)
b = torch.randn((), dtype=dtype, requires_grad=True)
c = torch.randn((), dtype=dtype, requires_grad=True)
d = torch.randn((), dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(2000):
Expand Down
15 changes: 9 additions & 6 deletions recipes_source/recipes/amp_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,14 @@ def make_model(in_size, out_size, num_layers):
num_batches = 50
epochs = 3

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)

# Creates data in default precision.
# The same data is used for both default and mixed precision trials below.
# You don't need to manually change inputs' ``dtype`` when enabling mixed precision.
data = [torch.randn(batch_size, in_size, device="cuda") for _ in range(num_batches)]
targets = [torch.randn(batch_size, out_size, device="cuda") for _ in range(num_batches)]
data = [torch.randn(batch_size, in_size) for _ in range(num_batches)]
targets = [torch.randn(batch_size, out_size) for _ in range(num_batches)]

loss_fn = torch.nn.MSELoss().cuda()

Expand Down Expand Up @@ -116,7 +119,7 @@ def make_model(in_size, out_size, num_layers):
for epoch in range(0): # 0 epochs, this section is for illustration only
for input, target in zip(data, targets):
# Runs the forward pass under ``autocast``.
with torch.autocast(device_type='cuda', dtype=torch.float16):
with torch.autocast(device_type=device, dtype=torch.float16):
output = net(input)
# output is float16 because linear layers ``autocast`` to float16.
assert output.dtype is torch.float16
Expand Down Expand Up @@ -151,7 +154,7 @@ def make_model(in_size, out_size, num_layers):

for epoch in range(0): # 0 epochs, this section is for illustration only
for input, target in zip(data, targets):
with torch.autocast(device_type='cuda', dtype=torch.float16):
with torch.autocast(device_type=device, dtype=torch.float16):
output = net(input)
loss = loss_fn(output, target)

Expand Down Expand Up @@ -184,7 +187,7 @@ def make_model(in_size, out_size, num_layers):
start_timer()
for epoch in range(epochs):
for input, target in zip(data, targets):
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
output = net(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
Expand All @@ -202,7 +205,7 @@ def make_model(in_size, out_size, num_layers):

for epoch in range(0): # 0 epochs, this section is for illustration only
for input, target in zip(data, targets):
with torch.autocast(device_type='cuda', dtype=torch.float16):
with torch.autocast(device_type=device, dtype=torch.float16):
output = net(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
Expand Down
2 changes: 1 addition & 1 deletion recipes_source/recipes/tuning_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def fused_gelu(x):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Instead of calling ``torch.rand(size).cuda()`` to generate a random tensor,
# produce the output directly on the target device:
# ``torch.rand(size, device=torch.device('cuda'))``.
# ``torch.rand(size, device='cuda')``.
#
# This is applicable to all functions which create new tensors and accept
# ``device`` argument:
Expand Down