Skip to content

Commit

Permalink
Merge pull request #6 from chm123/feature/metrics
Browse files Browse the repository at this point in the history
Add confusion matrix to minist example
  • Loading branch information
vbvg2008 authored Jul 8, 2019
2 parents 80a59a2 + b4e3315 commit 5456f81
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions image_classification/lenet_mnist.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from fastestimator.pipeline.static.preprocess import Minmax
import numpy as np
import tensorflow as tf

from fastestimator.architecture.lenet import LeNet
from fastestimator.estimator.estimator import Estimator
from fastestimator.estimator.trace import Accuracy, ConfusionMatrix
from fastestimator.pipeline.pipeline import Pipeline
from fastestimator.architecture.lenet import LeNet
from fastestimator.estimator.trace import Accuracy
import tensorflow as tf
import numpy as np
from fastestimator.pipeline.static.preprocess import Minmax


class Network:
def __init__(self):
Expand Down Expand Up @@ -35,12 +37,12 @@ def get_estimator(epochs=2, batch_size=32, optimizer="adam"):
feature_name=["x", "y"],
train_data={"x": x_train, "y": y_train},
validation_data={"x": x_eval, "y": y_eval},
transform_train= [[Minmax()], []])
transform_train=[[Minmax()], []])

traces = [Accuracy(y_true_key="y")]
traces = [Accuracy(y_true_key="y"), ConfusionMatrix(y_true_key="y", num_classes=10)]

estimator = Estimator(network= Network(),
estimator = Estimator(network=Network(),
pipeline=pipeline,
epochs= epochs,
traces= traces)
return estimator
epochs=epochs,
traces=traces)
return estimator

0 comments on commit 5456f81

Please sign in to comment.