Skip to content

Commit

Permalink
better evals
Browse files Browse the repository at this point in the history
  • Loading branch information
osmarks committed Apr 22, 2024
1 parent 58ce70b commit cebb4f9
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
69 changes: 69 additions & 0 deletions meme-rater/auroc_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch.nn
import torch.nn.functional as F
import torch
import sqlite3
import random
import numpy
import json
import time
from tqdm import tqdm
import torch

from model import Config, BradleyTerry
import shared

batch_size = 128
device = "cuda"

config = Config(
d_emb=1152,
n_hidden=1,
n_ensemble=16,
device=device,
dtype=torch.float32,
dropout=0.1
)
model = BradleyTerry(config).float()
modelc, _ = shared.checkpoint_for(1500)
model.load_state_dict(torch.load(modelc))
params = sum(p.numel() for p in model.parameters())
print(f"{params/1e6:.1f}M parameters")
print(model)

files = shared.fetch_all_files()
ratings = {}

model.eval()
with torch.inference_mode():
for bstart in tqdm(range(0, len(files), batch_size)):
batch = files[bstart:bstart + batch_size]
filenames = [ filename for filename, embedding in batch ]
embs = torch.stack([ torch.Tensor(embedding) for filename, embedding in batch ])
inputs = embs.unsqueeze(0).expand((config.n_ensemble, len(batch), config.d_emb)).to(device)
scores = model.ensemble(inputs).float()
mscores = torch.median(scores, dim=0).values
for filename, mscore in zip(filenames, mscores):
ratings[filename] = float(mscore)

ratings = sorted(ratings.items(), key=lambda x: x[1])
random.shuffle(ratings)

N = 150

buf = f"""<!DOCTYPE html>
<div>
{''.join(f'<div><img src="{"images/" + f}" width="30%"><br><input type=checkbox data-score="{s}"></div>' for i, (f, s) in enumerate(ratings[:N]))}
</div>
<script>
const dump = () => {{
const data = []
for (const x of document.querySelectorAll("input[type=checkbox]")) {{
data.push([parseFloat(x.getAttribute("data-score")), x.checked])
}}
console.log(JSON.stringify(data))
}}
</script>
"""

with open("eval.html", "w") as f:
f.write(buf)
32 changes: 32 additions & 0 deletions meme-rater/roc_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import matplotlib.pyplot as plt
import json

data = json.loads("[[1.2792096138000488,true],[1.1153279542922974,true],[0.9720794558525085,true],[-0.5180545449256897,false],[1.4547114372253418,true],[1.3289614915847778,true],[1.8748269081115723,true],[0.05465051531791687,false],[0.7888763546943665,true],[1.368210792541504,true],[1.4808461666107178,true],[0.9501181244850159,true],[1.2592355012893677,true],[1.0127032995224,true],[-0.8805797100067139,false],[-0.08946493268013,true],[0.4224545955657959,false],[1.0051900148391724,true],[0.5121232271194458,false],[1.0876282453536987,false],[1.5552432537078857,true],[-0.3680466413497925,false],[0.45498305559158325,true],[1.3851803541183472,true],[-0.8842921853065491,false],[2.6869430541992188,false],[1.6892706155776978,false],[0.7087478637695312,false],[-0.5138207077980042,false],[0.16498255729675293,false],[1.265992283821106,true],[0.47311416268348694,false],[0.04918492212891579,false],[1.283980369567871,true],[1.0510015487670898,false],[1.6323922872543335,false],[0.4570896625518799,true],[1.5262614488601685,true],[1.4057230949401855,true],[1.0391144752502441,true],[0.9190238118171692,true],[1.2970502376556396,true],[2.025949478149414,true],[0.6396026611328125,true],[2.3505871295928955,true],[1.0854156017303467,false],[1.0216373205184937,true],[-1.163207769393921,false],[1.8854788541793823,true],[0.249663308262825,false],[-0.8619526028633118,false],[1.9995672702789307,true],[1.0939114093780518,false],[0.6106101870536804,false],[1.8383781909942627,false],[-0.0637127161026001,false],[-0.34953051805496216,false],[0.988452672958374,false],[0.5209289193153381,false],[-0.4708566963672638,false],[0.4715256690979004,false],[-0.7905446887016296,false],[2.0255637168884277,true],[0.8488644361495972,false],[1.6645262241363525,true],[1.0948383808135986,true],[-0.8315924406051636,false],[1.5533114671707153,true],[0.9333463907241821,true],[-0.5723654627799988,false],[1.9510998725891113,true],[0.2842162549495697,false],[1.1901239156723022,false],[1.5058742761611938,false],[0.7622374296188354,false],[0.2894713282585144,false],[0.0965774804353714,false],[0.6335093379020691,false],[-0.7369110584259033,false],[1.2673722505569458,true],[0.9775630235671997,false],[0.7889275550842285,false],[-0.9432369470596313,false],[0.24122865498065948,false],[1.075297474861145,false],[0.545269250869751,false],[-0.1398508995771408,false],[-0.31118375062942505,false],[1.47971510887146,false],[0.5115379691123962,true],[0.8894630074501038,true],[0.4365079700946808,true],[2.5944597721099854,true],[0.8613907694816589,false],[1.1540073156356812,false],[1.6798168420791626,true],[1.5266021490097046,true],[0.2556634545326233,false],[0.90388423204422,false],[0.36393579840660095,false],[1.297504186630249,true],[1.091887354850769,true],[0.931088924407959,true],[0.8854649066925049,true],[0.0385725162923336,false],[1.5259686708450317,true],[-0.725635826587677,false],[-1.72086501121521,false],[1.9044498205184937,true],[-0.10369344800710678,false],[-0.5889104604721069,true],[0.2478746473789215,false],[1.4628609418869019,false],[1.1434470415115356,false],[0.20635242760181427,false],[0.8324120044708252,false],[0.676543653011322,false],[1.1111537218093872,true],[0.0488731786608696,false],[0.8705015182495117,true],[0.5464357733726501,true],[0.6190940737724304,true],[0.33756133913993835,false],[0.8019527196884155,true],[1.1540179252624512,true],[-1.4343260526657104,true],[1.4069069623947144,true],[0.5078597664833069,true],[0.1831521838903427,false],[-0.5352457761764526,false],[1.3706591129302979,true],[-0.8636290431022644,false],[0.8164027333259583,false],[0.6665022969245911,false],[0.5028047561645508,false],[-0.7765756845474243,false],[1.204775333404541,false],[1.2527906894683838,false],[0.7420544028282166,false],[1.0363034009933472,true],[1.0559784173965454,false],[-0.72457355260849,false],[1.9217685461044312,true],[0.9770780205726624,false],[0.8808136582374573,true],[1.0174754858016968,false],[0.4287119507789612,false],[1.0718724727630615,true],[0.8409612774848938,true],[-1.3366127014160156,false]]")
data = sorted(data, reverse=True)

tprs, fprs = [], []
positives = sum(1 for _, ground_truth in data if ground_truth)
negatives = len(data) - positives

for threshold, _ in data:
tp = sum(1 for score, ground_truth in data if ground_truth and score >= threshold)
fp = sum(1 for score, ground_truth in data if not ground_truth and score >= threshold)
tpr = tp / positives
fpr = fp / negatives
tprs.append(tpr)
fprs.append(fpr)

auroc = 0
for i in range(len(fprs) - 1):
auroc += (fprs[i+1] - fprs[i]) * (tprs[i+1] + tprs[i]) / 2

print(f"AUROC: {auroc}")

plt.plot(fprs, tprs)

plt.xlabel("FPR")
plt.ylabel("TPR")
plt.title("ROC")

plt.tight_layout()
plt.show()

0 comments on commit cebb4f9

Please sign in to comment.