Skip to content

Commit

Permalink
Lots of changes. New custom VGG16 Object Oriented Programming class d…
Browse files Browse the repository at this point in the history
…efinitions in models/vgg16.py for both 3 channel and 1 channel VGG16 pretrained on RGB Imagenet (input weights averaged across channels).\n Added callback for intermittently plotting images from a separately provided tf.data iterator in Tensorboard.
  • Loading branch information
JacobARose committed Mar 4, 2020
1 parent be9a745 commit 5cc0b69
Show file tree
Hide file tree
Showing 17 changed files with 1,113 additions and 189 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.ipynb_checkpoints
*/.ipynb_checkpoints/*
*/*/.ipynb_checkpoints/*
*/*/*/.ipynb_checkpoints/*

*/dask-worker-space/*
*/*/dask-worker-space/*
Expand Down
148 changes: 117 additions & 31 deletions pyleaves/analysis/img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,9 @@ def subset(self, new_subset):
self.output_dir = join(self.root_dir,self._subset)
ensure_dir_exists(self.output_dir)


def gen_shard_filepath(self, shard_key, output_dir):
'''
e.g. shard_filepath = self.gen_shard_filepth(shard_key=0, output_dir, output_base_name='train', num_shards=10)
e.g. shard_filepath = self.gen_shard_filepth(shard_key=0, output_dir)
'''
shard_fname = f'{self.subset}-{str(shard_key).zfill(5)}-of-{str(self.num_shards).zfill(5)}.tfrecord'
shard_filepath = os.path.join(output_dir,shard_fname)
Expand All @@ -299,10 +298,8 @@ def parse_image(self, src_filepath, label):

img = tf.io.read_file(src_filepath)
img = tf.image.decode_image(img, channels=3)
img = tf.image.resize_image_with_pad(img, *self.target_size)
img = tf.compat.v1.image.resize_image_with_pad(img, *self.target_size)
return img, label
# img = tf.image.resize_with_crop_or_pad(img, *self.target_size)
# return img, label

def encode_example(self, img, label):
img = tf.image.encode_jpeg(img, optimize_size=True, chroma_downsampling=False)
Expand All @@ -314,7 +311,6 @@ def encode_example(self, img, label):
example_proto = tf.train.Example(features=tf.train.Features(feature=features))
return example_proto.SerializeToString()


def decode_example(self, example):
feature_description = {
'image/bytes': tf.io.FixedLenFeature([], tf.string),
Expand All @@ -323,18 +319,13 @@ def decode_example(self, example):
features = tf.io.parse_single_example(example,features=feature_description)

img = tf.image.decode_jpeg(features['image/bytes'], channels=3) # * 255.0
img = tf.image.convert_image_dtype(img, dtype=tf.uint8)
#img = tf.image.convert_image_dtype(img, dtype=tf.float32)
img = tf.image.resize_image_with_pad(img, *self.target_size)
if self.num_channels==1:
img = tf.image.rgb_to_grayscale(img)
img = tf.compat.v1.image.resize_image_with_pad(img, *self.target_size)

label = tf.cast(features['label'], tf.int32)
label = tf.one_hot(label, depth=self.num_classes)

return img, label
# img = preprocess_input(img)

return img, label

def _get_sharded_dataset(self, paths, labels, shard_size):
return tf.data.Dataset.from_tensor_slices((paths, labels)) \
.map(self.parse_image,num_parallel_calls=AUTOTUNE) \
Expand All @@ -348,7 +339,6 @@ def stage_dataset(self, data):
shard_size = self.num_samples//self.num_shards
print('self.num_shards',self.num_shards)
return self._get_sharded_dataset(paths, labels, shard_size)
# return tf.data.Dataset.from_tensor_slices((paths, labels))

def execute_batch(self, shard_id, images, labels):
try:
Expand All @@ -371,7 +361,7 @@ def execute_batch(self, shard_id, images, labels):
print("Unexpected error:", sys.exc_info())
print(f'[ERROR] {e}')
raise
# @tf.function

def execute_convert(self):
print(f"converting {self.num_samples} images to tfrecord")
staged_data = self.stage_dataset(data=self.data)
Expand Down Expand Up @@ -444,37 +434,82 @@ def get_keras_preprocessing_function(model_name: str, input_format=tuple):

if input_format==dict:
def preprocess_func(input_example):

x, y = input_example['image'], input_example['label']
x = input_example['image']
y = input_example['label']
return preprocess_input(x), y
_temp = {'image':tf.zeros([4, 32, 32, 3]), 'label':tf.zeros(())}

preprocess_func(_temp)

elif input_format==tuple:
_temp = ( tf.zeros([4, 32, 32, 3]), tf.zeros(()) )
def preprocess_func(x, y):
return preprocess_input(x), y

_temp = ( tf.zeros([4, 32, 32, 3]), tf.zeros(()) )
preprocess_func(*_temp)
else:
print('''input_format must be either dict or tuple, corresponding to data organized as:
tuple: (x, y)
or
dict: {'image':x, 'label':y}
''')
return None
preprocess_func(_temp)

return preprocess_func






# def get_keras_preprocessing_function(model_name: str, input_format=tuple):
# '''
# if input_dict_format==True:
# Includes value unpacking in preprocess function to accomodate TFDS {'image':...,'label':...} format
# '''
# if model_name == 'vgg16':
# from tensorflow.keras.applications.vgg16 import preprocess_input
# elif model_name == 'xception':
# from tensorflow.keras.applications.xception import preprocess_input
# elif model_name in ['resnet_50_v2','resnet_101_v2']:
# from tensorflow.keras.applications.resnet_v2 import preprocess_input
# else:
# preprocess_input = lambda x: x

# if input_format==dict:
# def preprocess_func(input_example):

# x = tf.cast(input_example['image'],tf.float32)
# y = input_example['label']
# return preprocess_input(x), y
# _temp = {'image':tf.zeros([4, 32, 32, 3]), 'label':tf.zeros(())}
# preprocess_func(_temp)

# elif input_format==tuple:
# def preprocess_func(x, y):
# x = tf.cast(x, tf.float32)
# return preprocess_input(x), y
# _temp = ( tf.zeros([4, 32, 32, 3]), tf.zeros(()) )
# preprocess_func(*_temp)
# else:
# print('''input_format must be either dict or tuple, corresponding to data organized as:
# tuple: (x, y)
# or
# dict: {'image':x, 'label':y}
# ''')
# return None

# return preprocess_func



class ImageAugmentor:

def __init__(self,
augmentations=['rotate',
'flip',
'color']):
'color'],
seed=12):
self.augmentations = augmentations

self.seed = seed
def rotate(self, x: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
"""Rotation augmentation
Expand All @@ -486,7 +521,7 @@ def rotate(self, x: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
"""

# Rotate 0, 90, 180, 270 degrees
return tf.image.rot90(x, tf.random_uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)), label
return tf.image.rot90(x, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32,seed=self.seed)), label


def flip(self, x: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
Expand All @@ -498,8 +533,8 @@ def flip(self, x: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
Returns:
Augmented image
"""
x = tf.image.random_flip_left_right(x)
x = tf.image.random_flip_up_down(x)
x = tf.image.random_flip_left_right(x, seed=self.seed)
x = tf.image.random_flip_up_down(x, seed=self.seed)

return x, label

Expand All @@ -512,12 +547,28 @@ def color(self, x: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
Returns:
Augmented image
"""
x = tf.image.random_hue(x, 0.08)
x = tf.image.random_saturation(x, 0.6, 1.6)
x = tf.image.random_brightness(x, 0.05)
x = tf.image.random_contrast(x, 0.7, 1.3)
x = tf.image.random_hue(x, 0.08, seed=self.seed)
x = tf.image.random_saturation(x, 0.6, 1.6, seed=self.seed)
x = tf.image.random_brightness(x, 0.05, seed=self.seed)
x = tf.image.random_contrast(x, 0.7, 1.3, seed=self.seed)
return x, label



def rgb2gray_3channel(img, label):
'''
Convert rgb image to grayscale, but keep num_channels=3
'''
img = tf.image.rgb_to_grayscale(img)
img = tf.image.grayscale_to_rgb(img)
return img, label

def rgb2gray_1channel(img, label):
'''
Convert rgb image to grayscale, num_channels from 3 to 1
'''
img = tf.image.rgb_to_grayscale(img)
return img, label



Expand All @@ -538,6 +589,41 @@ def color(self, x: tf.Tensor, label: tf.Tensor) -> tf.Tensor:









































Expand Down
44 changes: 40 additions & 4 deletions pyleaves/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,25 @@ def __init__(self,
label_col='family',
target_size=(224,224),
num_channels=3,
grayscale=False,
low_class_count_thresh=3,
data_splits={'val_size':0.2,'test_size':0.2},
tfrecord_root_dir=r'/media/data/jacob/Fossil_Project/tfrecord_data',
num_shards=10,
input_format=tuple):
'''
if grayscale==True and num_channels==3:
Convert to grayscale 1 channel then duplicate to 3 channels for full [batch,h,w,3] shape
'''

self.dirs = {'tfrecord_root_dir':tfrecord_root_dir}
self.init_directories(self.dirs)

super().__init__(dataset_name=dataset_name,
label_col=label_col,
target_size=target_size,
num_channels=num_channels,
grayscale=grayscale,
low_class_count_thresh=low_class_count_thresh,
data_splits=data_splits,
tfrecord_root_dir=tfrecord_root_dir,
Expand All @@ -107,7 +114,9 @@ def __init__(self,
preprocessing=None,
augment_images=False,
augmentations=['rotate','flip','color'],
seed=3):
regularization=None,
seed=3,
verbose=True):
super().__init__(model_name=model_name,
batch_size=batch_size,
frozen_layers=frozen_layers,
Expand All @@ -117,20 +126,47 @@ def __init__(self,
preprocessing=preprocessing,
augment_images=augment_images,
augmentations=augmentations,
seed=seed)
regularization=regularization,
seed=seed,
verbose=verbose)
'''
preprocessing : Can be any of [None, 'imagenet']
If 'imagenet', subtract hard-coded imagenet mean from each of the RGB channels
'''


class ModelConfig(BaseConfig):

def __init__(self,
model_name='vgg16',
num_classes=1000,
frozen_layers=(0,-4),
input_shape=(224,224,3),
base_learning_rate=0.0001,
grayscale=False,
regularization=None,
seed=3,
verbose=True):
super().__init__(model_name=model_name,
num_classes=num_classes,
frozen_layers=frozen_layers,
input_shape=input_shape,
base_learning_rate=base_learning_rate,
regularization=regularization,
seed=seed,
verbose=verbose)




class ExperimentConfig(BaseConfig):

def __init__(self,
dataset_config,
train_config):
dataset_config={},
train_config={}):

self.dataset_config = dataset_config
self.train_config = train_config

Expand Down
Loading

0 comments on commit 5cc0b69

Please sign in to comment.