Skip to content

Convolutional network slower than tensorflow on CPU #2350

Open
@alerem18

Description

@alerem18
using Flux
using MLDatasets: MNIST, CIFAR10, CIFAR100
using Flux: logitcrossentropy, setup, Adam, train!
using Flux.OneHotArrays: onehotbatch, onecold
using Statistics: mean
using Flux.MLUtils: DataLoader
using ProgressBars: tqdm, set_postfix
using Flux.Zygote: ignore

# ------------------------ DATA ---------------------
TRAIN = MNIST(split=:train)
TEST = MNIST(split=:test)


x_train, y_train = TRAIN.features, TRAIN.targets
x_test, y_test = TEST.features, TEST.targets


x_train = Flux.unsqueeze(x_train, dims=3)
x_test = Flux.unsqueeze(x_test, dims=3)
y_train_encoded = onehotbatch(y_train, 0:9)

TRAIN_LOADER = DataLoader((x_train, y_train_encoded); batchsize=128, shuffle=true)
TEST_LOADER = DataLoader((x_test, y_test); batchsize=128, shuffle=false)

# -------------------- MODEL ------------------------------
model = Flux.@autosize (28, 28, 1, 1) Chain(
    
    Conv((3, 3), 1=>32, relu),
    MaxPool((2, 2)),
    Conv((3, 3), 32=>64, relu),
    MaxPool((2, 2)),
    Flux.flatten,
    Dropout(0.5),
    Dense(_, 10)
)

optimizer = setup(Adam(0.001), model)



# --------------------- HELPER ----------------------------------
function accuracy(m, LOADER)
    corrects = 0
    total = 0
    for (X, Y)  LOADER
        total += length(Y)
        corrects += sum(onecold(m(X), 0:9) .== Y)
    end

    return corrects / total
end

# ------------------- TRAIN ----------------------------------------
function train_loop(model, optimizer, train_loader, test_loader; epochs=5)
    for epoch  1:epochs
        iter = tqdm(train_loader)
        total = 0
        corrects = 0
        for (X, Y)  iter
            train!(model, [(X, Y)], optimizer) do m, features, labels
                predicted = m(features)
                b_size = size(features)[end]
                ignore() do 
                    corrects += sum(onecold(predicted, 0:9) .== onecold(labels, 0:9))
                    total += b_size
                end
                logitcrossentropy(m(features), labels)
            end
            set_postfix(iter, accuracy=corrects / total)
        end

        val_accuracy = accuracy(model, test_loader)
        @info "Epoch $epoch/5 | Accuracy : $val_accuracy"
    end
end


train_loop(model, optimizer, TRAIN_LOADER, TEST_LOADER)
`

each epoch in Flux takes about 1 minute and 10 seconds while each epoch in tensorflow takes about 15 seconds





```python

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

# Model / data parameters
num_classes = 10
input_shape = (32, 32, 3)

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")


# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)


model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()


batch_size = 128
epochs = 15

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions