Clustering#

Clustering seeks to group data into clusters based on their properties and then allow us to predict which cluster a new member belongs.

import numpy as np
import matplotlib.pyplot as plt

Preparing the data#

We’ll use a dataset generator that is part of scikit-learn called make_moons. This generates data that falls into 2 different sets with a shape that looks like half-moons.

from sklearn import datasets
def generate_data():
    xvec, val = datasets.make_moons(200, noise=0.15)

    # encode the output to be 2 elements
    x = []
    v = []
    for xv, vv in zip(xvec, val):
        x.append(np.array(xv))
        v.append(vv)

    return np.array(x), np.array(v)

Tip

By adjusting the noise parameter, we can blur the boundary between the two datasets, making the classification harder.

x, v = generate_data()

Let’s look at a point and it’s value

print(f"x = {x[0]}, value = {v[0]}")
x = [ 1.04628964 -0.4585966 ], value = 1

Now let’s plot the data

def plot_data(x, v):
    xpt = [q[0] for q in x]
    ypt = [q[1] for q in x]

    fig, ax = plt.subplots()
    ax.scatter(xpt, ypt, s=40, c=v, cmap="viridis")
    ax.set_aspect("equal")
    return fig
fig = plot_data(x, v)
../_images/f1c9e2c2e7962f23d7f1447eba10bb0678e11ea8f7b43cbbce0765e1ff379b45.png

We want to partition this domain into 2 regions, such that when we come in with a new point, we know which group it belongs to.

Constructing the network#

First we setup and train our network

from keras.models import Sequential
from keras.layers import Input, Dense, Dropout, Activation
from keras.optimizers import RMSprop
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/keras/src/export/tf2onnx_lib.py:8: FutureWarning: In the future `np.object` will be defined as the corresponding NumPy scalar.
  if not hasattr(np, "object"):
model = Sequential()
model.add(Input((2,)))
model.add(Dense(50, activation="relu"))
model.add(Dense(20, activation="relu"))
model.add(Dense(1, activation="sigmoid"))
rms = RMSprop()
model.compile(loss='binary_crossentropy',
              optimizer=rms, metrics=['accuracy'])
model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense (Dense)                   │ (None, 50)             │           150 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (None, 20)             │         1,020 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_2 (Dense)                 │ (None, 1)              │            21 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 1,191 (4.65 KB)
 Trainable params: 1,191 (4.65 KB)
 Non-trainable params: 0 (0.00 B)

Training#

Important

We seem to need a lot of epochs here to get a good result

results = model.fit(x, v, batch_size=50, epochs=200, verbose=2)
Epoch 1/200
4/4 - 0s - 6ms/step - accuracy: 0.6450 - loss: 0.6669
Epoch 2/200
4/4 - 0s - 6ms/step - accuracy: 0.7000 - loss: 0.6363
Epoch 3/200
4/4 - 0s - 6ms/step - accuracy: 0.7250 - loss: 0.6152
Epoch 4/200
4/4 - 0s - 6ms/step - accuracy: 0.7450 - loss: 0.5966
Epoch 5/200
4/4 - 0s - 6ms/step - accuracy: 0.7650 - loss: 0.5778
Epoch 6/200
4/4 - 0s - 6ms/step - accuracy: 0.7750 - loss: 0.5596
Epoch 7/200
4/4 - 0s - 6ms/step - accuracy: 0.8050 - loss: 0.5411
Epoch 8/200
4/4 - 0s - 6ms/step - accuracy: 0.8200 - loss: 0.5231
Epoch 9/200
4/4 - 0s - 6ms/step - accuracy: 0.8250 - loss: 0.5055
Epoch 10/200
4/4 - 0s - 6ms/step - accuracy: 0.8250 - loss: 0.4895
Epoch 11/200
4/4 - 0s - 6ms/step - accuracy: 0.8300 - loss: 0.4736
Epoch 12/200
4/4 - 0s - 6ms/step - accuracy: 0.8300 - loss: 0.4588
Epoch 13/200
4/4 - 0s - 6ms/step - accuracy: 0.8300 - loss: 0.4431
Epoch 14/200
4/4 - 0s - 6ms/step - accuracy: 0.8300 - loss: 0.4283
Epoch 15/200
4/4 - 0s - 6ms/step - accuracy: 0.8300 - loss: 0.4144
Epoch 16/200
4/4 - 0s - 6ms/step - accuracy: 0.8350 - loss: 0.4013
Epoch 17/200
4/4 - 0s - 27ms/step - accuracy: 0.8450 - loss: 0.3891
Epoch 18/200
4/4 - 0s - 5ms/step - accuracy: 0.8500 - loss: 0.3779
Epoch 19/200
4/4 - 0s - 6ms/step - accuracy: 0.8500 - loss: 0.3680
Epoch 20/200
4/4 - 0s - 6ms/step - accuracy: 0.8500 - loss: 0.3572
Epoch 21/200
4/4 - 0s - 5ms/step - accuracy: 0.8600 - loss: 0.3484
Epoch 22/200
4/4 - 0s - 5ms/step - accuracy: 0.8600 - loss: 0.3395
Epoch 23/200
4/4 - 0s - 5ms/step - accuracy: 0.8600 - loss: 0.3316
Epoch 24/200
4/4 - 0s - 11ms/step - accuracy: 0.8650 - loss: 0.3246
Epoch 25/200
4/4 - 0s - 5ms/step - accuracy: 0.8650 - loss: 0.3185
Epoch 26/200
4/4 - 0s - 7ms/step - accuracy: 0.8650 - loss: 0.3135
Epoch 27/200
4/4 - 0s - 6ms/step - accuracy: 0.8750 - loss: 0.3069
Epoch 28/200
4/4 - 0s - 7ms/step - accuracy: 0.8700 - loss: 0.3030
Epoch 29/200
4/4 - 0s - 7ms/step - accuracy: 0.8700 - loss: 0.2975
Epoch 30/200
4/4 - 0s - 22ms/step - accuracy: 0.8700 - loss: 0.2933
Epoch 31/200
4/4 - 0s - 6ms/step - accuracy: 0.8700 - loss: 0.2908
Epoch 32/200
4/4 - 0s - 6ms/step - accuracy: 0.8750 - loss: 0.2866
Epoch 33/200
4/4 - 0s - 5ms/step - accuracy: 0.8750 - loss: 0.2834
Epoch 34/200
4/4 - 0s - 6ms/step - accuracy: 0.8650 - loss: 0.2835
Epoch 35/200
4/4 - 0s - 5ms/step - accuracy: 0.8750 - loss: 0.2781
Epoch 36/200
4/4 - 0s - 6ms/step - accuracy: 0.8750 - loss: 0.2754
Epoch 37/200
4/4 - 0s - 5ms/step - accuracy: 0.8700 - loss: 0.2731
Epoch 38/200
4/4 - 0s - 5ms/step - accuracy: 0.8750 - loss: 0.2711
Epoch 39/200
4/4 - 0s - 6ms/step - accuracy: 0.8750 - loss: 0.2686
Epoch 40/200
4/4 - 0s - 6ms/step - accuracy: 0.8750 - loss: 0.2678
Epoch 41/200
4/4 - 0s - 6ms/step - accuracy: 0.8800 - loss: 0.2652
Epoch 42/200
4/4 - 0s - 6ms/step - accuracy: 0.8850 - loss: 0.2628
Epoch 43/200
4/4 - 0s - 6ms/step - accuracy: 0.8800 - loss: 0.2614
Epoch 44/200
4/4 - 0s - 5ms/step - accuracy: 0.8850 - loss: 0.2583
Epoch 45/200
4/4 - 0s - 5ms/step - accuracy: 0.8850 - loss: 0.2592
Epoch 46/200
4/4 - 0s - 5ms/step - accuracy: 0.8850 - loss: 0.2566
Epoch 47/200
4/4 - 0s - 5ms/step - accuracy: 0.8850 - loss: 0.2532
Epoch 48/200
4/4 - 0s - 6ms/step - accuracy: 0.8850 - loss: 0.2531
Epoch 49/200
4/4 - 0s - 6ms/step - accuracy: 0.8850 - loss: 0.2506
Epoch 50/200
4/4 - 0s - 6ms/step - accuracy: 0.8850 - loss: 0.2484
Epoch 51/200
4/4 - 0s - 6ms/step - accuracy: 0.8850 - loss: 0.2468
Epoch 52/200
4/4 - 0s - 6ms/step - accuracy: 0.8850 - loss: 0.2457
Epoch 53/200
4/4 - 0s - 6ms/step - accuracy: 0.8850 - loss: 0.2439
Epoch 54/200
4/4 - 0s - 6ms/step - accuracy: 0.8850 - loss: 0.2427
Epoch 55/200
4/4 - 0s - 5ms/step - accuracy: 0.8850 - loss: 0.2403
Epoch 56/200
4/4 - 0s - 5ms/step - accuracy: 0.8850 - loss: 0.2404
Epoch 57/200
4/4 - 0s - 40ms/step - accuracy: 0.8850 - loss: 0.2376
Epoch 58/200
4/4 - 0s - 5ms/step - accuracy: 0.8850 - loss: 0.2358
Epoch 59/200
4/4 - 0s - 5ms/step - accuracy: 0.8850 - loss: 0.2339
Epoch 60/200
4/4 - 0s - 5ms/step - accuracy: 0.8850 - loss: 0.2331
Epoch 61/200
4/4 - 0s - 6ms/step - accuracy: 0.8850 - loss: 0.2301
Epoch 62/200
4/4 - 0s - 5ms/step - accuracy: 0.8850 - loss: 0.2303
Epoch 63/200
4/4 - 0s - 5ms/step - accuracy: 0.8850 - loss: 0.2281
Epoch 64/200
4/4 - 0s - 5ms/step - accuracy: 0.8850 - loss: 0.2261
Epoch 65/200
4/4 - 0s - 6ms/step - accuracy: 0.8850 - loss: 0.2238
Epoch 66/200
4/4 - 0s - 13ms/step - accuracy: 0.8850 - loss: 0.2217
Epoch 67/200
4/4 - 0s - 6ms/step - accuracy: 0.8850 - loss: 0.2204
Epoch 68/200
4/4 - 0s - 6ms/step - accuracy: 0.8850 - loss: 0.2190
Epoch 69/200
4/4 - 0s - 6ms/step - accuracy: 0.8850 - loss: 0.2172
Epoch 70/200
4/4 - 0s - 8ms/step - accuracy: 0.8900 - loss: 0.2147
Epoch 71/200
4/4 - 0s - 7ms/step - accuracy: 0.9000 - loss: 0.2133
Epoch 72/200
4/4 - 0s - 8ms/step - accuracy: 0.8950 - loss: 0.2117
Epoch 73/200
4/4 - 0s - 23ms/step - accuracy: 0.8950 - loss: 0.2086
Epoch 74/200
4/4 - 0s - 6ms/step - accuracy: 0.8950 - loss: 0.2091
Epoch 75/200
4/4 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.2068
Epoch 76/200
4/4 - 0s - 5ms/step - accuracy: 0.9050 - loss: 0.2040
Epoch 77/200
4/4 - 0s - 5ms/step - accuracy: 0.8950 - loss: 0.2017
Epoch 78/200
4/4 - 0s - 6ms/step - accuracy: 0.9000 - loss: 0.1999
Epoch 79/200
4/4 - 0s - 5ms/step - accuracy: 0.9100 - loss: 0.1985
Epoch 80/200
4/4 - 0s - 5ms/step - accuracy: 0.9050 - loss: 0.1963
Epoch 81/200
4/4 - 0s - 5ms/step - accuracy: 0.9050 - loss: 0.1941
Epoch 82/200
4/4 - 0s - 7ms/step - accuracy: 0.9100 - loss: 0.1927
Epoch 83/200
4/4 - 0s - 6ms/step - accuracy: 0.9100 - loss: 0.1895
Epoch 84/200
4/4 - 0s - 5ms/step - accuracy: 0.9100 - loss: 0.1893
Epoch 85/200
4/4 - 0s - 5ms/step - accuracy: 0.9100 - loss: 0.1868
Epoch 86/200
4/4 - 0s - 6ms/step - accuracy: 0.9100 - loss: 0.1842
Epoch 87/200
4/4 - 0s - 5ms/step - accuracy: 0.9050 - loss: 0.1834
Epoch 88/200
4/4 - 0s - 5ms/step - accuracy: 0.9100 - loss: 0.1820
Epoch 89/200
4/4 - 0s - 5ms/step - accuracy: 0.9100 - loss: 0.1781
Epoch 90/200
4/4 - 0s - 5ms/step - accuracy: 0.9100 - loss: 0.1764
Epoch 91/200
4/4 - 0s - 5ms/step - accuracy: 0.9100 - loss: 0.1791
Epoch 92/200
4/4 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.1724
Epoch 93/200
4/4 - 0s - 5ms/step - accuracy: 0.9200 - loss: 0.1703
Epoch 94/200
4/4 - 0s - 5ms/step - accuracy: 0.9200 - loss: 0.1692
Epoch 95/200
4/4 - 0s - 5ms/step - accuracy: 0.9200 - loss: 0.1658
Epoch 96/200
4/4 - 0s - 6ms/step - accuracy: 0.9250 - loss: 0.1640
Epoch 97/200
4/4 - 0s - 5ms/step - accuracy: 0.9200 - loss: 0.1639
Epoch 98/200
4/4 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.1601
Epoch 99/200
4/4 - 0s - 5ms/step - accuracy: 0.9300 - loss: 0.1574
Epoch 100/200
4/4 - 0s - 5ms/step - accuracy: 0.9300 - loss: 0.1559
Epoch 101/200
4/4 - 0s - 45ms/step - accuracy: 0.9300 - loss: 0.1539
Epoch 102/200
4/4 - 0s - 7ms/step - accuracy: 0.9300 - loss: 0.1512
Epoch 103/200
4/4 - 0s - 7ms/step - accuracy: 0.9300 - loss: 0.1495
Epoch 104/200
4/4 - 0s - 7ms/step - accuracy: 0.9300 - loss: 0.1471
Epoch 105/200
4/4 - 0s - 6ms/step - accuracy: 0.9300 - loss: 0.1449
Epoch 106/200
4/4 - 0s - 5ms/step - accuracy: 0.9350 - loss: 0.1442
Epoch 107/200
4/4 - 0s - 6ms/step - accuracy: 0.9350 - loss: 0.1420
Epoch 108/200
4/4 - 0s - 6ms/step - accuracy: 0.9300 - loss: 0.1390
Epoch 109/200
4/4 - 0s - 5ms/step - accuracy: 0.9300 - loss: 0.1375
Epoch 110/200
4/4 - 0s - 13ms/step - accuracy: 0.9350 - loss: 0.1349
Epoch 111/200
4/4 - 0s - 5ms/step - accuracy: 0.9350 - loss: 0.1326
Epoch 112/200
4/4 - 0s - 5ms/step - accuracy: 0.9350 - loss: 0.1300
Epoch 113/200
4/4 - 0s - 8ms/step - accuracy: 0.9350 - loss: 0.1288
Epoch 114/200
4/4 - 0s - 6ms/step - accuracy: 0.9350 - loss: 0.1294
Epoch 115/200
4/4 - 0s - 7ms/step - accuracy: 0.9450 - loss: 0.1238
Epoch 116/200
4/4 - 0s - 7ms/step - accuracy: 0.9400 - loss: 0.1224
Epoch 117/200
4/4 - 0s - 22ms/step - accuracy: 0.9450 - loss: 0.1195
Epoch 118/200
4/4 - 0s - 6ms/step - accuracy: 0.9400 - loss: 0.1178
Epoch 119/200
4/4 - 0s - 5ms/step - accuracy: 0.9450 - loss: 0.1155
Epoch 120/200
4/4 - 0s - 5ms/step - accuracy: 0.9450 - loss: 0.1134
Epoch 121/200
4/4 - 0s - 5ms/step - accuracy: 0.9450 - loss: 0.1115
Epoch 122/200
4/4 - 0s - 5ms/step - accuracy: 0.9450 - loss: 0.1100
Epoch 123/200
4/4 - 0s - 5ms/step - accuracy: 0.9450 - loss: 0.1073
Epoch 124/200
4/4 - 0s - 5ms/step - accuracy: 0.9450 - loss: 0.1051
Epoch 125/200
4/4 - 0s - 5ms/step - accuracy: 0.9550 - loss: 0.1042
Epoch 126/200
4/4 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.1015
Epoch 127/200
4/4 - 0s - 5ms/step - accuracy: 0.9550 - loss: 0.1004
Epoch 128/200
4/4 - 0s - 5ms/step - accuracy: 0.9650 - loss: 0.0973
Epoch 129/200
4/4 - 0s - 6ms/step - accuracy: 0.9600 - loss: 0.0957
Epoch 130/200
4/4 - 0s - 5ms/step - accuracy: 0.9700 - loss: 0.0936
Epoch 131/200
4/4 - 0s - 6ms/step - accuracy: 0.9650 - loss: 0.0918
Epoch 132/200
4/4 - 0s - 6ms/step - accuracy: 0.9650 - loss: 0.0901
Epoch 133/200
4/4 - 0s - 6ms/step - accuracy: 0.9600 - loss: 0.0901
Epoch 134/200
4/4 - 0s - 6ms/step - accuracy: 0.9600 - loss: 0.0897
Epoch 135/200
4/4 - 0s - 6ms/step - accuracy: 0.9700 - loss: 0.0854
Epoch 136/200
4/4 - 0s - 5ms/step - accuracy: 0.9700 - loss: 0.0847
Epoch 137/200
4/4 - 0s - 6ms/step - accuracy: 0.9750 - loss: 0.0820
Epoch 138/200
4/4 - 0s - 6ms/step - accuracy: 0.9750 - loss: 0.0816
Epoch 139/200
4/4 - 0s - 5ms/step - accuracy: 0.9850 - loss: 0.0802
Epoch 140/200
4/4 - 0s - 5ms/step - accuracy: 0.9850 - loss: 0.0775
Epoch 141/200
4/4 - 0s - 5ms/step - accuracy: 0.9800 - loss: 0.0772
Epoch 142/200
4/4 - 0s - 6ms/step - accuracy: 0.9750 - loss: 0.0762
Epoch 143/200
4/4 - 0s - 5ms/step - accuracy: 0.9850 - loss: 0.0739
Epoch 144/200
4/4 - 0s - 5ms/step - accuracy: 0.9800 - loss: 0.0723
Epoch 145/200
4/4 - 0s - 47ms/step - accuracy: 0.9850 - loss: 0.0704
Epoch 146/200
4/4 - 0s - 5ms/step - accuracy: 0.9800 - loss: 0.0702
Epoch 147/200
4/4 - 0s - 6ms/step - accuracy: 0.9850 - loss: 0.0684
Epoch 148/200
4/4 - 0s - 5ms/step - accuracy: 0.9800 - loss: 0.0681
Epoch 149/200
4/4 - 0s - 5ms/step - accuracy: 0.9850 - loss: 0.0660
Epoch 150/200
4/4 - 0s - 5ms/step - accuracy: 0.9850 - loss: 0.0647
Epoch 151/200
4/4 - 0s - 6ms/step - accuracy: 0.9850 - loss: 0.0626
Epoch 152/200
4/4 - 0s - 6ms/step - accuracy: 0.9850 - loss: 0.0622
Epoch 153/200
4/4 - 0s - 6ms/step - accuracy: 0.9850 - loss: 0.0605
Epoch 154/200
4/4 - 0s - 14ms/step - accuracy: 0.9850 - loss: 0.0596
Epoch 155/200
4/4 - 0s - 6ms/step - accuracy: 0.9850 - loss: 0.0589
Epoch 156/200
4/4 - 0s - 6ms/step - accuracy: 0.9850 - loss: 0.0585
Epoch 157/200
4/4 - 0s - 9ms/step - accuracy: 0.9900 - loss: 0.0556
Epoch 158/200
4/4 - 0s - 8ms/step - accuracy: 0.9950 - loss: 0.0542
Epoch 159/200
4/4 - 0s - 7ms/step - accuracy: 0.9850 - loss: 0.0539
Epoch 160/200
4/4 - 0s - 9ms/step - accuracy: 0.9900 - loss: 0.0523
Epoch 161/200
4/4 - 0s - 21ms/step - accuracy: 0.9850 - loss: 0.0523
Epoch 162/200
4/4 - 0s - 6ms/step - accuracy: 0.9950 - loss: 0.0512
Epoch 163/200
4/4 - 0s - 6ms/step - accuracy: 0.9900 - loss: 0.0501
Epoch 164/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0483
Epoch 165/200
4/4 - 0s - 5ms/step - accuracy: 0.9900 - loss: 0.0477
Epoch 166/200
4/4 - 0s - 6ms/step - accuracy: 0.9950 - loss: 0.0467
Epoch 167/200
4/4 - 0s - 5ms/step - accuracy: 1.0000 - loss: 0.0451
Epoch 168/200
4/4 - 0s - 6ms/step - accuracy: 0.9950 - loss: 0.0451
Epoch 169/200
4/4 - 0s - 6ms/step - accuracy: 0.9950 - loss: 0.0445
Epoch 170/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0424
Epoch 171/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0416
Epoch 172/200
4/4 - 0s - 5ms/step - accuracy: 0.9950 - loss: 0.0413
Epoch 173/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0409
Epoch 174/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0395
Epoch 175/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0391
Epoch 176/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0386
Epoch 177/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0374
Epoch 178/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0363
Epoch 179/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0359
Epoch 180/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0351
Epoch 181/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0340
Epoch 182/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0344
Epoch 183/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0330
Epoch 184/200
4/4 - 0s - 5ms/step - accuracy: 1.0000 - loss: 0.0320
Epoch 185/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0322
Epoch 186/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0321
Epoch 187/200
4/4 - 0s - 5ms/step - accuracy: 1.0000 - loss: 0.0303
Epoch 188/200
4/4 - 0s - 54ms/step - accuracy: 1.0000 - loss: 0.0294
Epoch 189/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0290
Epoch 190/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0289
Epoch 191/200
4/4 - 0s - 5ms/step - accuracy: 1.0000 - loss: 0.0284
Epoch 192/200
4/4 - 0s - 5ms/step - accuracy: 1.0000 - loss: 0.0278
Epoch 193/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0270
Epoch 194/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0264
Epoch 195/200
4/4 - 0s - 5ms/step - accuracy: 1.0000 - loss: 0.0260
Epoch 196/200
4/4 - 0s - 5ms/step - accuracy: 1.0000 - loss: 0.0254
Epoch 197/200
4/4 - 0s - 16ms/step - accuracy: 1.0000 - loss: 0.0250
Epoch 198/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0243
Epoch 199/200
4/4 - 0s - 6ms/step - accuracy: 1.0000 - loss: 0.0241
Epoch 200/200
4/4 - 0s - 9ms/step - accuracy: 1.0000 - loss: 0.0238
score = model.evaluate(x, v, verbose=0)
print(f"score = {score[0]}")
print(f"accuracy = {score[1]}")
score = 0.023014333099126816
accuracy = 1.0

Predicting#

Let’s look at a prediction. We need to feed in a single point as an array of shape (N, 2), where N is the number of points

res = model.predict(np.array([[-2, 2]]))
res
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
array([[3.4656636e-16]], dtype=float32)

We see that we get a floating point number. We will need to convert this to 0 or 1 by rounding.

Let’s plot the partitioning

M = 256
N = 256

xmin = -1.75
xmax = 2.5
ymin = -1.25
ymax = 1.75

xpt = np.linspace(xmin, xmax, M)
ypt = np.linspace(ymin, ymax, N)

To make the prediction go faster, we want to feed in a vector of these points, of the form:

[[xpt[0], ypt[0]],
 [xpt[1], ypt[1]],
 ...
]

We can see that this packs them into the vector

pairs = np.array(np.meshgrid(xpt, ypt)).T.reshape(-1, 2)
pairs[0]
array([-1.75, -1.25])

Now we do the prediction. We will get a vector out, which we reshape to match the original domain.

res = model.predict(pairs, verbose=0)
res.shape = (M, N)

Finally, round to 0 or 1

domain = np.where(res > 0.5, 1, 0)

and we can plot the data

fig, ax = plt.subplots()
ax.imshow(domain.T, origin="lower",
          extent=[xmin, xmax, ymin, ymax], alpha=0.25)
xpt = [q[0] for q in x]
ypt = [q[1] for q in x]

ax.scatter(xpt, ypt, s=40, c=v, cmap="viridis")
<matplotlib.collections.PathCollection at 0x7fc42e0496d0>
../_images/c4f888b1042f702cd9b7f061accd1d6052abe1634ed3bf1bad4b2dc98262782b.png