Keras and the Last Number Problem#
Let’s see if we can do better than our simple hidden layer NN with the last number problem.
import numpy as np
import keras
from keras.utils import to_categorical
2025-12-04 16:04:30.392591: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2025-12-04 16:04:30.438277: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-12-04 16:04:31.965408: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
We’ll use the same data class
class ModelDataCategorical:
"""this is the model data for our "last number" training set. We
produce input of length N, consisting of numbers 0-9 and store
the result in a 10-element array as categorical data.
"""
def __init__(self, N=10):
self.N = N
# our model input data
self.x = np.random.randint(0, high=10, size=N)
self.x_scaled = self.x / 10 + 0.05
# our scaled model output data
self.y = np.array([self.x[-1]])
self.y_scaled = np.zeros(10) + 0.01
self.y_scaled[self.x[-1]] = 0.99
def interpret_result(self, out):
"""take the network output and return the number we predict"""
return np.argmax(out)
For Keras, we need to pack the scaled data (both input and output) into arrays. We’ll use
the Keras to_categorical() to make the data categorical.
Let’s make both a training set and a test set
x_train = []
y_train = []
for _ in range(10000):
m = ModelDataCategorical()
x_train.append(m.x_scaled)
y_train.append(m.y)
x_train = np.asarray(x_train)
y_train = to_categorical(y_train, 10)
x_test = []
y_test = []
for _ in range(1000):
m = ModelDataCategorical()
x_test.append(m.x_scaled)
y_test.append(m.y)
x_test = np.asarray(x_test)
y_test = to_categorical(y_test, 10)
Check to make sure the data looks like we expect:
x_train[0]
array([0.75, 0.35, 0.55, 0.55, 0.35, 0.05, 0.95, 0.15, 0.05, 0.85])
y_train[0]
array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0.])
Creating the network#
Now let’s build our network. We’ll use just a single hidden layer, but instead of the sigmoid used before, we’ll use RELU and the softmax activations.
from keras.models import Sequential
from keras.layers import Input, Dense, Dropout, Activation
from keras.optimizers import RMSprop
model = Sequential()
model.add(Input((10,)))
model.add(Dense(100, activation="relu"))
model.add(Dropout(0.1))
model.add(Dense(10, activation="softmax"))
2025-12-04 16:04:33.205343: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
rms = RMSprop()
model.compile(loss='categorical_crossentropy',
optimizer=rms, metrics=['accuracy'])
model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense (Dense) │ (None, 100) │ 1,100 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 100) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 10) │ 1,010 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 2,110 (8.24 KB)
Trainable params: 2,110 (8.24 KB)
Non-trainable params: 0 (0.00 B)
Now we have ~ 2k parameters to fit.
Training#
Now we can train and test each epoch to see how we do
epochs = 100
batch_size = 256
model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size,
validation_data=(x_test, y_test), verbose=2)
Epoch 1/100
40/40 - 1s - 16ms/step - accuracy: 0.1731 - loss: 2.2571 - val_accuracy: 0.2100 - val_loss: 2.1967
Epoch 2/100
40/40 - 0s - 2ms/step - accuracy: 0.2418 - loss: 2.1488 - val_accuracy: 0.2240 - val_loss: 2.0932
Epoch 3/100
40/40 - 0s - 2ms/step - accuracy: 0.2619 - loss: 2.0475 - val_accuracy: 0.2560 - val_loss: 1.9885
Epoch 4/100
40/40 - 0s - 2ms/step - accuracy: 0.2886 - loss: 1.9456 - val_accuracy: 0.2990 - val_loss: 1.8905
Epoch 5/100
40/40 - 0s - 2ms/step - accuracy: 0.3201 - loss: 1.8474 - val_accuracy: 0.3440 - val_loss: 1.7918
Epoch 6/100
40/40 - 0s - 2ms/step - accuracy: 0.3472 - loss: 1.7596 - val_accuracy: 0.3940 - val_loss: 1.7122
Epoch 7/100
40/40 - 0s - 2ms/step - accuracy: 0.3685 - loss: 1.6815 - val_accuracy: 0.3730 - val_loss: 1.6382
Epoch 8/100
40/40 - 0s - 2ms/step - accuracy: 0.3991 - loss: 1.6088 - val_accuracy: 0.4360 - val_loss: 1.5709
Epoch 9/100
40/40 - 0s - 2ms/step - accuracy: 0.4296 - loss: 1.5477 - val_accuracy: 0.4000 - val_loss: 1.5137
Epoch 10/100
40/40 - 0s - 2ms/step - accuracy: 0.4451 - loss: 1.4899 - val_accuracy: 0.5360 - val_loss: 1.4505
Epoch 11/100
40/40 - 0s - 2ms/step - accuracy: 0.4775 - loss: 1.4338 - val_accuracy: 0.4640 - val_loss: 1.3946
Epoch 12/100
40/40 - 0s - 2ms/step - accuracy: 0.4933 - loss: 1.3857 - val_accuracy: 0.4700 - val_loss: 1.3509
Epoch 13/100
40/40 - 0s - 3ms/step - accuracy: 0.5146 - loss: 1.3386 - val_accuracy: 0.5470 - val_loss: 1.3088
Epoch 14/100
40/40 - 0s - 2ms/step - accuracy: 0.5350 - loss: 1.2955 - val_accuracy: 0.6390 - val_loss: 1.2636
Epoch 15/100
40/40 - 0s - 2ms/step - accuracy: 0.5602 - loss: 1.2530 - val_accuracy: 0.6070 - val_loss: 1.2184
Epoch 16/100
40/40 - 0s - 2ms/step - accuracy: 0.5785 - loss: 1.2177 - val_accuracy: 0.6780 - val_loss: 1.1820
Epoch 17/100
40/40 - 0s - 2ms/step - accuracy: 0.5882 - loss: 1.1822 - val_accuracy: 0.6310 - val_loss: 1.1499
Epoch 18/100
40/40 - 0s - 2ms/step - accuracy: 0.6166 - loss: 1.1467 - val_accuracy: 0.6350 - val_loss: 1.1232
Epoch 19/100
40/40 - 0s - 2ms/step - accuracy: 0.6242 - loss: 1.1178 - val_accuracy: 0.6760 - val_loss: 1.0873
Epoch 20/100
40/40 - 0s - 2ms/step - accuracy: 0.6400 - loss: 1.0852 - val_accuracy: 0.7060 - val_loss: 1.0532
Epoch 21/100
40/40 - 0s - 2ms/step - accuracy: 0.6592 - loss: 1.0549 - val_accuracy: 0.6780 - val_loss: 1.0283
Epoch 22/100
40/40 - 0s - 2ms/step - accuracy: 0.6810 - loss: 1.0249 - val_accuracy: 0.7660 - val_loss: 0.9913
Epoch 23/100
40/40 - 0s - 2ms/step - accuracy: 0.7029 - loss: 0.9977 - val_accuracy: 0.7710 - val_loss: 0.9677
Epoch 24/100
40/40 - 0s - 2ms/step - accuracy: 0.7136 - loss: 0.9714 - val_accuracy: 0.8040 - val_loss: 0.9371
Epoch 25/100
40/40 - 0s - 2ms/step - accuracy: 0.7306 - loss: 0.9448 - val_accuracy: 0.7970 - val_loss: 0.9187
Epoch 26/100
40/40 - 0s - 2ms/step - accuracy: 0.7445 - loss: 0.9221 - val_accuracy: 0.8420 - val_loss: 0.8912
Epoch 27/100
40/40 - 0s - 2ms/step - accuracy: 0.7614 - loss: 0.9013 - val_accuracy: 0.8670 - val_loss: 0.8645
Epoch 28/100
40/40 - 0s - 2ms/step - accuracy: 0.7809 - loss: 0.8726 - val_accuracy: 0.8630 - val_loss: 0.8494
Epoch 29/100
40/40 - 0s - 2ms/step - accuracy: 0.7900 - loss: 0.8547 - val_accuracy: 0.9030 - val_loss: 0.8198
Epoch 30/100
40/40 - 0s - 2ms/step - accuracy: 0.8072 - loss: 0.8313 - val_accuracy: 0.8900 - val_loss: 0.7972
Epoch 31/100
40/40 - 0s - 2ms/step - accuracy: 0.8196 - loss: 0.8091 - val_accuracy: 0.9010 - val_loss: 0.7831
Epoch 32/100
40/40 - 0s - 2ms/step - accuracy: 0.8346 - loss: 0.7901 - val_accuracy: 0.9020 - val_loss: 0.7601
Epoch 33/100
40/40 - 0s - 2ms/step - accuracy: 0.8530 - loss: 0.7673 - val_accuracy: 0.9340 - val_loss: 0.7339
Epoch 34/100
40/40 - 0s - 2ms/step - accuracy: 0.8647 - loss: 0.7453 - val_accuracy: 0.8650 - val_loss: 0.7199
Epoch 35/100
40/40 - 0s - 2ms/step - accuracy: 0.8688 - loss: 0.7284 - val_accuracy: 0.9490 - val_loss: 0.6976
Epoch 36/100
40/40 - 0s - 2ms/step - accuracy: 0.8894 - loss: 0.7059 - val_accuracy: 0.9320 - val_loss: 0.6811
Epoch 37/100
40/40 - 0s - 2ms/step - accuracy: 0.8938 - loss: 0.6892 - val_accuracy: 0.9590 - val_loss: 0.6554
Epoch 38/100
40/40 - 0s - 2ms/step - accuracy: 0.9116 - loss: 0.6702 - val_accuracy: 0.9570 - val_loss: 0.6402
Epoch 39/100
40/40 - 0s - 2ms/step - accuracy: 0.9198 - loss: 0.6507 - val_accuracy: 0.9440 - val_loss: 0.6265
Epoch 40/100
40/40 - 0s - 2ms/step - accuracy: 0.9270 - loss: 0.6336 - val_accuracy: 0.9690 - val_loss: 0.6030
Epoch 41/100
40/40 - 0s - 2ms/step - accuracy: 0.9299 - loss: 0.6163 - val_accuracy: 0.9630 - val_loss: 0.5907
Epoch 42/100
40/40 - 0s - 2ms/step - accuracy: 0.9424 - loss: 0.5964 - val_accuracy: 0.9750 - val_loss: 0.5695
Epoch 43/100
40/40 - 0s - 2ms/step - accuracy: 0.9471 - loss: 0.5807 - val_accuracy: 0.9930 - val_loss: 0.5441
Epoch 44/100
40/40 - 0s - 3ms/step - accuracy: 0.9559 - loss: 0.5610 - val_accuracy: 0.9890 - val_loss: 0.5320
Epoch 45/100
40/40 - 0s - 2ms/step - accuracy: 0.9614 - loss: 0.5453 - val_accuracy: 0.9830 - val_loss: 0.5116
Epoch 46/100
40/40 - 0s - 2ms/step - accuracy: 0.9649 - loss: 0.5285 - val_accuracy: 0.9970 - val_loss: 0.4952
Epoch 47/100
40/40 - 0s - 2ms/step - accuracy: 0.9706 - loss: 0.5121 - val_accuracy: 0.9970 - val_loss: 0.4903
Epoch 48/100
40/40 - 0s - 2ms/step - accuracy: 0.9738 - loss: 0.4984 - val_accuracy: 0.9950 - val_loss: 0.4642
Epoch 49/100
40/40 - 0s - 2ms/step - accuracy: 0.9765 - loss: 0.4825 - val_accuracy: 0.9970 - val_loss: 0.4530
Epoch 50/100
40/40 - 0s - 2ms/step - accuracy: 0.9774 - loss: 0.4664 - val_accuracy: 0.9940 - val_loss: 0.4434
Epoch 51/100
40/40 - 0s - 2ms/step - accuracy: 0.9810 - loss: 0.4535 - val_accuracy: 0.9950 - val_loss: 0.4287
Epoch 52/100
40/40 - 0s - 2ms/step - accuracy: 0.9825 - loss: 0.4393 - val_accuracy: 0.9920 - val_loss: 0.4119
Epoch 53/100
40/40 - 0s - 2ms/step - accuracy: 0.9821 - loss: 0.4250 - val_accuracy: 0.9980 - val_loss: 0.3999
Epoch 54/100
40/40 - 0s - 2ms/step - accuracy: 0.9847 - loss: 0.4129 - val_accuracy: 0.9980 - val_loss: 0.3876
Epoch 55/100
40/40 - 0s - 2ms/step - accuracy: 0.9865 - loss: 0.3973 - val_accuracy: 0.9900 - val_loss: 0.3736
Epoch 56/100
40/40 - 0s - 2ms/step - accuracy: 0.9856 - loss: 0.3855 - val_accuracy: 0.9960 - val_loss: 0.3595
Epoch 57/100
40/40 - 0s - 2ms/step - accuracy: 0.9883 - loss: 0.3726 - val_accuracy: 1.0000 - val_loss: 0.3496
Epoch 58/100
40/40 - 0s - 2ms/step - accuracy: 0.9909 - loss: 0.3609 - val_accuracy: 1.0000 - val_loss: 0.3335
Epoch 59/100
40/40 - 0s - 2ms/step - accuracy: 0.9916 - loss: 0.3485 - val_accuracy: 1.0000 - val_loss: 0.3189
Epoch 60/100
40/40 - 0s - 2ms/step - accuracy: 0.9926 - loss: 0.3371 - val_accuracy: 1.0000 - val_loss: 0.3114
Epoch 61/100
40/40 - 0s - 2ms/step - accuracy: 0.9919 - loss: 0.3276 - val_accuracy: 1.0000 - val_loss: 0.3041
Epoch 62/100
40/40 - 0s - 2ms/step - accuracy: 0.9930 - loss: 0.3141 - val_accuracy: 1.0000 - val_loss: 0.2882
Epoch 63/100
40/40 - 0s - 2ms/step - accuracy: 0.9930 - loss: 0.3046 - val_accuracy: 0.9990 - val_loss: 0.2870
Epoch 64/100
40/40 - 0s - 2ms/step - accuracy: 0.9937 - loss: 0.2948 - val_accuracy: 1.0000 - val_loss: 0.2657
Epoch 65/100
40/40 - 0s - 2ms/step - accuracy: 0.9955 - loss: 0.2841 - val_accuracy: 1.0000 - val_loss: 0.2555
Epoch 66/100
40/40 - 0s - 2ms/step - accuracy: 0.9952 - loss: 0.2742 - val_accuracy: 1.0000 - val_loss: 0.2495
Epoch 67/100
40/40 - 0s - 2ms/step - accuracy: 0.9956 - loss: 0.2646 - val_accuracy: 1.0000 - val_loss: 0.2375
Epoch 68/100
40/40 - 0s - 2ms/step - accuracy: 0.9963 - loss: 0.2543 - val_accuracy: 1.0000 - val_loss: 0.2281
Epoch 69/100
40/40 - 0s - 2ms/step - accuracy: 0.9975 - loss: 0.2448 - val_accuracy: 1.0000 - val_loss: 0.2154
Epoch 70/100
40/40 - 0s - 2ms/step - accuracy: 0.9971 - loss: 0.2368 - val_accuracy: 1.0000 - val_loss: 0.2101
Epoch 71/100
40/40 - 0s - 2ms/step - accuracy: 0.9974 - loss: 0.2279 - val_accuracy: 1.0000 - val_loss: 0.2047
Epoch 72/100
40/40 - 0s - 2ms/step - accuracy: 0.9974 - loss: 0.2198 - val_accuracy: 1.0000 - val_loss: 0.1980
Epoch 73/100
40/40 - 0s - 2ms/step - accuracy: 0.9976 - loss: 0.2116 - val_accuracy: 1.0000 - val_loss: 0.1921
Epoch 74/100
40/40 - 0s - 2ms/step - accuracy: 0.9983 - loss: 0.2038 - val_accuracy: 1.0000 - val_loss: 0.1824
Epoch 75/100
40/40 - 0s - 2ms/step - accuracy: 0.9975 - loss: 0.1966 - val_accuracy: 1.0000 - val_loss: 0.1699
Epoch 76/100
40/40 - 0s - 2ms/step - accuracy: 0.9974 - loss: 0.1892 - val_accuracy: 1.0000 - val_loss: 0.1648
Epoch 77/100
40/40 - 0s - 2ms/step - accuracy: 0.9979 - loss: 0.1818 - val_accuracy: 1.0000 - val_loss: 0.1631
Epoch 78/100
40/40 - 0s - 3ms/step - accuracy: 0.9982 - loss: 0.1764 - val_accuracy: 1.0000 - val_loss: 0.1524
Epoch 79/100
40/40 - 0s - 2ms/step - accuracy: 0.9982 - loss: 0.1682 - val_accuracy: 1.0000 - val_loss: 0.1472
Epoch 80/100
40/40 - 0s - 2ms/step - accuracy: 0.9988 - loss: 0.1618 - val_accuracy: 1.0000 - val_loss: 0.1390
Epoch 81/100
40/40 - 0s - 2ms/step - accuracy: 0.9990 - loss: 0.1548 - val_accuracy: 1.0000 - val_loss: 0.1350
Epoch 82/100
40/40 - 0s - 2ms/step - accuracy: 0.9983 - loss: 0.1499 - val_accuracy: 1.0000 - val_loss: 0.1359
Epoch 83/100
40/40 - 0s - 2ms/step - accuracy: 0.9994 - loss: 0.1423 - val_accuracy: 1.0000 - val_loss: 0.1201
Epoch 84/100
40/40 - 0s - 2ms/step - accuracy: 0.9986 - loss: 0.1381 - val_accuracy: 1.0000 - val_loss: 0.1157
Epoch 85/100
40/40 - 0s - 3ms/step - accuracy: 0.9994 - loss: 0.1321 - val_accuracy: 1.0000 - val_loss: 0.1086
Epoch 86/100
40/40 - 0s - 2ms/step - accuracy: 0.9994 - loss: 0.1267 - val_accuracy: 1.0000 - val_loss: 0.1058
Epoch 87/100
40/40 - 0s - 2ms/step - accuracy: 0.9992 - loss: 0.1219 - val_accuracy: 1.0000 - val_loss: 0.1005
Epoch 88/100
40/40 - 0s - 2ms/step - accuracy: 0.9993 - loss: 0.1165 - val_accuracy: 1.0000 - val_loss: 0.0945
Epoch 89/100
40/40 - 0s - 2ms/step - accuracy: 0.9993 - loss: 0.1121 - val_accuracy: 1.0000 - val_loss: 0.0939
Epoch 90/100
40/40 - 0s - 2ms/step - accuracy: 0.9989 - loss: 0.1085 - val_accuracy: 1.0000 - val_loss: 0.0858
Epoch 91/100
40/40 - 0s - 2ms/step - accuracy: 0.9995 - loss: 0.1033 - val_accuracy: 1.0000 - val_loss: 0.0818
Epoch 92/100
40/40 - 0s - 2ms/step - accuracy: 0.9996 - loss: 0.0984 - val_accuracy: 1.0000 - val_loss: 0.0781
Epoch 93/100
40/40 - 0s - 2ms/step - accuracy: 0.9994 - loss: 0.0955 - val_accuracy: 1.0000 - val_loss: 0.0761
Epoch 94/100
40/40 - 0s - 2ms/step - accuracy: 0.9993 - loss: 0.0916 - val_accuracy: 1.0000 - val_loss: 0.0715
Epoch 95/100
40/40 - 0s - 2ms/step - accuracy: 0.9992 - loss: 0.0884 - val_accuracy: 1.0000 - val_loss: 0.0689
Epoch 96/100
40/40 - 0s - 2ms/step - accuracy: 0.9996 - loss: 0.0841 - val_accuracy: 1.0000 - val_loss: 0.0648
Epoch 97/100
40/40 - 0s - 2ms/step - accuracy: 0.9995 - loss: 0.0817 - val_accuracy: 1.0000 - val_loss: 0.0638
Epoch 98/100
40/40 - 0s - 2ms/step - accuracy: 0.9992 - loss: 0.0781 - val_accuracy: 1.0000 - val_loss: 0.0632
Epoch 99/100
40/40 - 0s - 2ms/step - accuracy: 0.9994 - loss: 0.0759 - val_accuracy: 1.0000 - val_loss: 0.0561
Epoch 100/100
40/40 - 0s - 2ms/step - accuracy: 0.9998 - loss: 0.0713 - val_accuracy: 1.0000 - val_loss: 0.0578
<keras.src.callbacks.history.History at 0x7feec66fbad0>
As we see, the network is essentially perfect now.