Skip to content

Commit

Permalink
clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
timojl committed Apr 28, 2022
1 parent 992398d commit e21a947
Showing 1 changed file with 1 addition and 31 deletions.
32 changes: 1 addition & 31 deletions score.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,37 +353,7 @@ def score(config, train_checkpoint_id, train_config):

return {key_prefix: {k: scores[k] for k in ['seen', 'unseen', 'harmonic', 'overall']}}

elif config.test_dataset == 'fss1000':
from datasets.fss1000 import FSS1000
dataset_cls = FSS1000
_, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
dataset = dataset_cls(**dataset_args)

loader = DataLoader(dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False)
metric = FixedIntervalMetrics(resize_pred=True, **metric_args)

with torch.no_grad():

i, losses = 0, []
for i_all, (data_x, data_y) in enumerate(loader):
data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]

pred, = model(data_x[0], data_x[1], data_x[2])
metric.add([pred], data_y)

i += 1
if config.max_iterations and i >= config.max_iterations:
break

prefix = 'fss'
if 'ratio_negative' in dataset_args:
prefix += f'-neg{dataset_args["ratio_negative"]}'

return {prefix: metric.scores()}


elif config.test_dataset in {'same_as_training', 'lvis', 'affordance', 'fss1000'}:
elif config.test_dataset in {'same_as_training', 'affordance'}:
loss_fn = get_attribute(train_config.loss)

metric_cls = get_attribute(config.metric)
Expand Down

0 comments on commit e21a947

Please sign in to comment.