Skip to content

Commit

Permalink
adapted Nodes to be compatible to inference
Browse files Browse the repository at this point in the history
  • Loading branch information
a1302z committed Aug 24, 2020
1 parent 24ef01f commit 42ab442
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 128 deletions.
5 changes: 5 additions & 0 deletions Node/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
"--config", type=str, help="Path to config",
)

parser.add_argument(
"--mean_std_file", type=str, help="Path to mean std file for inference data."
)

parser.set_defaults(use_test_config=False)

if __name__ == "__main__":
Expand All @@ -82,6 +86,7 @@
test_config={"SQLALCHEMY_DATABASE_URI": db_path},
data_dir=args.data_directory,
config_file=args.config,
mean_std_file=args.mean_std_file,
)
# else:
# app = create_app(
Expand Down
251 changes: 152 additions & 99 deletions Node/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def create_app(
test_config=None,
data_dir=None,
config_file=None,
mean_std_file: str = None,
) -> Flask:
"""Create flask application.
Expand Down Expand Up @@ -199,7 +200,11 @@ def create_app(
# add data
if data_dir:
import configparser
import albumentations as a
from inference import PathDataset
from torchlib.utils import Arguments
from torchlib.dicomtools import CombinedLoader
from torchlib.dataloader import AlbumentationsTorchTransform
from os import path
from argparse import Namespace
from random import seed as r_seed
Expand Down Expand Up @@ -227,111 +232,159 @@ def create_app(
np.random.seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if args.data_dir == "mnist":
node_id = local_worker.id
KEEP_LABELS_DICT = {
"alice": [0, 1, 2, 3],
"bob": [4, 5, 6],
"charlie": [7, 8, 9],
None: list(range(10)),
}
dataset = LabelMNIST(
labels=KEEP_LABELS_DICT[node_id]
if node_id in KEEP_LABELS_DICT
else KEEP_LABELS_DICT[None],
root="./data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)),]
),
if node_id == "data_owner":
print("setting up an data to be remotely classified.")
tf = [
a.Resize(args.inference_resolution, args.inference_resolution),
a.CenterCrop(args.inference_resolution, args.inference_resolution),
]
if hasattr(args, "clahe") and args.clahe:
tf.append(a.CLAHE(always_apply=True, clip_limit=(1, 1)))
if mean_std_file:
mean_std = torch.load(mean_std_file)
if "val_mean_std" in mean_std:
mean_std = mean_std["val_mean_std"]
mean, std = mean_std
else:
raise RuntimeError(
"To set up a data owner for inference we need a file which tells"
" us how to normalize the data."
)
tf.extend(
[
a.ToFloat(max_value=255.0),
a.Normalize(
mean.cpu().numpy()[None, None, :],
std.cpu().numpy()[None, None, :],
max_pixel_value=1.0,
),
]
)
tf = AlbumentationsTorchTransform(a.Compose(tf))
loader = CombinedLoader()

dataset = PathDataset(data_dir, transform=tf, loader=loader,)
data = []
for d in tqdm(dataset, total=len(dataset), leave=False, desc="load data"):
data.append(d)
data = torch.stack(data) # pylint:disable=no-member
data.tag("#inference_data")
local_worker.load_data([data])
else:
stats_dataset = ImageFolder(
data_dir,
loader=loader,
transform=transforms.Compose(
[
transforms.Resize(args.train_resolution),
transforms.CenterCrop(args.train_resolution),
transforms.ToTensor(),
if args.data_dir == "mnist":
node_id = local_worker.id
KEEP_LABELS_DICT = {
"alice": [0, 1, 2, 3],
"bob": [4, 5, 6],
"charlie": [7, 8, 9],
None: list(range(10)),
}
dataset = LabelMNIST(
labels=KEEP_LABELS_DICT[node_id]
if node_id in KEEP_LABELS_DICT
else KEEP_LABELS_DICT[None],
root="./data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
)
else:
stats_dataset = ImageFolder(
data_dir,
loader=loader,
transform=transforms.Compose(
[
transforms.Resize(args.train_resolution),
transforms.CenterCrop(args.train_resolution),
transforms.ToTensor(),
]
),
)
assert (
len(stats_dataset.classes) == 3
), "We can only handle data that has 3 classes: normal, bacterial and viral"
mean, std = calc_mean_std(stats_dataset, save_folder=data_dir,)
del stats_dataset
target_tf = None
if args.mixup or args.weight_classes:
target_tf = [
lambda x: torch.tensor(x), # pylint:disable=not-callable
To_one_hot(3),
]
),
)
assert (
len(stats_dataset.classes) == 3
), "We can only handle data that has 3 classes: normal, bacterial and viral"
mean, std = calc_mean_std(stats_dataset, save_folder=data_dir,)
del stats_dataset
target_tf = None
if args.mixup or args.weight_classes:
target_tf = [
lambda x: torch.tensor(x), # pylint:disable=not-callable
To_one_hot(3),
]
dataset = ImageFolder(
# path.join("data/server_simulation/", "validation")
# if worker.id == "validation"
# else
data_dir,
loader=loader,
transform=create_albu_transform(args, mean, std),
target_transform=transforms.Compose(target_tf) if target_tf else None,
)
assert (
len(dataset.classes) == 3
), "We can only handle data that has 3 classes: normal, bacterial and viral"
mean.tag("#datamean")
std.tag("#datastd")
local_worker.load_data([mean, std])
data, targets = [], []
# repetitions = 1 if worker.id == "validation" else args.repetitions_dataset
if args.mixup:
dataset = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
mixup = MixUp(λ=args.mixup_lambda, p=args.mixup_prob)
last_set = None
for j in tqdm(
range(args.repetitions_dataset),
total=args.repetitions_dataset,
leave=False,
desc="register data on {:s}".format(local_worker.id),
):
for d, t in tqdm(
dataset,
total=len(dataset),
dataset = ImageFolder(
# path.join("data/server_simulation/", "validation")
# if worker.id == "validation"
# else
data_dir,
loader=loader,
transform=create_albu_transform(args, mean, std),
target_transform=transforms.Compose(target_tf)
if target_tf
else None,
)
assert (
len(dataset.classes) == 3
), "We can only handle data that has 3 classes: normal, bacterial and viral"
mean.tag("#datamean")
std.tag("#datastd")
local_worker.load_data([mean, std])
data, targets = [], []
# repetitions = 1 if worker.id == "validation" else args.repetitions_dataset
if args.mixup:
dataset = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=True
)
mixup = MixUp(λ=args.mixup_lambda, p=args.mixup_prob)
last_set = None
for j in tqdm(
range(args.repetitions_dataset),
total=args.repetitions_dataset,
leave=False,
desc="register data {:d}. time".format(j + 1),
desc="register data on {:s}".format(local_worker.id),
):
if args.mixup:
original_set = (d, t)
if last_set:
# pylint:disable=unsubscriptable-object
d, t = mixup(((d, last_set[0]), (t, last_set[1])))
last_set = original_set
data.append(d)
targets.append(t)
selected_data = torch.stack(data) # pylint:disable=no-member
selected_targets = (
torch.stack(targets) # pylint:disable=no-member
if args.mixup or args.weight_classes
else torch.tensor(targets) # pylint:disable=not-callable
)
if args.mixup:
selected_data = selected_data.squeeze(1)
selected_targets = selected_targets.squeeze(1)
del data, targets
selected_data.tag("#traindata",) # "#valdata" if worker.id == "validation" else
selected_targets.tag(
# "#valtargets" if worker.id == "validation" else
"#traintargets",
)
local_worker.load_data([selected_data, selected_targets])
print(
"registered {:d} samples of {:s} data".format(
selected_data.size(0), args.data_dir
for d, t in tqdm(
dataset,
total=len(dataset),
leave=False,
desc="register data {:d}. time".format(j + 1),
):
if args.mixup:
original_set = (d, t)
if last_set:
# pylint:disable=unsubscriptable-object
d, t = mixup(((d, last_set[0]), (t, last_set[1])))
last_set = original_set
data.append(d)
targets.append(t)
selected_data = torch.stack(data) # pylint:disable=no-member
selected_targets = (
torch.stack(targets) # pylint:disable=no-member
if args.mixup or args.weight_classes
else torch.tensor(targets) # pylint:disable=not-callable
)
)
del selected_data, selected_targets
if args.mixup:
selected_data = selected_data.squeeze(1)
selected_targets = selected_targets.squeeze(1)
del data, targets
selected_data.tag(
"#traindata",
) # "#valdata" if worker.id == "validation" else
selected_targets.tag(
# "#valtargets" if worker.id == "validation" else
"#traintargets",
)
local_worker.load_data([selected_data, selected_targets])
print(
"registered {:d} samples of {:s} data".format(
selected_data.size(0), args.data_dir
)
)
del selected_data, selected_targets

# Register app blueprints
app.register_blueprint(main_routes, url_prefix=r"/")
Expand Down
8 changes: 4 additions & 4 deletions configs/torch/pneumonia-resnet-pretrained.ini
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
[config]
batch_size = 1
batch_size = 200
train_resolution = 224
;inference_resolution = 512
test_batch_size = 1
test_interval = 1
validation_split = 10
epochs = 1
epochs = 40
lr = 1e-4
end_lr = 1e-5
restarts = 0
Expand Down Expand Up @@ -55,8 +55,8 @@ grid_dropout = no


[federated]
sync_every_n_batch = 1
sync_every_n_batch = 5
wait_interval = 0.1
keep_optim_dict = no
repetitions_dataset = 1
repetitions_dataset = 5
weighted_averaging = yes
Loading

0 comments on commit 42ab442

Please sign in to comment.