Skip to content

Commit

Permalink
Make the PyTorch and TF following same coding/naming conventions + su…
Browse files Browse the repository at this point in the history
…pport of variable ratios and high-resolutions
  • Loading branch information
dorarad authored Feb 7, 2022
1 parent 8961ea2 commit 55194b7
Show file tree
Hide file tree
Showing 36 changed files with 2,428 additions and 651 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,16 @@ python generate.py --gpus 0 --model gdrive:bedrooms-snapshot.pkl --output-dir im
**You can use `--truncation-psi` to control the generated images quality/diversity trade-off.
We recommend trying out different values in the range of `0.6-1.0`.**

We currently provide pretrained models for resolution 256×256 but keep training them and will release newer checkpoints as well as pretrained models for resolution 1024×1024 soon!
### Pretrained models and High resolutions
We provide pretrained models for resolution 256×256 for all datasets, as well as 1024×1024 for FFHQ and 1024×2048 for Cityscapes.

To generate images for the high-resolution models, run the following commands:
(We reduce their batch-size to 1 so that they can load onto a single GPU)

```python
python generate.py --gpus 0 --model gdrive:ffhq-snapshot-1024.pkl --output-dir ffhq_images --images-num 32 --batch-size 1
python generate.py --gpus 0 --model gdrive:cityscapes-snapshot-2048.pkl --output-dir cityscapes_images --images-num 32 --batch-size 1 --ratio 0.5
```

We can train and evaluate new or pretrained model both quantitatively and qualitative with [`run_network.py`](run_network.py) ([TF](run_network.py) / [Pytorch](pytorch_version/run_network.py)).
The model architecutre can be found at [`network.py`](training/network.py) ([TF](training/network.py) / [Pytorch](pytorch_version/training/network.py)). The training procedure is implemented at [`training_loop.py`](training/training_loop.py) ([TF](training/training_loop.py) / [Pytorch](pytorch_version/training/training_loop.py)).
Expand Down
21 changes: 9 additions & 12 deletions dataset_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ def __init__(self, tfrecord_dir, expected_imgs, verbose = False, progress_interv

if self.verbose:
print("Creating dataset %s" % tfrecord_dir)
if not os.path.isdir(self.tfrecord_dir):
os.makedirs(self.tfrecord_dir)
assert os.path.isdir(self.tfrecord_dir)
os.makedirs(self.tfrecord_dir, exist_ok = True)

def close(self):
if self.verbose:
Expand Down Expand Up @@ -207,7 +205,7 @@ def display(tfrecord_dir):
idx = 0
while True:
try:
imgs, labels = dset.get_minibatch_np(1)
imgs, labels = dset.get_batch_np(1)
except tf.errors.OutOfRangeError:
break
if idx == 0:
Expand All @@ -229,14 +227,13 @@ def extract(tfrecord_dir, output_dir):
tflib.init_uninitialized_vars()

print("Extracting images to %s" % output_dir)
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
os.makedirs(output_dir, exist_ok = True)
idx = 0
while True:
if idx % 10 == 0:
print("%d\r" % idx, end = "", flush = True)
try:
imgs, _labels = dset.get_minibatch_np(1)
imgs, _labels = dset.get_batch_np(1)
except tf.errors.OutOfRangeError:
break
if imgs.shape[1] == 1:
Expand Down Expand Up @@ -264,11 +261,11 @@ def compare(tfrecord_dir_a, tfrecord_dir_b, ignore_labels):
if idx % 100 == 0:
print("%d\r" % idx, end = "", flush = True)
try:
imgs_a, labels_a = dset_a.get_minibatch_np(1)
imgs_a, labels_a = dset_a.get_batch_np(1)
except tf.errors.OutOfRangeError:
imgs_a, labels_a = None, None
try:
imgs_b, labels_b = dset_b.get_minibatch_np(1)
imgs_b, labels_b = dset_b.get_batch_np(1)
except tf.errors.OutOfRangeError:
imgs_b, labels_b = None, None
if imgs_a is None or imgs_b is None:
Expand Down Expand Up @@ -660,7 +657,7 @@ def process_func(idx):
def create_from_imgs(tfrecord_dir, img_dir, format = "png", shuffle = False, ratio = None,
max_imgs = None, shards_num = 5):
print("Loading images from %s" % img_dir)
img_filenames = sorted(glob.glob("{}/**/*.{}".format(img_dir, format), recursive = True))
img_filenames = sorted(glob.glob(f"{img_dir}/**/*.{format}", recursive = True))
if len(img_filenames) == 0:
error("No input images found")
if max_imgs is None:
Expand Down Expand Up @@ -699,7 +696,7 @@ def create_from_tfds(tfrecord_dir, dataset_name, ratio = None, max_imgs = None,
import tensorflow_datasets as tfds

print("Loading dataset %s" % dataset_name)
ds = tfds.load(dataset_name, split = "train", data_dir = "{}/tfds".format(tfrecord_dir))
ds = tfds.load(dataset_name, split = "train", data_dir = f"{tfrecord_dir}/tfds")
with TFRecordExporter(tfrecord_dir, 0, shards_num = shards_num) as tfr:
for i, ex in tqdm(enumerate(tfds.as_numpy(ds))):
img = PIL.Image.fromarray(ex["image"])
Expand Down Expand Up @@ -745,7 +742,7 @@ def create_from_lmdb(tfrecord_dir, lmdb_dir, ratio = None, max_imgs = None, shar
break

if bad_imgs > 0:
print("Couldn't read {} out of {} images".format(bad_imgs, max_imgs))
print(f"Couldn't read {bad_imgs} out of {max_imgs} images")

def create_from_npy(tfrecord_dir, npy_filename, shuffle = False, max_imgs = None, shards_num = 5):
print("Loading NPY archive from %s" % npy_filename)
Expand Down
10 changes: 5 additions & 5 deletions dnnlib/submission/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _create_run_dir_local(submit_config: SubmitConfig, resume: bool, create_new:

if not resume:
if os.path.exists(run_dir) and create_new:
raise RuntimeError("The run dir already exists! ({0})".format(run_dir))
raise RuntimeError(f"The run dir already exists! ({run_dir})")
if not os.path.exists(run_dir):
os.makedirs(run_dir)

Expand Down Expand Up @@ -242,7 +242,7 @@ def run_wrapper(submit_config: SubmitConfig) -> None:

exit_with_errcode = False
try:
print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
print(f"dnnlib: Running {submit_config.run_func_name}() on {submit_config.host_name}...")
start_time = time.time()

run_func_obj = util.get_obj_by_name(submit_config.run_func_name)
Expand All @@ -253,15 +253,15 @@ def run_wrapper(submit_config: SubmitConfig) -> None:
else:
run_func_obj(**submit_config.run_func_kwargs)

print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
print(f"dnnlib: Finished {submit_config.run_func_name}() in {util.format_time(time.time() - start_time)}.")
except:
if is_local:
raise
else:
traceback.print_exc()

log_src = os.path.join(submit_config.run_dir, "log.txt")
log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), f"{submit_config.run_name}-error.txt")
shutil.copyfile(log_src, log_dst)

# Defer sys.exit(1) to happen after we close the logs and create a _finished.txt
Expand Down Expand Up @@ -318,7 +318,7 @@ def submit_run(submit_config: SubmitConfig, run_func_name: str, create_newdir: b
#--------------------------------------------------------------------
host_run_dir = _create_run_dir_local(submit_config, resume, create_new = create_newdir)

submit_config.task_name = "{}-{:05d}-{}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)
submit_config.task_name = f"{submit_config.user_name}-{submit_config.run_id:05d}-{submit_config.run_desc}"
docker_valid_name_regex = "^[a-zA-Z0-9][a-zA-Z0-9_.-]+$"
if not re.match(docker_valid_name_regex, submit_config.task_name):
raise RuntimeError("Invalid task name. Probable reason: unacceptable characters in your submit_config.run_desc. Task name must be accepted by the following regex: " + docker_valid_name_regex + ", got " + submit_config.task_name)
Expand Down
2 changes: 1 addition & 1 deletion dnnlib/tflib/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_plugin(cuda_file):
# Compile if not already compiled
tf_ver = float(".".join(tf.__version__.split(".")[:-1]))
bin_file_ext = '.dll' if os.name == 'nt' else '.so'
bin_file = os.path.join(cuda_cache_path, cuda_file_name + "_{}_".format(tf_ver) + bin_file_ext) # + '_' + md5.hexdigest()
bin_file = os.path.join(cuda_cache_path, f"{cuda_file_name}_{tf_ver}_{bin_file_ext}") # + '_' + md5.hexdigest()

if not os.path.isfile(bin_file):
# Hash headers included by the CUDA code by running it through the preprocessor
Expand Down
26 changes: 13 additions & 13 deletions dnnlib/tflib/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ class Network:
# components: Container for sub-networks. Passed to the build func, and retained between calls
# num_inputs: Number of input tensors
# num_outputs: Number of output tensors
# input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension
# output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension
# input_shapes: Input tensor shapes (NC or NCHW), including batch dimension
# output_shapes: Output tensor shapes (NC or NCHW), including batch dimension
# input_shape: Short-hand for input_shapes[0]
# output_shape: Short-hand for output_shapes[0]
# input_templates: Input placeholders in the template graph
Expand Down Expand Up @@ -346,7 +346,7 @@ def copy_vars_from(self, src_net: "Network") -> None:
if len(uninitialized) > 0:
print(bcolored("Uninitialized variables:", "red"))
for name in uninitialized:
print("{}: {}".format(bold(name), self.vars[name].shape))
print(f"{bold(name)}: {self.vars[name].shape}")

for name in names:
if self.vars[name].shape == src_net.vars[self.translate(name)].shape:
Expand All @@ -355,7 +355,7 @@ def copy_vars_from(self, src_net: "Network") -> None:
if not mismatch:
mismatch = True
print(bcolored("Variables shape mismatching:", "red"))
print("{}: {}, {}".format(bold(name), self.vars[name].shape, src_net.vars[self.translate(name)].shape))
print(f"{bold(name)}: {self.vars[name].shape}, {src_net.vars[self.translate(name)].shape}")
names = names_new
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[self.translate(name)] for name in names}))

Expand Down Expand Up @@ -392,7 +392,7 @@ def run(self,
output_transform: dict = None,
return_as_list: bool = False,
print_progress: bool = False,
minibatch_size: int = None,
batch_size: int = None,
num_gpus: int = 1,
assume_frozen: bool = False,
verbose: bool = False,
Expand All @@ -407,7 +407,7 @@ def run(self,
# TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
# return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
# print_progress: Print progress to the console? Useful for very large input arrays.
# minibatch_size: Maximum minibatch size to use, None = disable batching.
# batch_size: Maximum batch size to use, None = disable batching.
# num_gpus: Number of GPUs to use.
# assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
# dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
Expand All @@ -424,8 +424,8 @@ def run(self,
assert output_transform is None or util.is_top_level_function(output_transform["func"])
output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
num_items = in_arrays[0].shape[0]
if minibatch_size is None:
minibatch_size = num_items
if batch_size is None:
batch_size = num_items

# Construct unique hash key from all arguments that affect the TensorFlow graph
key = dict(input_transform = input_transform, output_transform = output_transform, num_gpus = num_gpus, assume_frozen = assume_frozen, dynamic_kwargs = dynamic_kwargs)
Expand Down Expand Up @@ -469,20 +469,20 @@ def unwind_key(obj):
out_expr = [tf.concat(outputs, axis = 0) for outputs in zip(*out_split)]
self._run_cache[key] = in_expr, out_expr

# Run minibatches
# Run batches
in_expr, out_expr = self._run_cache[key]
out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr]

range_fn = range
if verbose:
range_fn = lambda start, end, step: trange(start, end, step, unit_scale = minibatch_size,
unit = "image ({} batches of {} images)".format(len(range(0, num_items, minibatch_size)), minibatch_size))
range_fn = lambda start, end, step: trange(start, end, step, unit_scale = batch_size,
unit = f"image ({len(range(0, num_items, batch_size))} batches of {batch_size} images)")

for mb_begin in range_fn(0, num_items, minibatch_size):
for mb_begin in range_fn(0, num_items, batch_size):
if print_progress:
print("\r%d / %d" % (mb_begin, num_items), end="")

mb_end = min(mb_begin + minibatch_size, num_items)
mb_end = min(mb_begin + batch_size, num_items)
mb_num = mb_end - mb_begin
mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
Expand Down
20 changes: 10 additions & 10 deletions dnnlib/tflib/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Tensorflow optimizer, supports:
### Gradient averaging for multi-GPU training
### Gradient accumulation for arbitrarily large minibatches
### Gradient accumulation for arbitrarily large batches
### Dynamic loss scaling and typecasts for FP16 training
### Ignoring corrupted gradients that contain NaNs/Infs
### Reporting statistics
Expand Down Expand Up @@ -28,20 +28,20 @@ def __init__(self,
name: str = "Train", # Name string that will appear in TensorFlow graph
tf_optimizer: str = "tf.train.AdamOptimizer", # Underlying optimizer class
learning_rate: TfExpressionEx = 0.001, # Learning rate, can vary over time
minibatch_multiplier: TfExpressionEx = None, # Treat N consecutive minibatches as one by accumulating gradients
batch_multiplier: TfExpressionEx = None, # Treat N consecutive batches as one by accumulating gradients
share: "Optimizer" = None, # Share internal state with a previously created optimizer?
use_loss_scaling: bool = False, # Enable dynamic loss scaling for robust mixed-precision training?
loss_scaling_init: float = 64.0, # Log2 of initial loss scaling factor
loss_scaling_inc: float = 0.0005, # Log2 of per-minibatch loss scaling increment when there is no overflow
loss_scaling_dec: float = 1.0, # Log2 of per-minibatch loss scaling decrement when there is an overflow
loss_scaling_inc: float = 0.0005, # Log2 of per-batch loss scaling increment when there is no overflow
loss_scaling_dec: float = 1.0, # Log2 of per-batch loss scaling decrement when there is an overflow
report_mem_usage: bool = False, # Report fine-grained memory usage statistics in TensorBoard?
clip: float = None,
**kwargs):

# Public fields
self.name = name
self.learning_rate = learning_rate
self.minibatch_multiplier = minibatch_multiplier
self.batch_multiplier = batch_multiplier
self.id = self.name.replace("/", ".")
self.scope = tf.get_default_graph().unique_name(self.id)
self.optimizer_class = util.get_obj_by_name(tf_optimizer)
Expand Down Expand Up @@ -191,8 +191,8 @@ def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
# Scale as needed
scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
scale = tf.constant(scale, dtype = tf.float32, name = "scale")
if self.minibatch_multiplier is not None:
scale /= tf.cast(self.minibatch_multiplier, tf.float32)
if self.batch_multiplier is not None:
scale /= tf.cast(self.batch_multiplier, tf.float32)
scale = self.undo_loss_scaling(scale)
device.grad_clean[var] = grad * scale

Expand All @@ -210,7 +210,7 @@ def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
for device_idx, device in enumerate(self._devices.values()):
with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
# Accumulate gradients over time
if self.minibatch_multiplier is None:
if self.batch_multiplier is None:
acc_ok = tf.constant(True, name = "acc_ok")
device.grad_acc = OrderedDict(device.grad_clean)
else:
Expand All @@ -224,7 +224,7 @@ def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
count_cur = device.grad_acc_count + 1.0
count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
acc_ok = (count_cur >= tf.cast(self.batch_multiplier, tf.float32))
all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))

# Track gradients
Expand Down Expand Up @@ -260,7 +260,7 @@ def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
self.reset_optimizer_state()
if self.use_loss_scaling:
tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
if self.minibatch_multiplier is not None:
if self.batch_multiplier is not None:
tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])

# Group everything into a single op
Expand Down
8 changes: 4 additions & 4 deletions dnnlib/tflib/tfutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,16 @@ def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwar
# images conversions
#----------------------------------------------------------------------------

def convert_imgs_from_uint8(images, drange=[-1,1], nhwc_to_nchw = False):
# Convert a minibatch of images from uint8 to float32 with configurable dynamic range
def convert_imgs_from_uint8(images, drange = [-1,1], nhwc_to_nchw = False):
# Convert a batch of images from uint8 to float32 with configurable dynamic range
# Can be used as an input transformation for Network.run()
images = tf.cast(images, tf.float32)
if nhwc_to_nchw:
images = tf.transpose(images, [0, 3, 1, 2])
return images * ((drange[1] - drange[0]) / 255) + drange[0]

def convert_imgs_to_uint8(images, drange=[-1,1], nchw_to_nhwc = False, shrink = 1, lst = False):
# Convert a minibatch of images from float32 to uint8 with configurable dynamic range
def convert_imgs_to_uint8(images, drange = [-1,1], nchw_to_nhwc = False, shrink = 1, lst = False):
# Convert a batch of images from float32 to uint8 with configurable dynamic range
# Can be used as an output transformation for Network.run()
if lst:
images = images[0]
Expand Down
Loading

0 comments on commit 55194b7

Please sign in to comment.