-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
removed duplicate functions from util
- Loading branch information
rmdocherty
committed
Jul 30, 2024
1 parent
2c14f99
commit 7c6bfda
Showing
1 changed file
with
180 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
import numpy as np | ||
import torch | ||
import slicegan | ||
from scipy import stats, ndimage | ||
from matplotlib import pyplot as plt | ||
|
||
import core | ||
|
||
|
||
def load_generator(Project_path): | ||
img_size, img_channels, scale_factor = 64, 1, 1 | ||
z_channels = 16 | ||
lays = 6 | ||
dk, gk = [4] * lays, [4] * lays | ||
ds, gs = [2] * lays, [2] * lays | ||
df, gf = [img_channels, 64, 128, 256, 512, 1], [ | ||
z_channels, | ||
512, | ||
256, | ||
128, | ||
64, | ||
img_channels, | ||
] | ||
dp, gp = [1, 1, 1, 1, 0], [2, 2, 2, 2, 3] | ||
|
||
## Create Networks | ||
netD, netG = slicegan.networks.slicegan_nets( | ||
Project_path, False, "grayscale", dk, ds, df, dp, gk, gs, gf, gp | ||
) | ||
netG = netG() | ||
netG = netG.cuda() | ||
return netG | ||
|
||
|
||
def generate_image(netG, slice_dim=0, lf=50, threed=False, reps=50): | ||
|
||
netG.eval() | ||
imgs = [] | ||
|
||
for _ in range(reps): | ||
if (_ % 50) == 0 and _ != 0: | ||
print(f"generating image {_}") | ||
noise = torch.randn(1, 16, lf if threed else 4, lf, lf) | ||
noise.transpose_(2, slice_dim + 2) | ||
noise = noise.cuda() | ||
img = netG(noise, threed, slice_dim) | ||
img = slicegan.util.post_proc(img) | ||
img.transpose_(0, slice_dim) | ||
if not threed: | ||
imgs.append(img[0]) | ||
else: | ||
imgs.append(img.cpu()) | ||
img = torch.stack(imgs, 0) | ||
return img.float() | ||
|
||
|
||
def angular_img(img): | ||
base_len, l = img.shape[0:2] | ||
img = img.cpu().numpy() | ||
plt.imshow(img[0, :100, :100]) | ||
plt.show() | ||
img_rot = ndimage.rotate(img, base_len / l * 90, axes=(1, 0), reshape=False) | ||
for i in range(img_rot.shape[0]): | ||
print(f"slice {i}") | ||
plt.imshow(img_rot[i, :100, :100]) | ||
plt.show() | ||
plt.imshow(img_rot[i, -100:, -100:]) | ||
plt.show() | ||
return img_rot | ||
|
||
|
||
def stat_analysis_error(img, pf, edge_lengths): # TODO see if to delete this or not | ||
img_dims = [np.array((l,) * (len(img.shape) - 1)) for l in edge_lengths] | ||
err_exp = real_image_stats(img, edge_lengths, pf) | ||
real_cls = core.fit_statisical_cls_from_errors(err_exp, img_dims, pf) | ||
# TODO different size image 1000 vs 1500 | ||
return real_cls | ||
|
||
|
||
def real_image_stats(img, ls, pf, repeats=4000, conf=0.95): | ||
"""Calculates the error of the stat. analysis for different edge lengths. | ||
The error is calculated by the std of the mean of the subimages divided by the pf. | ||
params: | ||
img: the image to calculate the error for (Should be a stack of images). | ||
ls: the edge lengths to calculate the error for. | ||
pf: the phase fraction of the image. | ||
repeats: the number of repeats for each edge length. | ||
conf: the confidence level for the error.""" | ||
dims = len(img[0].shape) | ||
errs = [] | ||
for l in ls: | ||
pfs = [] | ||
n_pos_ims = int(np.prod(img.shape) / l**dims) | ||
repeats = n_pos_ims * 2 | ||
# print(f'one im repeats = {repeats} for l = {l}') | ||
if dims == 1: | ||
for _ in range(repeats): | ||
bm, xm = img.shape | ||
x = torch.randint(0, xm - l, (1,)) | ||
b = torch.randint(0, bm, (1,)) | ||
crop = img[b, x : x + l] | ||
pfs.append(torch.mean(crop).cpu()) | ||
elif dims == 2: | ||
for _ in range(repeats): | ||
bm, xm, ym = img.shape | ||
x = torch.randint(0, xm - l, (1,)) | ||
y = torch.randint(0, ym - l, (1,)) | ||
b = torch.randint(0, bm, (1,)) | ||
crop = img[b, x : x + l, y : y + l] | ||
pfs.append(torch.mean(crop).cpu()) | ||
else: # 3D | ||
for _ in range(repeats): | ||
bm, xm, ym, zm = img.shape | ||
x = torch.randint(0, xm - l, (1,)) | ||
y = torch.randint(0, ym - l, (1,)) | ||
z = torch.randint(0, zm - l, (1,)) | ||
b = torch.randint(0, bm, (1,)) | ||
crop = img[b, x : x + l, y : y + l, z : z + l] | ||
pfs.append(torch.mean(crop).cpu()) | ||
pfs = np.array(pfs) | ||
ddof = 1 # for unbiased std | ||
std = np.std(pfs, ddof=ddof) | ||
errs.append(100 * (stats.norm.interval(conf, scale=std)[1] / pf)) | ||
return errs | ||
|
||
|
||
def bernouli_from_cls(cls, pf, img_size, conf=0.95): | ||
ns = core.n_samples_from_dims([np.array(img_size)], cls) | ||
return core.bernouli(pf, ns, conf) | ||
|
||
|
||
# fit_cls now fit_statistical_cls_from_errors | ||
|
||
# ns_from_dims now n_samples_from_dims | ||
|
||
# test_cls_set now test_all_cls_in_range | ||
|
||
|
||
def tpc_fit(x, a, b, c): | ||
return a * np.e ** (-b * x) + c | ||
|
||
|
||
def percentage_error(y_true, y_pred): | ||
return (y_true - y_pred) / y_true | ||
|
||
|
||
def mape(y_true, y_pred): # mean absolute percentage error | ||
return np.mean(np.abs(percentage_error(y_true, y_pred))) | ||
|
||
|
||
def mape_linear_objective(params, y_pred, y_true): | ||
y_pred_new = linear_fit(y_pred, *params) | ||
return mape(y_true, y_pred_new) | ||
|
||
|
||
def linear_fit(x, m, b): | ||
return m * x + b | ||
|
||
|
||
def optimize_error_conf_pred(bern_conf, total_conf, std_bern, std_model, pf): | ||
model_conf = total_conf / bern_conf | ||
err_bern = stats.norm.interval(bern_conf, scale=std_bern)[1] | ||
one_side_error_model = model_conf * 2 - 1 | ||
err_model = stats.norm.interval(one_side_error_model, scale=std_model)[1] | ||
return err_bern * (1 + err_model) | ||
|
||
|
||
def optimize_error_n_pred(bern_conf, total_conf, std_model, pf, err_targ): | ||
model_conf = total_conf / bern_conf | ||
z1 = stats.norm.interval(bern_conf)[1] | ||
one_side_error_model = model_conf * 2 - 1 | ||
err_model = stats.norm.interval(one_side_error_model, scale=std_model)[1] | ||
num = (err_model + 1) ** 2 * (1 - pf) * z1**2 * pf | ||
den = (err_targ) ** 2 # TODO go over the calcs and see if this is right | ||
return num / den | ||
|
||
|
||
# renamed calc_autocorrelation_orthant -> autocorrelation_orthant | ||
|
||
# renamed one_img_stat_analysis_error -> stat_analysis_error_classic |