Application: Galaxy Classification

Application: Galaxy Classification#

Galaxies come in many different shapes. Large surveys can images 100,000s or millions of galaxies, and we want to be able to classify them into types. We can use a neural network for this.

We will use the Galaxy10 DECaLS dataset (Leung & Bovy 2024), which contains 17736 images of galaxies, each 256x256 pixels and 3 colors (g, r, and z bands), classified into 10 different categories. It also provides RA and dec, redshift, and pixel scale—we’ll ignore those.

You can get the dataset from Zenodo: Galaxy10_DECals.h5

This is an HDF5 file—this is a self-describing binary format that is popular in scientific computing. We can read it in python using the h5py library.

Warning

This is a big dataset (~2.5 GB), so we need to be careful with memory.

Note

There is not a separate training and test set, so we will partition this dataset ourselves.

import numpy as np
import matplotlib.pyplot as plt
import h5py

If we look at the dataset, we can see the fields:

ds = h5py.File("Galaxy10_DECals.h5")
ds.keys()
<KeysViewHDF5 ['ans', 'dec', 'images', 'pxscale', 'ra', 'redshift']>

let’s look at the answers

ans = np.array(ds["ans"])
ans.shape, ans.dtype
((17736,), dtype('uint8'))

We see that there are 17736 records, stored as a NumPy array.

Important

The galaxies are stored in the dataset sorted by type. So all of the “disturbed galaxies” (type 0) come first, then the “merging galaxies”, …

Exploring the data#

Let’s define descriptive names for the galaxy types

galaxy_types = {0: "disturbed galaxies",
                1: "merging galaxies",
                2: "round smooth galaxies",
                3: "in-between round smooth galaxies",
                4: "cigar shaped smooth galaxies",
                5: "barred spiral galaxies",
                6: "unbarred tight spiral galaxies",
                7: "unbarred loose spiral galaxies",
                8: "edge-on galaxies without bulge",
                9: "edge-on galaxies with bulge"}

We’ll also create a simple class (Galaxy) to manage the data.

Here we pass in the image data for a single galaxy as a NumPy uint8 array—the same datatype used in the file. This class will convert it to a single precision floating-point array, scaled to fall between 0 and 1. We also convert the answer to a categorical datatype (and array of length 10).

class Galaxy:
    def __init__(self, data, answer, *, index=-1):
        self.data = np.array(data, dtype=np.float32) / 255.0 * 0.99 + 0.01
        self.answer = answer

        self.out = np.zeros(10, dtype=np.float32) + 0.01
        self.out[self.answer] = 0.99

        self.index = index

    def plot(self, ax=None):
        if ax is None:
            fig, ax = plt.subplots()
        ax.imshow(self.data, interpolation="nearest")
        ax.text(0.025, 0.95, f"answer: {self.answer}",
                color="white", transform=ax.transAxes)

    def validate(self, prediction):
        """check if a categorical prediction matches the answer"""
        return np.argmax(prediction) == self.answer

Now let’s find the first example of each type and plot it

images = ds["images"]
images.dtype
dtype('uint8')

Note

This doesn’t read in the data. It just provides a Dataset type that can be used to access the data, so we can get a single galaxy at a time.

from mpl_toolkits.axes_grid1 import ImageGrid
fig = plt.figure()
grid = ImageGrid(fig, 111,
                 nrows_ncols=(4, 3),
                 axes_pad=(0.25, 0.5))

for i in range(10):
    idx = np.where(ans == i)[0][0]
    g = Galaxy(images[idx, :, :, :], ans[idx])
    g.plot(ax=grid[i])
    grid[i].set_title(f"{idx}: {galaxy_types[i]}", fontsize="small")

grid[10].set_axis_off()
grid[11].set_axis_off()
fig.set_size_inches(10, 15)
../_images/9864b405c9f1666e9feccc3bfb663dce38316a141062c59ad6aefb04c7fc51e9.png

Important

Since each class of galaxy is stored together in the file, we will want to randomize the order before splitting into training and testing sets.