Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
LearningHx committed Aug 8, 2023
0 parents commit 70b58df
Show file tree
Hide file tree
Showing 53 changed files with 2,230 additions and 0 deletions.
52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
## StyA2K

This repository is an implementation of the ICCV 2023 paper "All-to-key Attention for Arbitrary Style Transfer".

### Requirements

+ Ubuntu 18.04
+ Anaconda (Python, Numpy, PIL, etc.)
+ PyTorch 1.9.0
+ torchvision 0.10.0

### Getting Started

* Inference:

* Download [vgg_normalised.pth](https://drive.google.com/file/d/1BinnwM5AmIcVubr16tPTqxMjUCE8iu5M/view?usp=sharing).

* The pre-trained models are right in the ./checkpoints/A2K directory, including: latest_net_A2K.pth, latest_net_decoder.pth, and latest_net_transform.pth

* Configure content_path and style_path in test_A2K.sh to specify the paths to testing content and style images folders, respectively.

* Run:

```shell
bash test_A2K.sh
```

* Check the results under the ./results/A2K directory.

* Train:

* Download [vgg_normalised.pth](https://drive.google.com/file/d/1BinnwM5AmIcVubr16tPTqxMjUCE8iu5M/view?usp=sharing).

* Download [COCO dataset](http://images.cocodataset.org/zips/train2014.zip) and [WikiArt dataset](http://web.fsktm.um.edu.my/~cschan/source/ICIP2017/wikiart.zip).

* Configure content_path, style_path, and image_encoder_path in train_A2K.sh to specify the paths to training content images folders, training style images folders, and "vgg_normalised.pth", respectively.


* Then, simply run:

```shell
bash train_A2K.sh
```

* Monitor the training status at http://localhost:8097/. Trained models would be saved in the ./checkpoints/A2k folder.

* Try other training options in train_A2K.sh.


### Acknowledgments

* This code builds heavily on **[pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)** and **[AdaAttN](https://github.com/Huage001/AdaAttN)**. Thanks for open-sourcing!
Binary file added checkpoints/A2K/latest_net_A2K.pth
Binary file not shown.
Binary file added checkpoints/A2K/latest_net_decoder.pth
Binary file not shown.
Binary file added checkpoints/A2K/latest_net_transform.pth
Binary file not shown.
86 changes: 86 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""This package includes all the modules related to data loading and preprocessing
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
You need to implement four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point from data loader.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from data.base_dataset import BaseDataset


def find_dataset_using_name(dataset_name):
"""Import the module "data/[dataset_name]_dataset.py".
In the file, the class called DatasetNameDataset() will
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
"""
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)

dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls

if dataset is None:
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))

return dataset


def get_option_setter(dataset_name):
"""Return the static method <modify_commandline_options> of the dataset class."""
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options


class CustomDatasetDataLoader:

def __init__(self, opt, sampler=None):
self.opt = opt
dataset_class = find_dataset_using_name(opt.dataset_mode)
self.dataset = dataset_class(opt)
phase = 'training' if opt.isTrain else 'test'
print("[%s] dataset [%s] was created" % (phase, type(self.dataset).__name__))
if sampler is None:
self.data_loader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle= not opt.serial_batches,
num_workers=int(opt.num_threads))
else:
self.data_loader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
num_workers=int(opt.num_threads),
sampler=sampler(self.dataset))

def load_data(self):
return self

def __len__(self):
"""Return the number of data in the dataset"""
return min(len(self.dataset), self.opt.max_dataset_size)

def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.data_loader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data


def create_dataset(opt, sampler=None):
data_loader = CustomDatasetDataLoader(opt, sampler)
dataset = data_loader.load_data()
return dataset
Binary file added data/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file added data/__pycache__/base_dataset.cpython-38.pyc
Binary file not shown.
Binary file added data/__pycache__/image_folder.cpython-38.pyc
Binary file not shown.
Binary file added data/__pycache__/unaligned_dataset.cpython-38.pyc
Binary file not shown.
147 changes: 147 additions & 0 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
"""
import random
import os
import numpy as np
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from abc import ABC, abstractmethod


class BaseDataset(data.Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets.
To create a subclass, you need to implement the following four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
"""

def __init__(self, opt):
self.opt = opt

@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser

@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
return 0

@abstractmethod
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns:
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
"""
pass


def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.preprocess == 'resize_and_crop':
new_h = new_w = opt.load_size
elif opt.preprocess == 'scale_width_and_crop':
new_w = opt.load_size
new_h = opt.load_size * h // w

x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))

flip = random.random() > 0.5

return {'crop_pos': (x, y), 'flip': flip}


def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC):
transform_list = []
if grayscale:
transform_list.append(transforms.Grayscale(1))
if 'resize' in opt.preprocess:
osize = [int(opt.load_size / opt.load_ratio), opt.load_size]
transform_list.append(transforms.Resize(osize, method))
elif 'scale_width' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))

if 'crop' in opt.preprocess:
if params is None:
transform_list.append(transforms.RandomCrop((int(opt.crop_size / opt.crop_ratio), opt.crop_size)))
else:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

if opt.preprocess == 'none':
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))

if not opt.no_flip:
if params is None:
transform_list.append(transforms.RandomHorizontalFlip())
elif params['flip']:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

transform_list += [transforms.ToTensor()]
return transforms.Compose(transform_list)


def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if h == oh and w == ow:
return img

__print_size_warning(ow, oh, w, h)
return img.resize((w, h), method)


def __scale_width(img, target_size, crop_size, method=Image.BICUBIC):
ow, oh = img.size
if ow == target_size and oh >= crop_size:
return img
w = target_size
h = int(max(target_size * oh / ow, crop_size))
return img.resize((w, h), method)


def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img


def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img


def __print_size_warning(ow, oh, w, h):
"""Print warning information about image size(only print once)"""
if not hasattr(__print_size_warning, 'has_printed'):
print("The image size needs to be a multiple of 4. "
"The loaded image size was (%d, %d), so it was adjusted to "
"(%d, %d). This adjustment will be done to all images "
"whose sizes are not multiples of 4" % (ow, oh, w, h))
__print_size_warning.has_printed = True
64 changes: 64 additions & 0 deletions data/image_folder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""A modified image folder class
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""

import torch.utils.data as data

from PIL import Image
import os

IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
'.tif', '.TIF', '.tiff', '.TIFF',
]


def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir, max_dataset_size=float("inf")):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir

for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]


def default_loader(path):
return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):

def __init__(self, root, transform=None, return_paths=False, loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader

def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img

def __len__(self):
return len(self.imgs)
Loading

0 comments on commit 70b58df

Please sign in to comment.