Skip to content

Message decoding problem using weight provided #29

@LiRunyi2001

Description

@LiRunyi2001

Hi there! I have tried the weight of decoder you provided here:
WM weights of latent decoder
and I generate an image using code provided in README.md:

from utils_model import load_model_from_config 

ldm_config = "/gdata/cold1/lirunyi/model-watermark/v2-inference.yaml"
ldm_ckpt = "/gdata/cold1/lirunyi/model-watermark/stable-diffusion-2-1-base/v2-1_512-ema-pruned.ckpt"

print(f'>>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...')
from omegaconf import OmegaConf 
config = OmegaConf.load(f"{ldm_config}")
ldm_ae = load_model_from_config(config, ldm_ckpt)
ldm_aef = ldm_ae.first_stage_model
ldm_aef.eval()
state_dict = torch.load("sd2_decoder.pth")
unexpected_keys = ldm_aef.load_state_dict(state_dict, strict=False)
print(unexpected_keys)
print("you should check that the decoder keys are correctly matched")

pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
pipeline = pipeline.to('cuda')
pipeline.vae.decode = (lambda x,  *args, **kwargs: ldm_aef.decode(x).unsqueeze(0))
# run inference
images = []
prompt = "a cat and a dog"
img = pipeline(prompt).images[0]
img.save(f"./{prompt}.png")

Then I use this image trying to extract message in decoding.ipynb, however it turns out that it cannot be extracted correctly, and the bit accuracy is only about 50% to 60%. I am wondering is there anything wrong with my usage? Thanks a lot!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions