Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

predict.py #144

Open
CHSLAM opened this issue Sep 23, 2022 · 3 comments
Open

predict.py #144

CHSLAM opened this issue Sep 23, 2022 · 3 comments

Comments

@CHSLAM
Copy link

CHSLAM commented Sep 23, 2022

I have downloaded fpn_inception.h5.
Why I run python predict.py 0.png nothing happens?

@ulucsahin
Copy link

ulucsahin commented Oct 26, 2022

Because the code is not correct.
in predict.py, replace if __name__ == "__main__" code block with:

if __name__ == "__main__":
    image_path = "path-to-your-image.jpg"
    main(image_path)

@fighting666777
Copy link

I have downloaded fpn_inception.h5. Why I run python predict.py 0.png nothing happens?

Ihave the same problem,do you know how to slove it?

@baselqt
Copy link

baselqt commented Dec 27, 2023

here is my fix, place all jpg files you want to process in the test_img folder and run python predict.py only
and modify the path\to\your\test_img with the actual path

import os
from glob import glob
from typing import Optional

import cv2
import numpy as np
import torch
import yaml
from tqdm import tqdm

from aug import get_normalize
from models.networks import get_generator


class Predictor:
    def __init__(self, weights_path: str, model_name: str = ''):
        with open('config/config.yaml', encoding='utf-8') as cfg:
            config = yaml.load(cfg, Loader=yaml.FullLoader)
        model = get_generator(model_name or config['model'])
        model.load_state_dict(torch.load(weights_path)['model'])
        self.model = model.cuda()
        self.model.train(True)
        self.normalize_fn = get_normalize()

    @staticmethod
    def _array_to_batch(x):
        x = np.transpose(x, (2, 0, 1))
        x = np.expand_dims(x, 0)
        return torch.from_numpy(x)

    def _preprocess(self, x: np.ndarray, mask: Optional[np.ndarray]):
        x, _ = self.normalize_fn(x, x)
        if mask is None:
            mask = np.ones_like(x, dtype=np.float32)
        else:
            mask = np.round(mask.astype('float32') / 255)

        h, w, _ = x.shape
        block_size = 32
        min_height = (h // block_size + 1) * block_size
        min_width = (w // block_size + 1) * block_size

        pad_params = {'mode': 'constant',
                      'constant_values': 0,
                      'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
                      }
        x = np.pad(x, **pad_params)
        mask = np.pad(mask, **pad_params)

        return map(self._array_to_batch, (x, mask)), h, w

    @staticmethod
    def _postprocess(x: torch.Tensor) -> np.ndarray:
        x, = x
        x = x.detach().cpu().float().numpy()
        x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
        return x.astype('uint8')

    def __call__(self, img: np.ndarray, mask: Optional[np.ndarray], ignore_mask=True) -> np.ndarray:
        (img, mask), h, w = self._preprocess(img, mask)
        with torch.no_grad():
            inputs = [img.cuda()]
            if not ignore_mask:
                inputs += [mask]
            pred = self.model(*inputs)
        return self._postprocess(pred)[:h, :w, :]


def main(img_pattern: str,
         weights_path='fpn_inception.h5',
         out_dir='submit/',
         side_by_side: bool = False):
    def sorted_glob(pattern):
        return sorted(glob(pattern))

    imgs = sorted_glob(img_pattern)
    names = sorted([os.path.basename(x) for x in glob(img_pattern)])
    predictor = Predictor(weights_path=weights_path)

    os.makedirs(out_dir, exist_ok=True)
    print(f"Total images to process: {len(names)}")

    for name, img_path in tqdm(zip(names, imgs), total=len(names)):
        print(f"Processing: {name}")
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        pred = predictor(img, None)
        if side_by_side:
            pred = np.hstack((img, pred))
        pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join(out_dir, name), pred)


if __name__ == '__main__':
    image_pattern = "PATH\TO\TEST_IMG"
    main(img_pattern=image_pattern)


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants