Focus on the Discriminator
In the last post, we took a look at a simple autoencoder. The autoencoder is a deep learning model that takes in an image and, (through an encoder and decoder), works to produce the same image. In short:
- Autoencoder: image -> image
For a discriminator, we are going to focus on only the first half on the autoencoder.
Why only half? We want a different transformation. We are going to want to take an image as input and then do some discrimination of the image and classify what type of image it is. In our case, the model is going to input an image of a handwritten digit and attempt to decide which number it is.
- Discriminator: image -> label
As always, with deep learning. To do anything, we need data.
MNIST Data
Nothing changes here from the autoencoder code. We are still using the MNIST dataset for handwritten digits.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
|
The model will change since we want a different output.
The Model
We are still taking in the image as input, and using the same encoder layers from the autoencoder model. However, at the end, we use a fully connected layer that has 10 hidden nodes - one for each label of the digits 0-9. Then we use a softmax for the classification output.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
|
In the autoencoder, we were never actually using the label, but we will certainly need to use it this time. It is reflected in the model’s bindings with the data and label shapes.
1 2 3 4 5 |
|
For the evaluation metric, we are also going to use an accuracy metric vs a mean squared error (mse) metric
1
|
|
With these items in place, we are ready to train the model.
Training
The training from the autoencoder needs to changes to use the real label for the the forward pass and updating the metric.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
|
Let’s Run Things
It’s always a good idea to take a look at things before you start training.
The first batch of the training data looks like:
1 2 3 4 5 6 7 |
|
Before training, if we take the first batch from the test data and predict what the labels are:
1 2 3 4 5 6 7 |
|
1 2 3 4 5 6 7 |
|
Yeah, not even close. The real first line of the images is 6 1 0 0 3 1 4 8 0 9
Let’s Train!
1 2 3 4 5 6 7 8 |
|
After the training, let’s have another look at the predicted labels.
1 2 3 4 5 6 7 |
|
- Predicted =
(6.0 1.0 0.0 0.0 3.0 1.0 4.0 8.0 0.0 9.0)
- Actual =
6 1 0 0 3 1 4 8 0 9
Rock on!
Closing
In this post, we focused on the first half of the autoencoder and made a discriminator model that took in an image and gave us a label.
Don’t forget to save the trained model for later, we’ll be using it.
1 2 |
|
Until then, here is a picture of the cat in a basket to keep you going.
P.S. If you want to run all the code for yourself. It is here