Skip to content

Image Classification Always Returns Cat.  #11299

Closed
@jetfuel

Description

@jetfuel

Issue: Image Classification with Resnet will only infer category 3 (cat)

The training code as follow. The resnet_cifar10 is based on the image_classification_test with some minor change.

After changing training with about 10 epochs, the accuracy during training is about 95%. When test against with the test_suit using trainer.test, the accuracy is about 79%

However, when using the Inferencer to load the saved params to do infer. The result is always showing category 3 (cat). I even use the the original train dataset for inferring, but the result is always 3.

Could this mean that there might be a bug within the inferencer or saving the params?
Note: The test from image_classification_test does provide correct inferring.

# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

from __future__ import print_function

import paddle
import paddle.fluid as fluid
import numpy
import sys

from vgg import vgg_bn_drop
from resnet import resnet_cifar10


def inference_network():
    # The image is 32 * 32 with RGB representation.
    data_shape = [3, 32, 32]
    images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')

    predict = resnet_cifar10(images, 32)
    # predict = vgg_bn_drop(images) # un-comment to use vgg net
    return predict


def train_network():
    predict = inference_network()
    label = fluid.layers.data(name='label', shape=[1], dtype='int64')
    cost = fluid.layers.cross_entropy(input=predict, label=label)
    avg_cost = fluid.layers.mean(cost)
    accuracy = fluid.layers.accuracy(input=predict, label=label)
    return [avg_cost, accuracy]


def optimizer_program():
    return fluid.optimizer.Adam(learning_rate=0.001)


def train(use_cuda, train_program, params_dirname):
    BATCH_SIZE = 128
    EPOCH_NUM = 2

    train_reader = paddle.batch(
        paddle.reader.shuffle(paddle.dataset.cifar.train10(), buf_size=50000),
        batch_size=BATCH_SIZE)

    test_reader = paddle.batch(
        paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE)

    def event_handler(event):
        if isinstance(event, fluid.EndStepEvent):
            if event.step % 100 == 0:
                print("\nPass %d, Batch %d, Cost %f, Acc %f" %
                      (event.step, event.epoch, event.metrics[0],
                       event.metrics[1]))
            else:
                sys.stdout.write('.')
                sys.stdout.flush()

        if isinstance(event, fluid.EndEpochEvent):
            avg_cost, accuracy = trainer.test(
                reader=test_reader, feed_order=['pixel', 'label'])

            print('\nTest with Pass {0}, Loss {1:2.2}, Acc {2:2.2}'.format(
                event.epoch, avg_cost, accuracy))
            if params_dirname is not None:
                trainer.save_params(params_dirname)

    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    trainer = fluid.Trainer(
        train_func=train_program, optimizer_func=optimizer_program, place=place)

    trainer.train(
        reader=train_reader,
        num_epochs=EPOCH_NUM,
        event_handler=event_handler,
        feed_order=['pixel', 'label'])


def infer(use_cuda, inference_program, params_dirname=None):
    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    inferencer = fluid.Inferencer(
        infer_func=inference_program, param_path=params_dirname, place=place)

    train_reader = paddle.batch(
        paddle.reader.shuffle(paddle.dataset.cifar.train10(), buf_size=50000),
        batch_size=1000)

    for batch in train_reader():
        image = batch[0][0]
        label = batch[0][1]

        image_array = image.reshape(1, 3, 32, 32)

        results = inferencer.infer({'pixel': image_array})
        lab = results[0].argsort()
        print("Label--------: ", label)
        print("infer results: ", lab[0][-1])


def main(use_cuda):
    if use_cuda and not fluid.core.is_compiled_with_cuda():
        return
    save_path = "image_classification_resnet.inference.model"

    train(
        use_cuda=use_cuda,
        train_program=train_network,
        params_dirname=save_path)

    infer(
        use_cuda=use_cuda,
        inference_program=inference_network,
        params_dirname=save_path)


if __name__ == '__main__':
    # For demo purpose, the training runs on CPU
    # Please change accordingly.
    main(use_cuda=False)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Bug预测原名Inference,包含Capi预测问题等

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions