Skip to content

Commit

Permalink
Enable aug search
Browse files Browse the repository at this point in the history
  • Loading branch information
daochenzha committed Mar 5, 2022
1 parent 1447859 commit 42b6e47
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
Inputs = container.DataFrame

class Hyperparams(hyperparams.Hyperparams):
scale = hyperparams.Set[float](
scale = hyperparams.Hyperparameter[tuple](
default=(0, 15),
description='Standard deviation of the normal distribution that generates the noise. Must be >=0. If 0 then loc will simply be added to all pixels.',
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
)

per_channel = hyperparams.Constant[bool](
per_channel = hyperparams.Hyperparameter[bool](
default=True,
description='Whether to use (imagewise) the same sample(s) for all channels (False) or to sample value(s) for each channel (True). Setting this to True will therefore lead to different transformations per image and channel, otherwise only per image.',
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
Expand Down Expand Up @@ -61,4 +61,4 @@ def _get_function(self):
scale = self.hyperparams["scale"]
per_channel = self.hyperparams["per_channel"]
seed = self.hyperparams["seed"]
return iaa.AdditiveGaussianNoise(scale=scale, per_channel=per_channel, seed=seed)
return iaa.AdditiveGaussianNoise(scale=scale, per_channel=per_channel, seed=seed)
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
Inputs = container.DataFrame

class Hyperparams(hyperparams.Hyperparams):
scale = hyperparams.Set[float](
scale = hyperparams.Hyperparameter[tuple](
default=(0, 15),
description='Standard deviation of the normal distribution that generates the noise. Must be >=0. If 0 then loc will simply be added to all pixels.',
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
)

per_channel = hyperparams.Constant[bool](
per_channel = hyperparams.Hyperparameter[bool](
default=True,
description='Whether to use (imagewise) the same sample(s) for all channels (False) or to sample value(s) for each channel (True). Setting this to True will therefore lead to different transformations per image and channel, otherwise only per image.',
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
Expand Down Expand Up @@ -61,4 +61,4 @@ def _get_function(self):
scale = self.hyperparams["scale"]
per_channel = self.hyperparams["per_channel"]
seed = self.hyperparams["seed"]
return iaa.AdditiveLaplaceNoise(scale=scale, per_channel=per_channel, seed=seed)
return iaa.AdditiveLaplaceNoise(scale=scale, per_channel=per_channel, seed=seed)
6 changes: 3 additions & 3 deletions autovideo/augmentation/geometric/Jigsaw_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@

class Hyperparams(hyperparams.Hyperparams):

nb_rows = hyperparams.Set[int](
nb_rows = hyperparams.Hyperparameter[tuple](
default=(3, 10),
description="How many rows the jigsaw pattern should have.",
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
)

nb_cols = hyperparams.Set[int](
nb_cols = hyperparams.Hyperparameter[tuple](
default=(3, 10),
description="How many cols the jigsaw pattern should have.",
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
)

max_steps = hyperparams.Set[int](
max_steps = hyperparams.Hyperparameter[tuple](
default=(1, 5),
description="How many steps each jigsaw cell may be moved.",
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
Expand Down
2 changes: 1 addition & 1 deletion autovideo/augmentation/geometric/Rotate_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

class Hyperparams(hyperparams.Hyperparams):

rotate = hyperparams.Set[int](
rotate = hyperparams.Hyperparameter[tuple](
default=(-45, 45),
description="See Affine.",
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
Expand Down
4 changes: 2 additions & 2 deletions autovideo/entry_points.ini
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ autovideo.augmentation.geometric_PiecewiseAffine = autovideo.augmentation.geomet
autovideo.augmentation.geometric_PerspectiveTransform = autovideo.augmentation.geometric.PerspectiveTransform_primitive:PerspectiveTransformPrimitive
autovideo.augmentation.geometric_ElasticTransformation = autovideo.augmentation.geometric.ElasticTransformation_primitive:ElasticTransformationPrimitive
autovideo.augmentation.geometric_Rot90 = autovideo.augmentation.geometric.Rot90_primitive:Rot90Primitive
autovideo.augmentation.WithPolarWarping = autovideo.augmentation.geometric.WithPolarWarping_primitive:WithPolarWarpingPrimitive
autovideo.augmentation.Jigsaw = autovideo.augmentation.geometric.Jigsaw_primitive:JigsawPrimitive
autovideo.augmentation.geometric_WithPolarWarping = autovideo.augmentation.geometric.WithPolarWarping_primitive:WithPolarWarpingPrimitive
autovideo.augmentation.geometric_Jigsaw = autovideo.augmentation.geometric.Jigsaw_primitive:JigsawPrimitive

autovideo.augmentation.imgcorruptlike_GaussianNoise = autovideo.augmentation.imgcorruptlike.GaussianNoise_primitive:GaussianNoisePrimitive
autovideo.augmentation.imgcorruptlike_ShotNoise = autovideo.augmentation.imgcorruptlike.ShotNoise_primitive:ShotNoisePrimitive
Expand Down
28 changes: 27 additions & 1 deletion autovideo/searcher/ray_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def search(self, search_space, config):

self.valid_labels = self.valid_dataset['label']
self.valid_dataset = self.valid_dataset.drop(['label'], axis=1)


search_space = flatten_search_space(search_space)

analysis = tune.run(
self._evaluate,
config=search_space,
Expand All @@ -43,10 +45,12 @@ def search(self, search_space, config):
name=config["searching_algorithm"]+"_"+str(config["num_samples"])
)
best_config = analysis.get_best_config(metric="accuracy")
best_config = unflatten_config(best_config)

return best_config

def _evaluate(self, config):
config = unflatten_config(config)
pipeline = build_pipeline(config)

# Fit and produce
Expand All @@ -63,3 +67,25 @@ def _evaluate(self, config):
valid_acc = compute_accuracy_with_preds(predictions['label'], self.valid_labels)
tune.report(accuracy=valid_acc)

def flatten_search_space(search_space):
flattened_search_space = {}
augmentation = search_space.pop("augmentation", {})
for key in augmentation:
flattened_search_space["augmentation:"+key] = augmentation[key]
for key in search_space:
flattened_search_space[key] = search_space[key]

return flattened_search_space

def unflatten_config(config):
unflattened_config = {}
for key in config:
if key.startswith("augmentation"):
if "augmentation" not in unflattened_config:
unflattened_config["augmentation"] = []
unflattened_config["augmentation"].append(config[key])
else:
unflattened_config[key] = config[key]

return unflattened_config

44 changes: 14 additions & 30 deletions autovideo/utils/d3m_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,40 +63,26 @@ def _update_predictions_metadata(inputs_metadata: metadata_base.DataMetadata, ou
def build_pipeline(config):
"""Build a pipline based on the config
"""
default_config = {
"transformation": [],
"augmentation": [],
"multi_aug": None,
"algorithm": "tsn",
}
for key in config:
default_config[key] = config[key]
config = default_config

from d3m import index
from d3m.metadata.base import ArgumentType
from d3m.metadata.pipeline import Pipeline, PrimitiveStep
algorithm = config.pop('algorithm', None)
transformation = config.pop('transformation', None)
algorithm = config.pop('algorithm', "tsn")
transformation = config.pop('transformation', [])
transformation_methods = [transformation[i][0] for i in range(len(transformation))]
augmentation = config.pop('augmentation', None)
augmentation = config.pop('augmentation', [])
augmentation_methods = [augmentation[i][0] for i in range(len(augmentation))]
if len(augmentation) > 0 and len(augmentation[0]) > 1:
augmentation_configs = []
for i in range(len(augmentation)):
try:
augmentation_configs.append(augmentation[i][1])
except:
augmentation_configs.append(None)
#augmentation_configs = [augmentation[i][1] for i in range(len(augmentation))]
else:
augmentation_configs = None
multi_aug = config.pop('multi_aug', 'meta_Sequential')

if len(transformation) > 0 and len(transformation[0]) > 1:
transformation_configs = [transformation[i][1] for i in range(len(transformation))]
else:
transformation_configs = None
# Read augmentation hyperparameters
augmentation_configs = []
for i in range(len(augmentation)):
if len(augmentation[i]) > 1:
augmentation_configs.append(augmentation[i][1])
else:
augmentation_configs.append(None)
multi_aug = config.pop('multi_aug', None)

# Read transformation hyperparameters
transformation_configs = [transformation[i][1] for i in range(len(transformation))]

# Creating pipeline
pipeline_description = Pipeline()
Expand Down Expand Up @@ -176,8 +162,6 @@ def build_pipeline(config):
alg_python_path = 'd3m.primitives.autovideo.augmentation.'+multi_aug
step_7 = PrimitiveStep(primitive=index.get_primitive(alg_python_path))
step_7.add_argument(name='inputs', argument_type=ArgumentType.CONTAINER, data_reference='steps.'+str(curr_step_no)+'.produce')
#for key, value in config.items():
# step_6.add_hyperparameter(name=key, argument_type=ArgumentType.VALUE, data=value)
step_7.add_output('produce')
pipeline_description.add_step(step_7)
curr_step_no += 1
Expand Down
15 changes: 15 additions & 0 deletions examples/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,22 @@ def run(args):

#Search Space
search_space = {
"augmentation": {
"aug_0": tune.choice([
("arithmetic_AdditiveGaussianNoise",),
("arithmetic_AdditiveLaplaceNoise",),
]),
"aug_1": tune.choice([
("geometric_Rotate",),
("geometric_Jigsaw",),
]),
},
"multi_aug": tune.choice([
"meta_Sometimes",
"meta_Sequential",
]),
"algorithm": tune.choice(["tsn"]),
"epochs": tune.choice([1]),
"learning_rate": tune.uniform(0.0001, 0.001),
"momentum": tune.uniform(0.9,0.99),
"weight_decay": tune.uniform(5e-4,1e-3),
Expand Down

0 comments on commit 42b6e47

Please sign in to comment.