Recognizing Handwritten Digits Using a Neural Network in Python

My last post described a neural network written entirely in Python, which performed reasonably well on a dummy data set. In that script the sklearn.datasets.make_moons() function generated random points with two features each, and the neural network managed to classify those points into one of two possible y values. The decision boundary was easy to visualize since there were only two features, and it was clear that the neural network was dividing the data set well.

While this was great, it was not a particularly relevant use case for neural networks. So I grabbed the MNIST data set of handwritten digits and tried my hand at character recognition using a neural network. The MNIST data set is quite well known and very well studied (there is also Kaggle competition). I only used a subset of the data, borrowed from the Coursera Machine Learning course I’m doing.

Essentially, the data set consists of 5000 rows, each row corresponding to a handwritten character. In each row there are 400 values which make up a 20 x 20 pixel image, expressed as grayscale values. Below is a random selection of 100 of these digits:

figure_1

Adapting my previous code to accept this data set was relatively easy, because I had the good sense to parameterize most of the model variables. I did have to change a few things though. First, I wanted to split the data into a training and testing set, so I used sklearn.cross_validation.train_test_split() to carve out 25% of my data for testing the model at the end:

from sklearn.cross_validation import train_test_split
# Split the data into testing and training sets
Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, test_size=0.25, random_state=0)

I made some subtle changes to the nn() function to account for the new variable names (for example I no longer have just m, but rather mtrain and mtest). I also added the maxiter argument to fmin_cg() since we’re now training on 3750 examples with 400 features each, which might take a while on my lowly 2009 Macbook.

Finally, I output the accuracy on the training set and additionally ran the predict() function on the testing set to see how the model fared on unseen data.

The very last thing I did was add in the printer() function which randomly grabs one row from the whole data set, calls predict() on that data and draws the image with the predicted digit in the title:

pred3

I spent most of my time playing with the regularization parameter,the number of nodes in the hidden layer and the maxiter argument. The data set is far more complex than the make_moons() model in my original post, so limiting the complexity through regularization is valuable here. I settled on a lambda value of 3, however this could probably use some optimization. Lower values increase the amount of time the model takes to train, but I didn’t find there was much benefit on the testing set – all I got was a very high training set accuracy. It’s clear that with low lambda values, the model is overfitting the training set, just as one would expect. Also as expected, if I increased lambda too much the training and testing accuracy dropped considerably.

Increasing the number of nodes in the hidden layer significantly increased the processing time, but I didn’t spend enough energy optimizing this – I just stuck with 10 nodes which was recommended in the Coursera notes.

The higher the maxiter value is, the longer the model takes to train. I kept this low (less than 200) for most of my testing because I didn’t want to wait around for my project to run. I’m going to try loading this up onto the supercomputer we have at work (I’m really excited for that!) and I’ll see what kind of results I can get.

I got a training set accuracy of 96% and a testing set accuracy of 92% using a lambda value of 3 and 200 iterations.

You can get my full code on Github.

That’s all for now – I mainly wanted to share this code. I was pleasantly surprised at just how easy it was to get this running. Have a look, and comment if you have questions or suggestions!

In my next post I’d like to transfer the neural network to a Python library (such as PyBrain) and compare performance. Look out for that post!

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s