Skip to content

Commit

Permalink
updated FE 1.0 working examples
Browse files Browse the repository at this point in the history
  • Loading branch information
vbvg2008 committed Jun 16, 2019
1 parent 20d51f1 commit 99b05c0
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 0 deletions.
46 changes: 46 additions & 0 deletions image_classification/lenet_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from fastestimator.pipeline.static.preprocess import Minmax
from fastestimator.estimator.estimator import Estimator
from fastestimator.pipeline.pipeline import Pipeline
from fastestimator.architecture.lenet import LeNet
from fastestimator.estimator.trace import Accuracy
import tensorflow as tf
import numpy as np

class Network:
def __init__(self):
self.model = LeNet()
self.optimizer = tf.optimizers.Adam()
self.loss = tf.losses.SparseCategoricalCrossentropy()

def train_op(self, batch):
with tf.GradientTape() as tape:
predictions = self.model(batch["x"])
loss = self.loss(batch["y"], predictions)
gradients = tape.gradient(loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
return predictions, loss

def eval_op(self, batch):
predictions = self.model(batch["x"], training=False)
loss = self.loss(batch["y"], predictions)
return predictions, loss

def get_estimator(epochs=2, batch_size=32, optimizer="adam"):

(x_train, y_train), (x_eval, y_eval) = tf.keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, -1)
x_eval = np.expand_dims(x_eval, -1)

pipeline = Pipeline(batch_size=batch_size,
feature_name=["x", "y"],
train_data={"x": x_train, "y": y_train},
validation_data={"x": x_eval, "y": y_eval},
transform_train= [[Minmax()], []])

traces = [Accuracy(feature_true="y")]

estimator = Estimator(network= Network(),
pipeline=pipeline,
epochs= epochs,
traces= traces)
return estimator
102 changes: 102 additions & 0 deletions image_generation/dcgan_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from fastestimator.estimator.estimator import Estimator
from fastestimator.pipeline.pipeline import Pipeline
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np

class Network:
def __init__(self):
self.discriminator = self.make_discriminator_model()
self.generator = self.make_generator_model()
self.cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
self.generator_optimizer = tf.keras.optimizers.Adam(1e-4)
self.discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

def make_generator_model(self):
model = tf.keras.Sequential()
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 7, 7, 128)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 64)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 28, 28, 1)
return model

def make_discriminator_model(self):
model = tf.keras.Sequential()
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model

def discriminator_loss(self, real_output, fake_output):
real_loss = self.cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = self.cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss

def generator_loss(self, fake_output):
return self.cross_entropy(tf.ones_like(fake_output), fake_output)

def train_op(self, batch):
noise = tf.random.normal([32, 100])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = self.generator(noise, training=True)

real_output = self.discriminator(batch["x"], training=True)
fake_output = self.discriminator(generated_images, training=True)

gen_loss = self.generator_loss(fake_output)
disc_loss = self.discriminator_loss(real_output, fake_output)

gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)

self.generator_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
self.discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))
return generated_images, (gen_loss, disc_loss)

def eval_op(self, batch):
noise = tf.random.normal([32, 100])
generated_images = self.generator(noise, training=False)
real_output = self.discriminator(batch["x"], training=False)
fake_output = self.discriminator(generated_images, training=False)
gen_loss = self.generator_loss(fake_output)
disc_loss = self.discriminator_loss(real_output, fake_output)
return generated_images, (gen_loss, disc_loss)

class Myrescale:
def transform(self, data, decoded_data=None):
data = tf.cast(data, tf.float32)
data = (data - 127.5) / 127.5
return data

def get_estimator():
(x_train, _), (x_eval, _) = tf.keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, -1)
x_eval = np.expand_dims(x_eval, -1)

pipeline = Pipeline(batch_size=32,
feature_name=["x"],
train_data={"x": x_train},
validation_data={"x": x_eval},
transform_train= [[Myrescale()], []])

estimator = Estimator(network= Network(),
pipeline=pipeline,
epochs= 2)
return estimator

0 comments on commit 99b05c0

Please sign in to comment.