Skip to content

Commit

Permalink
Merge pull request #391 from OFA-Sys/feature/add_text_recognition
Browse files Browse the repository at this point in the history
Feature/add text recognition
  • Loading branch information
logicwong committed May 11, 2023
2 parents 9ed0646 + 408d466 commit bf9b9f3
Show file tree
Hide file tree
Showing 6 changed files with 588 additions and 169 deletions.
2 changes: 1 addition & 1 deletion data/mm_data/ocr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
import random
import functools

import numpy as np
import torch
import base64
from torchvision import transforms
Expand Down
44 changes: 31 additions & 13 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3 -u
# Copyright 2022 The OFA-Sys Team.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import logging
Expand Down Expand Up @@ -42,7 +42,7 @@ def main(cfg: DictConfig, **kwargs):
logger.info(cfg)

assert (
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
), "Must specify batch size either with --max-tokens or --batch-size"

# Fix seed for stochastic decoding
Expand Down Expand Up @@ -82,7 +82,8 @@ def main(cfg: DictConfig, **kwargs):
num_shards=cfg.checkpoint.checkpoint_shard_count,
)

# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
# loading the dataset should happen after the checkpoint has been loaded
# so we can give it the saved task config
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)

if cfg.generation.lm_path is not None:
Expand All @@ -104,10 +105,13 @@ def main(cfg: DictConfig, **kwargs):
lms = [None]

# Move models to GPU
for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)):
for model, ckpt_path in zip(
models, utils.split_paths(
cfg.common_eval.path)):
if kwargs['ema_eval']:
logger.info("loading EMA weights from {}".format(ckpt_path))
model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model'])
model.load_state_dict(
checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model'])
model.eval()
if use_fp16:
model.half()
Expand Down Expand Up @@ -135,7 +139,8 @@ def main(cfg: DictConfig, **kwargs):
itr,
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
default_log_format=(
"tqdm" if not cfg.common.no_progress_bar else "simple"),
)

# Initialize generator
Expand All @@ -149,12 +154,15 @@ def main(cfg: DictConfig, **kwargs):
if "net_input" not in sample:
continue
sample = utils.move_to_cuda(sample) if use_cuda else sample
sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
sample = utils.apply_to_sample(
apply_half, sample) if cfg.common.fp16 else sample
with torch.no_grad():
if kwargs["zero_shot"]:
result, scores = zero_shot_step(task, generator, models, sample)
result, scores = zero_shot_step(
task, generator, models, sample)
else:
result, scores = eval_step(task, generator, models, sample, **kwargs)
result, scores = eval_step(
task, generator, models, sample, **kwargs)
results += result
if scores and isinstance(scores[0], tuple):
score_sum += sum([s[0] for s in scores])
Expand All @@ -170,13 +178,23 @@ def main(cfg: DictConfig, **kwargs):

def cli_main():
parser = options.get_generation_parser()
parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.")
parser.add_argument("--beam-search-vqa-eval", action='store_true', help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.")
parser.add_argument(
"--ema-eval",
action='store_true',
help="Use EMA weights to make evaluation.")
parser.add_argument(
"--beam-search-vqa-eval",
action='store_true',
help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.")
parser.add_argument("--zero-shot", action='store_true')
args = options.parse_args_and_arch(parser)
cfg = convert_namespace_to_omegaconf(args)
distributed_utils.call_main(
cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval, zero_shot=args.zero_shot
cfg,
main,
ema_eval=args.ema_eval,
beam_search_vqa_eval=args.beam_search_vqa_eval,
zero_shot=args.zero_shot,
)


Expand Down
Loading

0 comments on commit bf9b9f3

Please sign in to comment.