Skip to content

Commit

Permalink
Update calc_metrics.py
Browse files Browse the repository at this point in the history
  • Loading branch information
thuanz123 authored May 25, 2023
1 parent c1957c6 commit 469b31c
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions calc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ def parse_comma_separated_list(s):
# ----------------------------------------------------------------------------


def make_coords(resolution: float, scale: float):
coords = torch.linspace(0, 1, int(resolution * scale))
coords = coords.reshape(1, -1, 1, 1)
coords = coords.repeat(1, 1, 2, 1)
return coords


# ----------------------------------------------------------------------------


@click.command()
@click.pass_context
@click.option(
Expand All @@ -114,6 +124,13 @@ def parse_comma_separated_list(s):
metavar="PATH",
required=True,
)
@click.option(
"--scale",
help="Scale of generated images",
type=float,
default=1,
show_default=True,
)
@click.option(
"--metrics",
help="Quality metrics",
Expand Down Expand Up @@ -149,7 +166,7 @@ def parse_comma_separated_list(s):
metavar="BOOL",
show_default=True,
)
def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
def calc_metrics(ctx, network_pkl, scale, metrics, data, mirror, gpus, verbose):
"""Calculate quality metrics for previous training run or pretrained network pickle.
Examples:
Expand Down Expand Up @@ -198,6 +215,10 @@ def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
network_dict = legacy.load_network_pkl(f)
args.G = network_dict["G_ema"] # subclass of torch.nn.Module

# Construct an input coordinates and pass to the creps generator.
if hasattr(args.G.synthesis.b4, "input"):
args.G.synthesis.b4.input.coords = make_coords(args.G.img_resolution, scale)

# Initialize dataset options.
if data is not None:
Expand All @@ -208,7 +229,7 @@ def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
ctx.fail("Could not look up dataset options; please specify --data")

# Finalize dataset options.
args.dataset_kwargs.resolution = args.G.img_resolution
args.dataset_kwargs.resolution = int(args.G.img_resolution * scale)
args.dataset_kwargs.use_labels = args.G.c_dim != 0
if mirror is not None:
args.dataset_kwargs.xflip = mirror
Expand Down

0 comments on commit 469b31c

Please sign in to comment.