Skip to content

Conversation

@sayakpaul
Copy link
Contributor

There are many subtleties for training a well-performing video classifier. Also, there are many ways to train one. This example walks through one of them. Hopefully, it will be helpful for the community.

@google-cla google-cla bot added the cla: yes label May 28, 2021
Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. This is great stuff! Do you have a notebook version so we can take a look at the visualizations?

@sayakpaul
Copy link
Contributor Author

Thank you, @fchollet for the review. I have addressed all your comments. Here's a Colab Notebook for visualizations.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

@sayakpaul
Copy link
Contributor Author

@fchollet, I tried to incorporate your suggestions but it weirdly affects the performance quite a bit. I am unsure as to why this might be the case. Would you be able to take a look at the notebook?

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I trying running the notebook with a few simplifications:

  • apply pooling in the feature extractor to limit the number of features
  • simplify the classification model
  • shuffle the training data and use a validation split to monitor overfitting

Notebook: https://colab.research.google.com/drive/1sfkoKxEF_kfGU1vi-_rNOQeu_Ydx1EtF?usp=sharing

The results I'm seeing are consistent with what I'd expect given the small number of samples and the large model sizes: quick overfitting and low generalization.

To train a model with this many parameters (the classification model you had has 10M parameters) you'd need tens of thousands of samples at the very least. Right now there are 152 training samples. This is just impossible. I recommend both simplifying the model (e.g. some of the changes outlined above) and increasing the size of the training data by at least 10x.

If you get to 1000 training samples you can probably manage to train a classification model with 10-100k parameters.



train_labels = prepare_labels(train_labels)
test_labels = prepare_labels(test_labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had not noticed we had 152 training samples and 34 validation samples, for 5 classes. This seems very insufficient. You'd need at least 10x more.

@sayakpaul
Copy link
Contributor Author

sayakpaul commented May 31, 2021

Thank you for your inputs, @fchollet. I realized I had introduced a pesky bug during preparing the frame sequences. This was hurting the mappings between the videos and labels. I took care of it and incorporated the following suggestions:

  • Used pooling inside the feature_extractor.
  • Simplified the sequence models (now has 99099 parameters).
  • Made an extensive note on the use of a small dataset with respect to using a relatively larger model.

With these changes, the performance is good enough. Here's the notebook.

Additionally, your contributions to this would-be example are non-trivial IMHO. They go way beyond just code reviews and improvement suggestions. Therefore, I would very much welcome the idea of adding you as a co-author for the example. LMK.

@fchollet
Copy link
Contributor

Glad you were able to fix the bug!

I've been thinking about this example some more. This would be our first video classification example, which is great. But that means it should also exemplify best practices -- techniques that will work well across most video classification problems, not just this specific dataset. Right now, some things are very good, like the general idea of using a CNN for feature extraction and then a RNN. However, some things could be improved. Here are some things I'd like to see in a generic video classification example:

  1. Use a fixed frame subsampling rate, and have videos of different length.

There's no real justification for using exactly N frames per video. In the real world your videos will have different lengths. The justification "we need to batch images" doesn't work because you can just pad shorter videos with zeros and generate a padding mask, which the RNN will take into account (it will skip masked frames, so it won't even slow down training).

How that would work in practice: use a batch size of 1 during feature extraction, then pad the vector sequences with zeros and add a mask input to the classification model, which would go in the mask argument in the first RNN call.

  1. Use a higher frame count.

5 frames is very low and doesn't justify the use of a RNN as the importance of order will be extremely marginal. We need the inputs to be actual videos, not a small collection of pictures, so more like 20+ frames.

  1. Use more training samples.

As mentioned, at least a few hundreds per class. If you have 5 balanced classes then 1000-2000 is a reasonable number.

Therefore, I would very much welcome the idea of adding you as a co-author for the example.

Don't worry about it, I'm just the janitor here.

Let me know.

@sayakpaul
Copy link
Contributor Author

Thanks!

1> Could you provide a minimal code example / relevant resource for this?

2> Alright.

3> Alright. But also note that because of broken videos (videos that OpenCV could not capture) the number of available videos got reduced further.

@fchollet
Copy link
Contributor

Could you provide a minimal code example

I imagine you could do something like (untested code, just writing it up here):

frame_masks = np.zeros(shape=(num_samples, max_seq_length), dtype='uint8')
frame_features = np.zeros(shape=(num_samples, max_seq_length, num_features), dtype='float32')
for i, batch in enumerate(dataset):  # Dataset is batched with size 1
   length = batch.shape[1]  # This is different from video to video
   for j in range(length):
     frame_features[i, j, :] = feature_extractor(batch[:, j, :])  # feature_extractor is just a CNN that returns vectors of features
   frame_masks[i, :length] = 1  # 1 = not masked, 0 = masked

... # later

frame_features_input = Input((max_seq_length, num_features))
mask_input = Input((max_seq_length,), dtype='uint8')
x = GRU(...)(frame_features_input, mask=mask_input)

... # later
model = Model([frame_features_input, mask_input], output)
model.fit([frame_features, frame_masks], labels, ...)

Does that make sense? Can't guarantee it will work exactly as is, but it gives you an idea of what is supposed to happen.

@sayakpaul
Copy link
Contributor Author

sayakpaul commented Jun 1, 2021

@fchollet, I ended up developing a keras.utils.Sequence class for the datasets. Each entry in the dataset returns [(frame_features, frame_masks), labels] with shapes being (batch_size, max_seq_length, num_features), (batch_size, max_seq_length), and (batch_size, num_classes) respectively. It seems to work but it's very slow (~11 mins/epoch on a single V100). Looking forward to hearing your suggestions to improve it.

Here's the full notebook.

Additionally, given that we are now covering the best practices you mentioned above in the example, is it possible to limit the number of examples to the current setting in the interest of the runtime (there will be extensive notes about the data regime, though)?

Update

Logs after training for 5 epochs:

Epoch 1/5
9/9 [==============================] - 823s 94s/step - loss: 1.6176 - accuracy: 0.2088 - val_loss: 1.6072 - val_accuracy: 0.2604

Epoch 00001: val_loss improved from inf to 1.60717, saving model to /tmp/video_classifier
Epoch 2/5
9/9 [==============================] - 755s 86s/step - loss: 1.5981 - accuracy: 0.2432 - val_loss: 1.6252 - val_accuracy: 0.1458

Epoch 00002: val_loss did not improve from 1.60717
Epoch 3/5
9/9 [==============================] - 743s 85s/step - loss: 1.5508 - accuracy: 0.3631 - val_loss: 1.6148 - val_accuracy: 0.2396

Epoch 00003: val_loss did not improve from 1.60717
Epoch 4/5
9/9 [==============================] - 740s 85s/step - loss: 1.4979 - accuracy: 0.4382 - val_loss: 1.6423 - val_accuracy: 0.2344

Epoch 00004: val_loss did not improve from 1.60717
Epoch 5/5
9/9 [==============================] - 743s 85s/step - loss: 1.4766 - accuracy: 0.4166 - val_loss: 1.5973 - val_accuracy: 0.2760

Epoch 00005: val_loss improved from 1.60717 to 1.59728, saving model to /tmp/video_classifier

@fchollet
Copy link
Contributor

fchollet commented Jun 1, 2021

It seems to work but it's very slow

I think this might be because you won't get multiprocessing on Colab. But regardless of the environment I strongly recommend using a Dataset to avoid such performance issues (or just NumPy arrays, which amounts to the same). It will also enable you to train on TPU, etc. All the data is available as arrays so it seems doable.

A random thing you can try to diagnose the issue is to use from_generator to make a Dataset and then call cache() on it. See if it improves performance for the next epoch.

is it possible to limit the number of examples to the current setting in the interest of the runtime

Is there an obstacle with getting more data, or the issue the run time?

@sayakpaul
Copy link
Contributor Author

I think this might be because you won't get multiprocessing on Colab.

I tried it on a GCP VM (N1-standard-8) too but the performance was roughly the same.

or just NumPy arrays, which amounts to the same

I am more inclined toward creating a standard dataset for this since we are trying to showcase a couple of best practices here. Besides, showing the readers how to create a standard video classification dataset as a part of the example would be beneficial too.

Is there an obstacle with getting more data, or the issue the run time?

Runtime is currently the main issue.

Additionally, if you could check the notebook in my previous comment and let me know if it's close to what you had envisioned that would be helpful.

@fchollet
Copy link
Contributor

fchollet commented Jun 1, 2021

I checked out the notebook and the model looks great!

I am more inclined toward creating a standard dataset for this

Do you mean an end-to-end tf.data dataset where you provide video filenames and labels and it gives you an iterable with encoded video frames, masks, and encoded labels? This is certainly doable. To avoid doing redundant feature extraction work you're going to want to use caching though. I would encourage you to go this route (in particular to preserve TPU compatibility).

Actually this is currently a major issue with your Sequence class that may explain most of the slowdown: you apply feature extraction for every batch, while training proceeds. This means you're doing a ton of extra work in later epochs. In the previous workflow, you did feature extraction once, beforehand. This represents a large difference in GPU compute load.

@fchollet
Copy link
Contributor

fchollet commented Jun 1, 2021

Runtime is currently the main issue.

Then let's increase the data size and let's reduce the number of epochs (while documenting performance achieved for the full number of epochs).

@sayakpaul
Copy link
Contributor Author

sayakpaul commented Jun 2, 2021

Okay let me look into these.

TPU compatibility is questionable if we want to extract the frames using OpenCV, though. I would do that using tf.py_funnction() which is unsupported with TPUs as far as I remember.

@fchollet
Copy link
Contributor

fchollet commented Jun 2, 2021

TPU compatibility is questionable if we want to extract the frames using OpenCV, though

You could do this, but I would recommend doing the Python-only parts of the preprocessing beforehand, then using from_tensor_slices to create a dataset (if the data is too large you can save processed files to disk first). The CNN feature extraction part can either be done beforehand or as part of the dataset, though beforehand is actually better because otherwise you need to remember to call .cache() if you don't want to do dramatically more work (kind of a trap).

@sayakpaul
Copy link
Contributor Author

sayakpaul commented Jun 2, 2021

You could do this, but I would recommend doing the Python-only parts of the preprocessing beforehand, then using from_tensor_slices to create a dataset (if the data is too large you can save processed files to disk first).

This sounds like what we were doing previously. Reading the videos, cap their frames to a predefined max_seq_length, and then serialize those frames for later preprocessing. Or are you suggesting we serialize all the frames of the videos?

Update

With the current dataset (594 videos in the train set and 224 in the test) it takes a total of 14 minutes and 37 seconds to fully prepare the arrays beforehand. Another point to note is I did this in the "High RAM" setting of Colab Pro. I am not sure the runtime will be the same for non-pro Colab users. However, the time for training for 5 epochs is now extremely fast. Here's the full notebook.

The data preprocessing time will naturally get increased if we increase the number of samples. This may actually go well beyond the permitted runtime. @fchollet, let me know how you'd want us to proceed from here.

@fchollet
Copy link
Contributor

fchollet commented Jun 2, 2021

This sounds like what we were doing previously.

Yes, indeed, and that strategy was completely fine.

With the current dataset (594 videos in the train set and 224 in the test) it takes a total of 14 minutes and 37 seconds to fully prepare the arrays beforehand. Another point to note is I did this in the "High RAM" setting of Colab Pro. I am not sure the runtime will be the same for non-pro Colab users. However, the time for training for 5 epochs is now extremely fast.

Does it make sense it this case to attempt to parallelize the preprocessing? If most of the time spent is OpenCV, and if that algo is already parallel, then there's not much we can do.

Also, it seems we can increase the epoch number if training is so fast.

I think if we keep the total runtime below 30min it will be a success. How many training samples would that get us?

@sayakpaul
Copy link
Contributor Author

sayakpaul commented Jun 2, 2021

Does it make sense it this case to attempt to parallelize the preprocessing? If most of the time spent is OpenCV, and if that algo is already parallel, then there's not much we can do.

I was going to suggest this but the OpenCV backend utilities are already threaded. So, there's nothing much we can do here.

Also, it seems we can increase the epoch number if training is so fast.

Yes, totally.

I think if we keep the total runtime below 30min it will be a success. How many training samples would that get us?

A few maybe 30 or 50. But that risks the runtime being overloaded with the preprocessing we are doing.

@sayakpaul
Copy link
Contributor Author

@fchollet let me know how you'd want us to proceed from here.

@fchollet
Copy link
Contributor

fchollet commented Jun 3, 2021

@sayakpaul I'll let you make the call with regard to how many samples to use -- more data is better, but we should try to keep the runtime below 30 min. Up to you at this point!

A random thought: by using a small/faster CNN you'll be able to cut some processing time (but not too much if it's mostly OpenCV time).

@sayakpaul
Copy link
Contributor Author

Thanks, @fchollet. Let me incorporate these points in the next commit. This PR has been an amazing learning experience for me. Thank you for pushing it forward :)

@sayakpaul
Copy link
Contributor Author

sayakpaul commented Jun 3, 2021

@fchollet added changes to the example. Here is the summary:

  • Currently, the example operates with the videos that belong to the top-5 classes present in the UCF101 dataset. Increasing the top-n classes to 10, we get 1171 and 459 videos for train and test sets respectively. However, the preprocessing step takes 40 minutes to complete. So, I decided to stick to the previously subsampled dataset. To make the readers fully aware of the low-data regime this example operates on, I have included an extensive note.
  • I tried with smaller architectures like MobileNetV3 but that did not result in much speed-up since most of the time is being spent on frame extraction (handled by OpenCV).
  • Upon trying smaller but better performing models like EfficientNetB0, I observed something weird. The performance was dropping consistently (+/- 10%). The data regime is too small to comment anything definitively but here are some thoughts:
    • We are primarily interested in extracting meaningful features. Networks like VGG16, InceptionV3 have a good history of being used in works that deal with image feature spaces. Some relevant works: LPIPS, Perceptual Loss, and the famous Deep Dream.
    • Whereas EfficientNet family of models were found via neural architecture search constrained by aspects like minimal FLOP count, speedy training schedule, and low inference latency. I suspect this may have something to do with the disparity in the feature representations they produce.

Here's the full notebook. Let me know your thoughts.

@fchollet
Copy link
Contributor

fchollet commented Jun 3, 2021

Fantastic! Everything is looking good. Please add the generated files 👍

@sayakpaul
Copy link
Contributor Author

Thanks for all your help and feedback, @fchollet. I have added the files. The visualization GIF was probably embedded in the notebook but if you find anything off let me know.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the great contribution! I confirm the gif is being captured -- some guy shaving his beard.

@fchollet fchollet merged commit 317590e into keras-team:master Jun 4, 2021
@fchollet
Copy link
Contributor

fchollet commented Jun 4, 2021

Actually, there's a small issue -- it seems these copyedits were dropped. Can you please create a new PR to add them back?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants