In this project, we create convolutional neural networks for image classification and segmentation.
We create a CNN to classify the Fashion MNIST database using PyTorch. The Fashion MNIST database consists of labelled images of clothing, shoes and accessories, split into 10 classes: top, trouser, pullover, dress, coat, sandal, shirt, sneaker, bag, and ankle boot. Below are some random labelled images from the database.
The network architecture is similar to what is suggeste in the spec with some changes. (128 channel convolution -> Batch Norm -> ReLu -> Max Pool) -> (128 channel convolution -> Batch Norm -> ReLu -> Max Pool) -> Linear Fully Connected layer -> ReLu -> Linear Fully Connected layer -> LogSoftMax. The network is trained for 20 epochs with bath size 50. Below is the accuracy plot for the training and validation datasets and the per-class accuracy table for the validation and test datasets.
Top | Trouser | Pullover | Dress | Coat | Sandal | Shirt | Sneaker | Bag | Ankle boot | |
---|---|---|---|---|---|---|---|---|---|---|
Validation | 96 | 99 | 96 | 99 | 94 | 100 | 96 | 99 | 100 | 99 |
Test | 82 | 97 | 84 | 90 | 79 | 98 | 76 | 97 | 97 | 95 |
From the above table, it can be seen that my CNN has the lowest accuracy for shirts and pullovers. Below, we can see what images are correctly labelled and which are incorrectly labelled. Pullovers, shirts and tops are all clothing that are worn on the top half of the body and this model misclassifies within these three categories.
Lastly, let's look at what the learned filters look like.
The network architecture is: Conv(3, 64), BatchNorm, ReLu, Conv(64, 64), BatchNorm, ReLu, MaxPool, Conv(64, 128), BatchNorm, ReLu, Conv(128, 128), BatchNorm, ReLu, Upsample, MaxPool, Conv(128, 64), BatchNorm, ReLu, Conv(64, 64), BatchNorm, ReLu, Conv(64, 5), Upsample x2, MaxPool. This architecture is very much motivated by the UNet architecture. Below is the loss for the training and validation sets.
The average precision per class is as follows:
The average AP = 0.5573236241191435.
Below are some example segmentations created by the network.