-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 70b58df
Showing
53 changed files
with
2,230 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
## StyA2K | ||
|
||
This repository is an implementation of the ICCV 2023 paper "All-to-key Attention for Arbitrary Style Transfer". | ||
|
||
### Requirements | ||
|
||
+ Ubuntu 18.04 | ||
+ Anaconda (Python, Numpy, PIL, etc.) | ||
+ PyTorch 1.9.0 | ||
+ torchvision 0.10.0 | ||
|
||
### Getting Started | ||
|
||
* Inference: | ||
|
||
* Download [vgg_normalised.pth](https://drive.google.com/file/d/1BinnwM5AmIcVubr16tPTqxMjUCE8iu5M/view?usp=sharing). | ||
|
||
* The pre-trained models are right in the ./checkpoints/A2K directory, including: latest_net_A2K.pth, latest_net_decoder.pth, and latest_net_transform.pth | ||
|
||
* Configure content_path and style_path in test_A2K.sh to specify the paths to testing content and style images folders, respectively. | ||
|
||
* Run: | ||
|
||
```shell | ||
bash test_A2K.sh | ||
``` | ||
|
||
* Check the results under the ./results/A2K directory. | ||
|
||
* Train: | ||
|
||
* Download [vgg_normalised.pth](https://drive.google.com/file/d/1BinnwM5AmIcVubr16tPTqxMjUCE8iu5M/view?usp=sharing). | ||
|
||
* Download [COCO dataset](http://images.cocodataset.org/zips/train2014.zip) and [WikiArt dataset](http://web.fsktm.um.edu.my/~cschan/source/ICIP2017/wikiart.zip). | ||
|
||
* Configure content_path, style_path, and image_encoder_path in train_A2K.sh to specify the paths to training content images folders, training style images folders, and "vgg_normalised.pth", respectively. | ||
|
||
|
||
* Then, simply run: | ||
|
||
```shell | ||
bash train_A2K.sh | ||
``` | ||
|
||
* Monitor the training status at http://localhost:8097/. Trained models would be saved in the ./checkpoints/A2k folder. | ||
|
||
* Try other training options in train_A2K.sh. | ||
|
||
|
||
### Acknowledgments | ||
|
||
* This code builds heavily on **[pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)** and **[AdaAttN](https://github.com/Huage001/AdaAttN)**. Thanks for open-sourcing! |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
"""This package includes all the modules related to data loading and preprocessing | ||
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. | ||
You need to implement four functions: | ||
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). | ||
-- <__len__>: return the size of dataset. | ||
-- <__getitem__>: get a data point from data loader. | ||
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options. | ||
Now you can use the dataset class by specifying flag '--dataset_mode dummy'. | ||
See our template dataset class 'template_dataset.py' for more details. | ||
""" | ||
import importlib | ||
import torch.utils.data | ||
from data.base_dataset import BaseDataset | ||
|
||
|
||
def find_dataset_using_name(dataset_name): | ||
"""Import the module "data/[dataset_name]_dataset.py". | ||
In the file, the class called DatasetNameDataset() will | ||
be instantiated. It has to be a subclass of BaseDataset, | ||
and it is case-insensitive. | ||
""" | ||
dataset_filename = "data." + dataset_name + "_dataset" | ||
datasetlib = importlib.import_module(dataset_filename) | ||
|
||
dataset = None | ||
target_dataset_name = dataset_name.replace('_', '') + 'dataset' | ||
for name, cls in datasetlib.__dict__.items(): | ||
if name.lower() == target_dataset_name.lower() \ | ||
and issubclass(cls, BaseDataset): | ||
dataset = cls | ||
|
||
if dataset is None: | ||
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) | ||
|
||
return dataset | ||
|
||
|
||
def get_option_setter(dataset_name): | ||
"""Return the static method <modify_commandline_options> of the dataset class.""" | ||
dataset_class = find_dataset_using_name(dataset_name) | ||
return dataset_class.modify_commandline_options | ||
|
||
|
||
class CustomDatasetDataLoader: | ||
|
||
def __init__(self, opt, sampler=None): | ||
self.opt = opt | ||
dataset_class = find_dataset_using_name(opt.dataset_mode) | ||
self.dataset = dataset_class(opt) | ||
phase = 'training' if opt.isTrain else 'test' | ||
print("[%s] dataset [%s] was created" % (phase, type(self.dataset).__name__)) | ||
if sampler is None: | ||
self.data_loader = torch.utils.data.DataLoader( | ||
self.dataset, | ||
batch_size=opt.batch_size, | ||
shuffle= not opt.serial_batches, | ||
num_workers=int(opt.num_threads)) | ||
else: | ||
self.data_loader = torch.utils.data.DataLoader( | ||
self.dataset, | ||
batch_size=opt.batch_size, | ||
num_workers=int(opt.num_threads), | ||
sampler=sampler(self.dataset)) | ||
|
||
def load_data(self): | ||
return self | ||
|
||
def __len__(self): | ||
"""Return the number of data in the dataset""" | ||
return min(len(self.dataset), self.opt.max_dataset_size) | ||
|
||
def __iter__(self): | ||
"""Return a batch of data""" | ||
for i, data in enumerate(self.data_loader): | ||
if i * self.opt.batch_size >= self.opt.max_dataset_size: | ||
break | ||
yield data | ||
|
||
|
||
def create_dataset(opt, sampler=None): | ||
data_loader = CustomDatasetDataLoader(opt, sampler) | ||
dataset = data_loader.load_data() | ||
return dataset |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. | ||
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. | ||
""" | ||
import random | ||
import os | ||
import numpy as np | ||
import torch.utils.data as data | ||
from PIL import Image | ||
import torchvision.transforms as transforms | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class BaseDataset(data.Dataset, ABC): | ||
"""This class is an abstract base class (ABC) for datasets. | ||
To create a subclass, you need to implement the following four functions: | ||
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). | ||
-- <__len__>: return the size of dataset. | ||
-- <__getitem__>: get a data point. | ||
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options. | ||
""" | ||
|
||
def __init__(self, opt): | ||
self.opt = opt | ||
|
||
@staticmethod | ||
def modify_commandline_options(parser, is_train): | ||
"""Add new dataset-specific options, and rewrite default values for existing options. | ||
Parameters: | ||
parser -- original option parser | ||
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. | ||
Returns: | ||
the modified parser. | ||
""" | ||
return parser | ||
|
||
@abstractmethod | ||
def __len__(self): | ||
"""Return the total number of images in the dataset.""" | ||
return 0 | ||
|
||
@abstractmethod | ||
def __getitem__(self, index): | ||
"""Return a data point and its metadata information. | ||
Parameters: | ||
index - - a random integer for data indexing | ||
Returns: | ||
a dictionary of data with their names. It ususally contains the data itself and its metadata information. | ||
""" | ||
pass | ||
|
||
|
||
def get_params(opt, size): | ||
w, h = size | ||
new_h = h | ||
new_w = w | ||
if opt.preprocess == 'resize_and_crop': | ||
new_h = new_w = opt.load_size | ||
elif opt.preprocess == 'scale_width_and_crop': | ||
new_w = opt.load_size | ||
new_h = opt.load_size * h // w | ||
|
||
x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) | ||
y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) | ||
|
||
flip = random.random() > 0.5 | ||
|
||
return {'crop_pos': (x, y), 'flip': flip} | ||
|
||
|
||
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC): | ||
transform_list = [] | ||
if grayscale: | ||
transform_list.append(transforms.Grayscale(1)) | ||
if 'resize' in opt.preprocess: | ||
osize = [int(opt.load_size / opt.load_ratio), opt.load_size] | ||
transform_list.append(transforms.Resize(osize, method)) | ||
elif 'scale_width' in opt.preprocess: | ||
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method))) | ||
|
||
if 'crop' in opt.preprocess: | ||
if params is None: | ||
transform_list.append(transforms.RandomCrop((int(opt.crop_size / opt.crop_ratio), opt.crop_size))) | ||
else: | ||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) | ||
|
||
if opt.preprocess == 'none': | ||
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) | ||
|
||
if not opt.no_flip: | ||
if params is None: | ||
transform_list.append(transforms.RandomHorizontalFlip()) | ||
elif params['flip']: | ||
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) | ||
|
||
transform_list += [transforms.ToTensor()] | ||
return transforms.Compose(transform_list) | ||
|
||
|
||
def __make_power_2(img, base, method=Image.BICUBIC): | ||
ow, oh = img.size | ||
h = int(round(oh / base) * base) | ||
w = int(round(ow / base) * base) | ||
if h == oh and w == ow: | ||
return img | ||
|
||
__print_size_warning(ow, oh, w, h) | ||
return img.resize((w, h), method) | ||
|
||
|
||
def __scale_width(img, target_size, crop_size, method=Image.BICUBIC): | ||
ow, oh = img.size | ||
if ow == target_size and oh >= crop_size: | ||
return img | ||
w = target_size | ||
h = int(max(target_size * oh / ow, crop_size)) | ||
return img.resize((w, h), method) | ||
|
||
|
||
def __crop(img, pos, size): | ||
ow, oh = img.size | ||
x1, y1 = pos | ||
tw = th = size | ||
if (ow > tw or oh > th): | ||
return img.crop((x1, y1, x1 + tw, y1 + th)) | ||
return img | ||
|
||
|
||
def __flip(img, flip): | ||
if flip: | ||
return img.transpose(Image.FLIP_LEFT_RIGHT) | ||
return img | ||
|
||
|
||
def __print_size_warning(ow, oh, w, h): | ||
"""Print warning information about image size(only print once)""" | ||
if not hasattr(__print_size_warning, 'has_printed'): | ||
print("The image size needs to be a multiple of 4. " | ||
"The loaded image size was (%d, %d), so it was adjusted to " | ||
"(%d, %d). This adjustment will be done to all images " | ||
"whose sizes are not multiples of 4" % (ow, oh, w, h)) | ||
__print_size_warning.has_printed = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
"""A modified image folder class | ||
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) | ||
so that this class can load images from both current directory and its subdirectories. | ||
""" | ||
|
||
import torch.utils.data as data | ||
|
||
from PIL import Image | ||
import os | ||
|
||
IMG_EXTENSIONS = [ | ||
'.jpg', '.JPG', '.jpeg', '.JPEG', | ||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', | ||
'.tif', '.TIF', '.tiff', '.TIFF', | ||
] | ||
|
||
|
||
def is_image_file(filename): | ||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) | ||
|
||
|
||
def make_dataset(dir, max_dataset_size=float("inf")): | ||
images = [] | ||
assert os.path.isdir(dir), '%s is not a valid directory' % dir | ||
|
||
for root, _, fnames in sorted(os.walk(dir)): | ||
for fname in fnames: | ||
if is_image_file(fname): | ||
path = os.path.join(root, fname) | ||
images.append(path) | ||
return images[:min(max_dataset_size, len(images))] | ||
|
||
|
||
def default_loader(path): | ||
return Image.open(path).convert('RGB') | ||
|
||
|
||
class ImageFolder(data.Dataset): | ||
|
||
def __init__(self, root, transform=None, return_paths=False, loader=default_loader): | ||
imgs = make_dataset(root) | ||
if len(imgs) == 0: | ||
raise(RuntimeError("Found 0 images in: " + root + "\n" | ||
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) | ||
|
||
self.root = root | ||
self.imgs = imgs | ||
self.transform = transform | ||
self.return_paths = return_paths | ||
self.loader = loader | ||
|
||
def __getitem__(self, index): | ||
path = self.imgs[index] | ||
img = self.loader(path) | ||
if self.transform is not None: | ||
img = self.transform(img) | ||
if self.return_paths: | ||
return img, path | ||
else: | ||
return img | ||
|
||
def __len__(self): | ||
return len(self.imgs) |
Oops, something went wrong.