Skip to content

Commit

Permalink
++
Browse files Browse the repository at this point in the history
  • Loading branch information
achanhon committed Feb 12, 2024
1 parent 7107ec8 commit 885022d
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions geometry/firsttest/entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy
import PIL
from PIL import Image
import torch
import torchvision


class FeatureExtractor(torch.nn.Module):
def __init__(self):
super(FeatureExtractor, self).__init__()
tmp = torchvision.models.efficientnet_v2_s(weights="DEFAULT").features
del tmp[7]
del tmp[6]
del tmp[5]
self.features = tmp.cuda().half()

def forward(self, x):
x = x.cuda().half() / 255.0
x = (x - 0.5) / 0.25
return self.features(x)


with torch.no_grad():
net = FeatureExtractor()

image = PIL.Image.open("/scratchf/DFC2015/BE_ORTHO_27032011_315130_56865.tif")
image = numpy.asarray(image.convert("RGB").copy())
image = torch.Tensor(image)[0:9920, 0:9920, :].clone()
image = torch.stack([image[:, :, 0], image[:, :, 1], image[:, :, 2]], dim=0)
image = image.unsqueeze(0)

print("extract data")
allfeatures = torch.zeros(128, 620, 620)
for r in range(10):
for c in range(10):
x = image[:, :, 992 * r : 992 * (r + 1), 992 * c : 992 * (c + 1)]
z = net(x)[0].cpu()
allfeatures[:, 62 * r : 62 * (r + 1), 62 * c : 62 * (c + 1)] = z

allfeatures = torch.nn.functional.avg_pool2d(
allfeatures.unsqueeze(0), kernel_size=2, stride=2
)
allfeatures = allfeatures[0]

print("extract stats")
allfeatures = allfeatures.cuda().half()

GRAM = torch.nn.functional.softmax(torch.matmul(allfeatures.transpose(0, 1), allfeatures),1)

del allfeatures
torch.diagonal(GRAM).fill_(-10)
assert GRAM.shape == (310 * 310, 310 * 310)

maxGRAM, _ = GRAM.max(1)
assert GRAM.shape[0] == 310 * 310
del GRAM

seuil = sorted(list(maxGRAM.cpu().numpy()))
seuil = float(seuil[100])
maxGRAM = maxGRAM.view(310, 310).cpu()
print((maxGRAM < seuil).float().sum())

image620 = torch.nn.functional.interpolate(image, size=310, mode="bilinear")
image620 = image620[0] / 255
torchvision.utils.save_image(image620, "build/image.png")

image620 *= (maxGRAM <= seuil).float().unsqueeze(0)
torchvision.utils.save_image(image620, "build/amer.png")

0 comments on commit 885022d

Please sign in to comment.