-
Notifications
You must be signed in to change notification settings - Fork 58
/
pnp_utils.py
66 lines (55 loc) · 2.59 KB
/
pnp_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import numpy as np
from sklearn.decomposition import PCA
from PIL import Image
from torchvision import transforms as T
from math import sqrt
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from PIL import Image
# load safety model
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
def visualize_and_save_features_pca(feature_maps_fit_data,feature_maps_transform_data, transform_experiments, t, save_dir):
feature_maps_fit_data = feature_maps_fit_data.cpu().numpy()
pca = PCA(n_components=3)
pca.fit(feature_maps_fit_data)
feature_maps_pca = pca.transform(feature_maps_transform_data.cpu().numpy()) # N X 3
feature_maps_pca = feature_maps_pca.reshape(len(transform_experiments), -1, 3) # B x (H * W) x 3
for i, experiment in enumerate(transform_experiments):
pca_img = feature_maps_pca[i] # (H * W) x 3
h = w = int(sqrt(pca_img.shape[0]))
pca_img = pca_img.reshape(h, w, 3)
pca_img_min = pca_img.min(axis=(0, 1))
pca_img_max = pca_img.max(axis=(0, 1))
pca_img = (pca_img - pca_img_min) / (pca_img_max - pca_img_min)
pca_img = Image.fromarray((pca_img * 255).astype(np.uint8))
pca_img = T.Resize(512, interpolation=T.InterpolationMode.NEAREST)(pca_img)
pca_img.save(os.path.join(save_dir, f"{experiment}_time_{t}.png"))
def load_replacement(x):
try:
hwc = x.shape
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
y = (np.array(y)/255.0).astype(x.dtype)
assert y.shape == x.shape
return y
except Exception:
return x
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def check_safety(x_image):
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
assert x_checked_image.shape[0] == len(has_nsfw_concept)
for i in range(len(has_nsfw_concept)):
if has_nsfw_concept[i]:
x_checked_image[i] = load_replacement(x_checked_image[i])
return x_checked_image