Skip to content

Commit

Permalink
[MXNET-470] Gluon Style Transfer Example (apache#11044)
Browse files Browse the repository at this point in the history
* fix style transfer example training

* update model path

* rm test

* update path

* model and hybrid block

* hybrid training

* Revert "hybrid training"

This reverts commit 5bf39c1.

* Revert "model and hybrid block"

This reverts commit f50115a.

* download

* rm comments

* new model path
  • Loading branch information
zhanghang1989 authored and szha committed May 30, 2018
1 parent 92286c9 commit 714e296
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 47 deletions.
13 changes: 8 additions & 5 deletions example/gluon/style_transfer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def train(args):
utils.ToTensor(ctx),
])
train_dataset = data.ImageFolder(args.dataset, transform)
train_loader = gluon.data.DataLoader(train_dataset, batch_size=args.batch_size, last_batch='discard')
train_loader = gluon.data.DataLoader(train_dataset, batch_size=args.batch_size,
last_batch='discard')
style_loader = utils.StyleLoader(args.style_folder, args.style_size, ctx=ctx)
print('len(style_loader):',style_loader.size())
# models
Expand Down Expand Up @@ -79,7 +80,7 @@ def train(args):
xc = utils.subtract_imagenet_mean_preprocess_batch(x.copy())
f_xc_c = vgg(xc)[1]
with autograd.record():
style_model.setTarget(style_image)
style_model.set_target(style_image)
y = style_model(x)

y = utils.subtract_imagenet_mean_batch(y)
Expand All @@ -92,7 +93,8 @@ def train(args):
gram_y = net.gram_matrix(features_y[m])
_, C, _ = gram_style[m].shape
gram_s = F.expand_dims(gram_style[m], 0).broadcast_to((args.batch_size, 1, C, C))
style_loss = style_loss + 2 * args.style_weight * mse_loss(gram_y, gram_s[:n_batch, :, :])
style_loss = style_loss + 2 * args.style_weight * \
mse_loss(gram_y, gram_s[:n_batch, :, :])

total_loss = content_loss + style_loss
total_loss.backward()
Expand All @@ -115,7 +117,8 @@ def train(args):

if (batch_id + 1) % (4 * args.log_interval) == 0:
# save model
save_model_filename = "Epoch_" + str(e) + "iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
save_model_filename = "Epoch_" + str(e) + "iters_" + \
str(count) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
args.content_weight) + "_" + str(args.style_weight) + ".params"
save_model_path = os.path.join(args.save_model_dir, save_model_filename)
style_model.save_params(save_model_path)
Expand All @@ -142,7 +145,7 @@ def evaluate(args):
style_model = net.Net(ngf=args.ngf)
style_model.load_params(args.model, ctx=ctx)
# forward
style_model.setTarget(style_image)
style_model.set_target(style_image)
output = style_model(content_image)
utils.tensor_save_bgrimage(output[0], args.output_image, args.cuda)

Expand Down
12 changes: 11 additions & 1 deletion example/gluon/style_transfer/models/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
# specific language governing permissions and limitations
# under the License.

import os
import zipfile
import shutil
from mxnet.test_utils import download

download('https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/models/21styles-32f7205c.params', 'models/21styles.params')
zip_file_path = 'models/msgnet_21styles.zip'
download('https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/models/msgnet_21styles-2cb88353.zip', zip_file_path)

with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall()

os.remove(zip_file_path)

shutil.move('msgnet_21styles-2cb88353.params', 'models/21styles.params')
48 changes: 7 additions & 41 deletions example/gluon/style_transfer/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,6 @@ def __init__(self, in_channels, out_channels, kernel_size,
stride, upsample=None):
super(UpsampleConvLayer, self).__init__()
self.upsample = upsample
"""
if upsample:
self.upsample_layer = torch.nn.UpsamplingNearest2d(scale_factor=upsample)
"""
self.reflection_padding = int(np.floor(kernel_size / 2))
self.conv2d = nn.Conv2D(in_channels=in_channels,
channels=out_channels,
Expand All @@ -165,10 +161,6 @@ def __init__(self, in_channels, out_channels, kernel_size,
def forward(self, x):
if self.upsample:
x = F.UpSampling(x, scale=self.upsample, sample_type='nearest')
"""
if self.reflection_padding != 0:
x = self.reflection_pad(x)
"""
out = self.conv2d(x)
return out

Expand Down Expand Up @@ -222,16 +214,16 @@ def __init__(self, input_nc=3, output_nc=3, ngf=64,
self.model.add(ConvLayer(16*expansion, output_nc, kernel_size=7, stride=1))


def setTarget(self, Xs):
def set_target(self, Xs):
F = self.model1(Xs)
G = self.gram(F)
self.ins.setTarget(G)
self.ins.set_target(G)

def forward(self, input):
return self.model(input)


class Inspiration(HybridBlock):
class Inspiration(Block):
""" Inspiration Layer (from MSG-Net paper)
tuning the featuremap with target Gram Matrix
ref https://arxiv.org/abs/1703.06953
Expand All @@ -243,17 +235,14 @@ def __init__(self, C, B=1):
self.weight = self.params.get('weight', shape=(1,C,C),
init=mx.initializer.Uniform(),
allow_deferred_init=True)
self.gram = self.params.get('gram', shape=(B,C,C),
init=mx.initializer.Uniform(),
allow_deferred_init=True,
lr_mult=0)
self.gram = F.random.uniform(shape=(B, C, C))

def setTarget(self, target):
self.gram.set_data(target)
def set_target(self, target):
self.gram = target

def forward(self, X):
# input X is a 3D feature map
self.P = F.batch_dot(F.broadcast_to(self.weight.data(), shape=(self.gram.shape)), self.gram.data())
self.P = F.batch_dot(F.broadcast_to(self.weight.data(), shape=(self.gram.shape)), self.gram)
return F.batch_dot(F.SwapAxis(self.P,1,2).broadcast_to((X.shape[0], self.C, self.C)), X.reshape((0,0,X.shape[2]*X.shape[3]))).reshape(X.shape)

def __repr__(self):
Expand Down Expand Up @@ -305,26 +294,3 @@ def forward(self, X):
relu4_3 = h

return [relu1_2, relu2_2, relu3_3, relu4_3]


def test_InstanceNorm():
import torch
from torch import nn as nn2
from torch.autograd import Variable
tx = Variable(torch.Tensor(1, 2, 200, 300).uniform_(0,1))
tlayer = nn2.InstanceNorm2d(2)
ty = tlayer(tx)

mlayer = InstanceNorm(2)
ctx = mx.cpu(0)
mlayer.initialize(ctx=ctx)
mmx = (mx.nd.array(tx.data.numpy())).as_in_context(ctx)
my = mlayer(mmx)
print('tx',tx)
print('mmx',mmx)
print('ty',ty)
print('my',my)

if __name__ == "__main__":
test_InstanceNorm()

0 comments on commit 714e296

Please sign in to comment.