Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_processor is not using local_files_only flag #96

Open
w4ffl35 opened this issue Mar 24, 2024 · 1 comment · May be fixed by #97
Open

load_processor is not using local_files_only flag #96

w4ffl35 opened this issue Mar 24, 2024 · 1 comment · May be fixed by #97

Comments

@w4ffl35
Copy link

w4ffl35 commented Mar 24, 2024

Describe the bug

There is no way to pass the local_files_only flag to the controlnet Processor load_processor function (see controlnet_aux/processor.py)

Code:

    def load_processor(self, processor_id: str) -> 'Processor':
        """Load controlnet aux processors

        Args:
            processor_id (str): processor name

        Returns:
            Processor: controlnet aux processor
        """
        processor = MODELS[processor_id]['class']

        # check if the proecssor is a checkpoint model
        if MODELS[processor_id]['checkpoint']:
            processor = processor.from_pretrained("lllyasviel/Annotators")
        else:
            processor = processor()
        return processor

That function in-turn calls things such as LeresDetector.from_pretrained which looks like this

class LeresDetector:
    @classmethod
    def from_pretrained(cls, pretrained_model_or_path, filename=None, pix2pix_filename=None, cache_dir=None):
        filename = filename or "res101.pth"
        pix2pix_filename = pix2pix_filename or "latest_net_G.pth"

        if os.path.isdir(pretrained_model_or_path):
            model_path = os.path.join(pretrained_model_or_path, filename)
        else:
            model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
            
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

        model = RelDepthModel(backbone='resnext101')
        model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True)
        del checkpoint

        if os.path.isdir(pretrained_model_or_path):
            model_path = os.path.join(pretrained_model_or_path, pix2pix_filename)
        else:
            model_path = hf_hub_download(pretrained_model_or_path, pix2pix_filename, cache_dir=cache_dir)

        opt = TestOptions().parse()
        if not torch.cuda.is_available():
            opt.gpu_ids = []  # cpu mode
        pix2pixmodel = Pix2Pix4DepthModel(opt)
        pix2pixmodel.save_dir = os.path.dirname(model_path)
        pix2pixmodel.load_networks('latest')
        pix2pixmodel.eval()

        return cls(model, pix2pixmodel)

LeresDetector.from_pretrained in-turn calls hf_hub_download in order to download the model if the path is not a folder. hf_hub_download takes a local_files_only flag, however it is not being passed here.

Because of this, the hub attempts to access huggingface.co which causes the application to hang when offline.

Proposed fix

Add a local_files_only flag to the Processor and each controlnet class

    def load_processor(self, processor_id: str, local_files_only: bool = False) -> 'Processor':
        """Load controlnet aux processors

        Args:
            processor_id (str): processor name

        Returns:
            Processor: controlnet aux processor
        """
        processor = MODELS[processor_id]['class']

        # check if the proecssor is a checkpoint model
        if MODELS[processor_id]['checkpoint']:
            processor = processor.from_pretrained("lllyasviel/Annotators", local_files_only=local_files_only)
        else:
            processor = processor()
        return processor
class LeresDetector:
    @classmethod
    def from_pretrained(cls, pretrained_model_or_path, filename=None, pix2pix_filename=None, cache_dir=None, local_files_only: bool = False):
        filename = filename or "res101.pth"
        pix2pix_filename = pix2pix_filename or "latest_net_G.pth"

        if os.path.isdir(pretrained_model_or_path):
            model_path = os.path.join(pretrained_model_or_path, filename)
        else:
            model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only)
            
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

        model = RelDepthModel(backbone='resnext101')
        model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True)
        del checkpoint

        if os.path.isdir(pretrained_model_or_path):
            model_path = os.path.join(pretrained_model_or_path, pix2pix_filename)
        else:
            model_path = hf_hub_download(pretrained_model_or_path, pix2pix_filename, cache_dir=cache_dir, local_files_only=local_files_only)

        opt = TestOptions().parse()
        if not torch.cuda.is_available():
            opt.gpu_ids = []  # cpu mode
        pix2pixmodel = Pix2Pix4DepthModel(opt)
        pix2pixmodel.save_dir = os.path.dirname(model_path)
        pix2pixmodel.load_networks('latest')
        pix2pixmodel.eval()

        return cls(model, pix2pixmodel)

Reproduction

Attempt to initialize a processor using the load_processor function without an internet connection and see the script hang.

Logs

No response

System Info

controlnet_aux >=0.0.7

@w4ffl35
Copy link
Author

w4ffl35 commented Mar 24, 2024

It appears that some of this is already solved in a recent PR, however when I install from main it says I'm on version 0.0.6 rather than something higher than 0.0.7. Additionally the load_processor function does not respect the flag so I will be fixing that and the version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant