Example: Character Recognition

Example: Character Recognition#

We’ll apply the ideas we just learned to a neural network that does character recognition using the MNIST database. This is a set of handwritten digits (0–9) represented as a 28×28 pixel grayscale image.

There are 2 datasets, the training set with 60,000 images and the test set with 10,000 images. We will use a version of the data that is provided as CSV files:

Each line of these files provides the answer (i.e., what the digit is) as the first column and then the next 784 columns are the pixel values.

We’ll write a class to managed this data:

import random
import numpy as np
import matplotlib.pyplot as plt

A TrainingDigit provides a scaled floating point representation of the image as a 1D array (.scaled) as well as the correct answer (.num) and categorical data that is used to represent the answer from the neural network—a 10 element array of 1s and 0s. It also provides a method to plot the data.

class TrainingDigit:
    """a handwritten digit from the MNIST training set"""

    def __init__(self, raw_string):
        """we feed this a single line from the MNIST data set"""
        self.raw_string = raw_string

        # make the data range from 0.01 to 1.00
        _tmp = raw_string.split(",")
        self.scaled = np.asfarray(_tmp[1:])/255.0 * 0.99 + 0.01

        # the correct answer
        self.num = int(_tmp[0])

        # the output for the NN as a bit array -- make this lie in [0.01, 0.99]
        self.out = np.zeros(10) + 0.01
        self.out[self.num] = 0.99

    def plot(self, ax=None, output=None):
        """plot the digit"""
        if ax is None:
            fig, ax = plt.subplots()
        ax.imshow(self.scaled.reshape((28, 28)),
                  cmap="Greys", interpolation="nearest")
        if output is not None:
            dstr = [f"{n}: {v:6.4f}" for n, v in enumerate(output)]
            ostr = f"correct digit: {self.num}\n"
            ostr += "  ".join(dstr[0:5]) + "\n" + "  ".join(dstr[5:])
            plt.title(f"{ostr}", fontsize="x-small")

An UnknownDigit is like a TrainingDigit but it also provides a method to check if our prediction from the network is correct.

class UnknownDigit(TrainingDigit):
    """A digit from the MNIST test database.  This provides a method to
    compare a NN result to the correct answer

    """
    def __init__(self, raw_string):
        super().__init__(raw_string)
        self.out = None

    def interpret_output(self, out):
        """return the prediction from the net as an integer"""
        return np.argmax(out)
    
    def check_output(self, out):
        """given the output array from the NN, return True if it is
        correct for this digit"""
        return self.interpret_output(out) == self.num

Now we’ll read in the data and store the training and test sets in separate lists. We store the files as zipped files, so we need to unzip first.

import zipfile
training_set = []
with zipfile.ZipFile("mnist_train.csv.zip") as zf:
    with zf.open("mnist_train.csv") as f:
        for line in f:
            training_set.append(TrainingDigit(line.decode("utf8").strip("\n")))
len(training_set)
60000
test_set = []
with zipfile.ZipFile("mnist_test.csv.zip") as zf:
    with zf.open("mnist_test.csv") as f:
        for line in f:
            test_set.append(UnknownDigit(line.decode("utf8").strip("\n")))
len(test_set)
10000

Let’s look at the first few digits in the training set

from mpl_toolkits.axes_grid1 import ImageGrid
fig = plt.figure(1)
grid = ImageGrid(fig, 111,
                 nrows_ncols=(4, 4),
                 axes_pad=0.1)

for i, ax in enumerate(grid):
    training_set[i].plot(ax=ax)
../_images/472ec725dd13ef5785254aaeac8db853c2d2ce8383558da083e2e47db296600e.png

Here’s what the scaled pixel values look like—this is what will be fed into the network as input

training_set[0].scaled
array([0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.02164706, 0.07988235, 0.07988235,
       0.07988235, 0.49917647, 0.538     , 0.68941176, 0.11094118,
       0.65447059, 1.        , 0.96894118, 0.50305882, 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.12647059, 0.14976471, 0.37494118, 0.60788235,
       0.67      , 0.99223529, 0.99223529, 0.99223529, 0.99223529,
       0.99223529, 0.88352941, 0.67776471, 0.99223529, 0.94952941,
       0.76705882, 0.25847059, 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.20023529, 0.934     ,
       0.99223529, 0.99223529, 0.99223529, 0.99223529, 0.99223529,
       0.99223529, 0.99223529, 0.99223529, 0.98447059, 0.37105882,
       0.32835294, 0.32835294, 0.22741176, 0.16141176, 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.07988235, 0.86023529, 0.99223529, 0.99223529,
       0.99223529, 0.99223529, 0.99223529, 0.77870588, 0.71658824,
       0.96894118, 0.94564706, 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.32058824, 0.61564706, 0.42541176, 0.99223529, 0.99223529,
       0.80588235, 0.05270588, 0.01      , 0.17694118, 0.60788235,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.06435294,
       0.01388235, 0.60788235, 0.99223529, 0.35941176, 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.54964706,
       0.99223529, 0.74764706, 0.01776471, 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.05270588, 0.74764706, 0.99223529,
       0.28176471, 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.14588235, 0.94564706, 0.88352941, 0.63117647,
       0.42929412, 0.01388235, 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.32447059, 0.94176471, 0.99223529, 0.99223529, 0.472     ,
       0.10705882, 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.18470588,
       0.73211765, 0.99223529, 0.99223529, 0.59235294, 0.11482353,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.07211765, 0.37105882,
       0.98835294, 0.99223529, 0.736     , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.97670588, 0.99223529,
       0.97670588, 0.25847059, 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.18858824, 0.51470588,
       0.72047059, 0.99223529, 0.99223529, 0.81364706, 0.01776471,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.16141176,
       0.58458824, 0.89905882, 0.99223529, 0.99223529, 0.99223529,
       0.98058824, 0.71658824, 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.10317647, 0.45258824, 0.868     , 0.99223529, 0.99223529,
       0.99223529, 0.99223529, 0.79035294, 0.31282353, 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.09929412, 0.26623529, 0.83694118, 0.99223529,
       0.99223529, 0.99223529, 0.99223529, 0.77870588, 0.32447059,
       0.01776471, 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.07988235, 0.67388235, 0.86023529,
       0.99223529, 0.99223529, 0.99223529, 0.99223529, 0.76705882,
       0.32058824, 0.04494118, 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.22352941, 0.67776471,
       0.88741176, 0.99223529, 0.99223529, 0.99223529, 0.99223529,
       0.95729412, 0.52635294, 0.05270588, 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.538     , 0.99223529, 0.99223529, 0.99223529,
       0.83305882, 0.53411765, 0.52247059, 0.07211765, 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      , 0.01      ,
       0.01      , 0.01      , 0.01      , 0.01      ])

and here’s what the categorical output looks like—this will be what we expect the network to return

training_set[0].out
array([0.01, 0.01, 0.01, 0.01, 0.01, 0.99, 0.01, 0.01, 0.01, 0.01])

Now we can write our neural network class. We will include a single hidden layer.

class NeuralNetwork:
    """A neural network class with a single hidden layer."""

    def __init__(self, input_size=1, output_size=1, hidden_layer_size=1):

        # the number of nodes/neurons on the output layer
        self.N_out = output_size

        # the number of nodes/neurons on the input layer
        self.N_in = input_size

        # the number of nodes/neurons on the hidden layer
        self.N_hidden = hidden_layer_size

        # we will initialize the weights with Gaussian normal random
        # numbers centered on 0 with a width of 1/sqrt(n), where n is
        # the length of the input state

        # A is the set of weights between the hidden layer and output layer
        self.A = np.random.normal(0.0, 1.0/np.sqrt(self.N_hidden), (self.N_out, self.N_hidden))

        # B is the set of weights between the input layer and hidden layer
        self.B = np.random.normal(0.0, 1.0/np.sqrt(self.N_in), (self.N_hidden, self.N_in))

    def g(self, xi):
        """our sigmoid function that operates on the hidden layer"""
        return 1.0/(1.0 + np.exp(-xi))

    def train(self, training_data, n_epochs=1, learning_rate=0.1):
        """Train the neural network by doing gradient descent with back
        propagation to set the matrix elements in B (the weights
        between the input and hidden layer) and A (the weights between
        the hidden layer and output layer)

        """

        print(f"size of training data = {len(training_data)}")
        
        for i in range(n_epochs):
            print(f"epoch {i+1} of {n_epochs}")

            random.shuffle(training_data)
            for n, model in enumerate(training_data):

                # make the input and output data one-dimensional
                x = model.scaled.reshape(self.N_in, 1)
                y = model.out.reshape(self.N_out, 1)

                # propagate the input through the network
                z_tilde = self.g(self.B @ x)
                z = self.g(self.A @ z_tilde)

                # compute the errors (backpropagate to the hidden layer)
                e = z - y
                e_tilde = self.A.T @ (e * z * (1 - z))

                # corrections
                dA = -2 * learning_rate * e * z * (1 - z) @ z_tilde.T
                dB = -2 * learning_rate * e_tilde * z_tilde * (1 - z_tilde) @ x.T

                self.A[:, :] += dA
                self.B[:, :] += dB

    def predict(self, model):
        """ predict the outcome using our trained matrix A """
        y = self.g(self.A @ (self.g(self.B @ model.scaled)))
        return y

Create our neural network

input_size = len(training_set[0].scaled)
output_size = len(training_set[0].out)
net = NeuralNetwork(input_size=input_size, output_size=output_size, hidden_layer_size=50)

Now we can train

net.train(training_set, n_epochs=5)
size of training data = 60000
epoch 1 of 5
epoch 2 of 5
epoch 3 of 5
epoch 4 of 5
epoch 5 of 5

Let’s see what our accuracy rate is

n_correct = 0
for model in test_set:
    res = net.predict(model)
    if model.check_output(res):
        n_correct += 1

print(f"accuracy is {n_correct / len(test_set)}")
accuracy is 0.9654

So we are about 96% accurate. We can try to improve this by training with more epochs or using a bigger hidden layer. We might also try experimenting with other activation functions.

Let’s look at some of the digits we get wrong

from mpl_toolkits.axes_grid1 import ImageGrid
fig = plt.figure(1, (8, 8))
grid = ImageGrid(fig, 111,
                 nrows_ncols=(4, 4),
                 axes_pad=0.1)

num_wrong = 0
for model in test_set:
    res = net.predict(model)
    if not model.check_output(res):
        model.plot(ax=grid[num_wrong])
        grid[num_wrong].text(0.05, 0.05,
                             f"ans: {model.num}\npred: {model.interpret_output(res)}",
                             transform=grid[num_wrong].transAxes,
                             color="C1", zorder=100)
        num_wrong += 1

    if num_wrong == len(grid):
        break
../_images/8c23451553b93f97e6ec64c472e89a3c853bf4f6d45f16511b06f07f0614d80a.png