Skip to content

Conversation

@gjamesgoenawan
Copy link
Contributor

@gjamesgoenawan gjamesgoenawan commented Oct 29, 2025

What does this PR do?

This PR proposes changing the default padding value from 0.5 to 0.0 in OWLv2. While OWLv1 originally used a padding value of 0.5 (gray) as described in its paper [1], OWLv2 adopts 0.0 instead [2], consistent with its official implementation [3]. Using the incorrect padding value (0.5) leads to degraded performance on the LVIS dataset.

Implementation LVIS mAP
Scenic 43.9
Transformers (0.5 padding) 43.4
Transformers (0.0 padding) 44.0

Reproducing the results

Testing scripts:
The following scripts explicitly resized and pad the image beforehand so no padding will be done in the processor.

import os
import re
import torch
import argparse
import warnings
import numpy as np
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from transformers import Owlv2Processor, Owlv2ForObjectDetection
from PIL import Image
from lvis import LVIS, LVISResults, LVISEval
from tqdm import tqdm


warnings.filterwarnings("ignore")

NOT_PROMPTABLE_MARKER = '#'
PROMPT_TEMPLATES = [
    'itap of a {}.',
    'a bad photo of the {}.',
    'a origami {}.',
    'a photo of the large {}.',
    'a {} in a video game.',
    'art of the {}.',
    'a photo of the small {}.',
]

def _canonicalize_string(string: str) -> str:
    string = string.lower()
    string = re.sub(f'[^a-z0-9-{NOT_PROMPTABLE_MARKER} ]', ' ', string)
    string = re.sub(r'\s+', ' ', string)
    string = re.sub(r'-+', '-', string)
    string = string.strip()
    string = re.sub(f'([^^]){NOT_PROMPTABLE_MARKER}+', r'\1', string)
    return string

class LVISDataset(Dataset):
    def __init__(self, ann_file, img_dir, processor, pad_value):
        self.lvis = LVIS(ann_file)
        self.img_ids = sorted(self.lvis.imgs.keys())
        self.img_dir = img_dir
        self.processor = processor
        self.img_size = self.processor.image_processor.size['height']
        self.pad_value = pad_value

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.lvis.imgs[img_id]
        img_path = os.path.join(self.img_dir, os.path.basename(img_info['coco_url']))
        # Load image
        image = Image.open(img_path).convert("RGB")
        image = np.array(image).astype(np.float32) / 255.0  # scale to [0,1]

        # Determine square size
        max_side = max(image.shape[1], image.shape[0])

        # Create padded square with floating-point pad value
        pad_value = np.array(self.pad_value, dtype=np.float32)  # e.g., [0.5,0.5,0.5]
        padded_image = np.ones((max_side, max_side, 3), dtype=np.float32) * pad_value

        # Paste original image at top-left
        padded_image[:image.shape[0], :image.shape[1], :] = image

        # Convert back to PIL for resizing
        padded_image = Image.fromarray((padded_image * 255).astype(np.uint8))

        # Resize to target size
        resized_image = padded_image.resize((self.img_size, self.img_size), Image.Resampling.BILINEAR)

        # Process image
        pixel_values = self.processor.image_processor(
            images=resized_image,
            return_tensors="pt"
        )['pixel_values']
        return img_id, image, img_info['width'], img_info['height'], pixel_values


def collate_fn(batch):
    img_ids, images, widths, heights, pixel_values = zip(*batch)
    return list(img_ids), list(images), list(widths), list(heights), torch.cat(list(pixel_values), axis=0)


def main():
    parser = argparse.ArgumentParser(description="Evaluate OWLv2 on LVIS dataset")
    parser.add_argument("--dataset_dir", default="/path/to/lvis/dataset")
    parser.add_argument("--pad_value", type=float, default=0.5)
    parser.add_argument("--local_rank", default=int(os.getenv('LOCAL_RANK', -1)), type=int)
    parser.add_argument("--topk", type=int, default=300)
    parser.add_argument("--num_workers", type=int, default=4)
    args = parser.parse_args()

    torch.cuda.set_device(args.local_rank)

    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=int(os.getenv("WORLD_SIZE", 1)),
        rank=int(os.getenv("RANK", 0)),
        device_id=torch.device(f'cuda:{args.local_rank}'),
    )
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    print(f'Using Pad Value : {args.pad_value}')
    device = torch.device(f"cuda:{args.local_rank}" if args.local_rank >= 0 else "cuda")
    if rank == 0:
        print(f"Running evaluation on {world_size} GPUs, device={device}")

    processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble",  use_fast=True)
    model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device).eval()

    ann_file = os.path.join(args.dataset_dir, "lvis_v1_val.json")
    img_dir = os.path.join(args.dataset_dir, "val2017")

    dataset = LVISDataset(ann_file, img_dir, processor=processor, pad_value=args.pad_value)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False)
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        sampler=sampler,
        collate_fn=collate_fn,
        num_workers=args.num_workers,
        pin_memory=True,
        persistent_workers=(args.num_workers > 0)
    )

    lvis_gt = dataset.lvis
    cats = sorted(lvis_gt.cats.items(), key=lambda x: x[0])
    class_names = [cat['name'] for _, cat in cats]

    texts_ens = []
    for template in PROMPT_TEMPLATES:
        texts_ens += [_canonicalize_string(template.format(name)) for name in class_names]

    with torch.no_grad():
        text_inputs = processor.tokenizer(
            texts_ens, padding=True, truncation=True, max_length=16, return_tensors="pt"
        ).to(device)
        text_outputs = model.owlv2.text_model(**text_inputs)
        text_embeds = model.owlv2.text_projection(text_outputs[1])
        text_embeds = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
        input_ids = text_inputs['input_ids'].reshape(1, -1, text_inputs['input_ids'].shape[-1])
        query_mask = input_ids[..., 0] > 0

    print(f'RANK {rank}, Ready!')
    dist.barrier()

    raw_predictions = []
    progress_bar = tqdm(dataloader, desc="Evaluating") if rank == 0 else dataloader
    
    
    for n, batch in enumerate(progress_bar):
        img_ids, images, widths, heights, pixel_values = batch
        with torch.no_grad():
            num_patches_height = model.num_patches_height
            num_patches_width = model.num_patches_width

            vision_outputs = model.owlv2.vision_model(pixel_values=pixel_values.to(device))
            last_hidden_state = vision_outputs[0]
            image_embeds = model.owlv2.vision_model.post_layernorm(last_hidden_state)
            class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
            image_embeds = image_embeds[:, 1:, :] * class_token_out
            image_embeds = model.layer_norm(image_embeds)
            image_embeds = image_embeds.reshape(
                image_embeds.shape[0], num_patches_height, num_patches_width, image_embeds.shape[-1]
            )

            image_feats = image_embeds.view(image_embeds.shape[0], -1, image_embeds.shape[-1])
            (pred_logits, _) = model.class_predictor(image_feats, text_embeds, query_mask)
            pred_boxes = model.box_predictor(image_feats, image_embeds, False)

            num_templates = len(PROMPT_TEMPLATES)
            num_classes = len(class_names)
            scores = pred_logits.reshape(1, -1, num_templates, num_classes).mean(2)

            bsz, num_patches, num_classes = scores.shape
            k = min(args.topk, num_patches * num_classes)
            scores_flat = scores.view(bsz, -1)
            topk_scores, topk_inds = torch.topk(scores_flat, k, dim=1)

            patch_inds = topk_inds // num_classes
            label_inds = topk_inds % num_classes
            batch_idx = torch.arange(bsz, device=pred_boxes.device).unsqueeze(-1)
            selected_boxes = pred_boxes[batch_idx, patch_inds]

            raw_predictions.append([
                img_ids, widths, heights,
                topk_scores.cpu(), label_inds.cpu(), selected_boxes.cpu()
            ])
        
    torch.cuda.synchronize()

    predictions = []
    for img_ids, widths, heights, topk_scores_cpu, label_inds_cpu, selected_boxes_cpu in raw_predictions:
        image_id = img_ids[0]
        w, h = float(widths[0]), float(heights[0])
        scale = max(w, h)
        scores_np = topk_scores_cpu[0].numpy()
        labels_np = label_inds_cpu[0].numpy()
        boxes_np = selected_boxes_cpu[0].numpy()

        cx, cy, bw, bh = boxes_np[:, 0], boxes_np[:, 1], boxes_np[:, 2], boxes_np[:, 3]
        x, y = (cx - bw / 2) * scale, (cy - bh / 2) * scale
        width, height = bw * scale, bh * scale

        preds_img = [
            {
                "image_id": image_id,
                "category_id": cats[label][0],
                "bbox": [float(x[i]), float(y[i]), float(width[i]), float(height[i])],
                "score": float(scores_np[i]),
            }
            for i, label in enumerate(labels_np)
        ]
        predictions.extend(preds_img)

    print(f'RANK {rank}, Done!')
    all_predictions = [None] * world_size
    dist.all_gather_object(all_predictions, predictions)

    if rank == 0:
        full_predictions = [p for sublist in all_predictions for p in sublist]
        lvis_dt = LVISResults(lvis_gt, full_predictions)
        lvis_eval = LVISEval(lvis_gt, lvis_dt, iou_type='bbox')
        lvis_eval.evaluate()
        lvis_eval.accumulate()
        lvis_eval.summarize()
        lvis_eval.print_results()

    dist.destroy_process_group()


if __name__ == "__main__":
    main()

Commands:

# 0.5 padding:
torchrun --nproc-per-node=NUM_GPUS myscript.py --pad_value 0.5 --dataset_dir /path/to/lvis/

# 0.0 padding:
torchrun --nproc-per-node=NUM_GPUS myscript.py --pad_value 0.0 --dataset_dir /path/to/lvis/

Please prepare LVIS dataset beforehand with the following structure:

/path/to/lvis/
      ├── val2017
      │ ├── 000000062833.jpg
      │ └── ...
      └── lvis_v1_val.json

After Running the scripts, the following logs should be printed:

0.5 padding

Using Pad Value : 0.5
Running evaluation on 1 GPUs, device=cuda:0
RANK 0, Ready!
RANK 0, Done!
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=300 catIds=all] = 0.434
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=300 catIds=all] = 0.600
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=300 catIds=all] = 0.473
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=     s | maxDets=300 catIds=all] = 0.330
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=     m | maxDets=300 catIds=all] = 0.533
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=     l | maxDets=300 catIds=all] = 0.652
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=300 catIds=  r] = 0.403
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=300 catIds=  c] = 0.430
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=300 catIds=  f] = 0.451
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=300 catIds=all] = 0.563
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=     s | maxDets=300 catIds=all] = 0.406
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=     m | maxDets=300 catIds=all] = 0.672
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=     l | maxDets=300 catIds=all] = 0.805

0.0 padding

Using Pad Value : 0.0
Running evaluation on 1 GPUs, device=cuda:0
RANK 0, Ready!
RANK 0, Done!
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=300 catIds=all] = 0.440
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=300 catIds=all] = 0.602
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=300 catIds=all] = 0.482
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=     s | maxDets=300 catIds=all] = 0.333
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=     m | maxDets=300 catIds=all] = 0.540
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=     l | maxDets=300 catIds=all] = 0.664
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=300 catIds=  r] = 0.406
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=300 catIds=  c] = 0.438
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=300 catIds=  f] = 0.458
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=300 catIds=all] = 0.570
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=     s | maxDets=300 catIds=all] = 0.411
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=     m | maxDets=300 catIds=all] = 0.678
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=     l | maxDets=300 catIds=all] = 0.815

Reference:
[1] OWLv1 (Figure A4.)
[2] OWLv2 (Figure A3),
[3] OWLv2 original implementation, which is changed with this PR (scenic/projects/owl_vit/evaluator.py, line 158).

@gjamesgoenawan gjamesgoenawan changed the title Fixed padding value of OWLv2 Fixed missed padding value in modular_owlv2 Oct 29, 2025
@gjamesgoenawan gjamesgoenawan changed the title Fixed missed padding value in modular_owlv2 Fixed wrong padding value in OWLv2 Oct 29, 2025
@Rocketknight1
Copy link
Member

cc @yonigozlan

Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @gjamesgoenawan ! Thanks a lot for investigating this and providing sources! Ok to merge for me, just waiting on the CI and the slow CI to pass, as this might change some integration tests. If it does I'll push the new (correct) results directly on this PR

@yonigozlan
Copy link
Member

run-slow: owlv2

@github-actions
Copy link
Contributor

github-actions bot commented Nov 3, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: owlv2

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gjamesgoenawan
Copy link
Contributor Author

gjamesgoenawan commented Nov 3, 2025

@yonigozlan Thanks for reviewing!

I think it would be helpful to mention some inference optimizations from the jax implementation (text embedding caching and prompt ensembling) as these aren’t immediately obvious from the current documentation. Including brief references or examples for these techniques would make it clearer for users.

@yonigozlan
Copy link
Member

I think it would be helpful to mention some inference optimizations from the jax implementation (text embedding caching and prompt ensembling) as these aren’t immediately obvious from the current documentation. Including brief references or examples for these techniques would make it clearer for users.

Indeed, at least for text embedding caching, it seems there's no easy way to use it with the current API. I'm working on a refactor of vision models, and I'll add this to the list of to-dos.
Not sure about prompt ensembling though, this seems to be a technique used for training? Or am I misunderstanding something

@yonigozlan yonigozlan merged commit 64397a8 into huggingface:main Nov 3, 2025
15 checks passed
@gjamesgoenawan
Copy link
Contributor Author

Prompt Ensembling is explored in CLIP[1] in Section 3.1.4 & Figure 4.

TLDR; It is an inference technique that significantly improve zero-shot performance by averaging logits over multiple text-prompts.

You can find traces of this in the original JAX evaluator code and the code snippet I shared.

I define the number of prompt templates as n_prompts, which in this case has a value of 7. text_embeds has a shape of (n_prompts*n_classes,16,hidden_dim) and hence pred_logits will have a shape of (bs, n_objects, n_prompts*n_classes). This is then reshaped and averaged by:

scores = pred_logits.reshape(1, -1, num_templates, num_classes).mean(2)
# scores [bs, n_objects, n_classes]

Effectively averaging logits across all prompt templates.

Additionally, the top-k operations is a bit different from Transformers' implementation as well. Specifically this line in postprocessing which strictly enforce maximum of 1 detection per object proposal.

My implementation follow the original closely which doesnt have this restriction. Instead, it took the top-k from all logits, meaning 1 object proposal can result in more than 1 detection.

Without these two methods, evaluation performance differs significantly from the original.

[1] CLIP

@yonigozlan yonigozlan changed the title Fixed wrong padding value in OWLv2 🚨Fixed wrong padding value in OWLv2 Nov 4, 2025
@yonigozlan
Copy link
Member

Thanks for the explanation @gjamesgoenawan, it does sound like some elements were overlooked when adding this model. Feel free to open PRs to add/fix these two features, I'll make sure to review them quickly.

yonigozlan added a commit to yonigozlan/transformers that referenced this pull request Nov 7, 2025
* Update image_processing_owlv2_fast.py

fixed padding value

* fixed padding value

* Change padding constant value from 0.5 to 0.0

* Fixed missed padding value in modular_owlv2.py

---------

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
Abdennacer-Badaoui pushed a commit to Abdennacer-Badaoui/transformers that referenced this pull request Nov 10, 2025
* Update image_processing_owlv2_fast.py

fixed padding value

* fixed padding value

* Change padding constant value from 0.5 to 0.0

* Fixed missed padding value in modular_owlv2.py

---------

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants