diff --git a/keras/preprocessing/image.py b/keras/preprocessing/image.py index 784401e2fe12..3bcf0cfc0ba4 100644 --- a/keras/preprocessing/image.py +++ b/keras/preprocessing/image.py @@ -527,11 +527,12 @@ def standardize(self, x): 'first by calling `.fit(numpy_data)`.') return x - def random_transform(self, x): + def random_transform(self, x, seed=None): """Randomly augment a single image tensor. # Arguments x: 3D tensor, single image. + seed: random seed. # Returns A randomly transformed version of the input (same shape). @@ -541,6 +542,9 @@ def random_transform(self, x): img_col_axis = self.col_axis - 1 img_channel_axis = self.channel_axis - 1 + if seed is not None: + np.random.seed(seed) + # use composition of homographies # to generate final transform that needs to be applied if self.rotation_range: