astronomystuff/astronn_test.py
2025-04-21 16:36:49 -04:00

68 lines
2.4 KiB
Python

# import everything we need first
from tensorflow.keras import utils
import numpy as np
from sklearn.model_selection import train_test_split
import pylab as plt
from astroNN.models import Galaxy10CNN
from astroNN.datasets import galaxy10
from astroNN.datasets.galaxy10 import galaxy10cls_lookup, galaxy10_confusion
# To load images and labels (will download automatically at the first time)
# First time downloading location will be ~/.astroNN/datasets/
images, labels = galaxy10.load_data()
# To convert the labels to categorical 10 classes
labels = utils.to_categorical(labels, 10)
# Select 10 of the images to inspect
img = None
plt.ion()
print('===================Data Inspection===================')
for counter, i in enumerate(range(np.random.randint(0, labels.shape[0], size=10).shape[0])):
img = plt.imshow(images[i])
plt.title('Class {}: {} \n Random Demo images {} of 10'.format(np.argmax(labels[i]), galaxy10cls_lookup(labels[i]), counter+1))
plt.draw()
plt.pause(2.)
plt.close('all')
print('===============Data Inspection Finished===============')
# To convert to desirable type
labels = labels.astype(np.float32)
images = images.astype(np.float32)
# Split the dataset into training set and testing set
train_idx, test_idx = train_test_split(np.arange(labels.shape[0]), test_size=0.1)
train_images, train_labels, test_images, test_labels = images[train_idx], labels[train_idx], images[test_idx], labels[test_idx]
# To create a neural network instance
galaxy10net = Galaxy10CNN()
# set maximium epochs the neural network can run, set 5 to get quick result
galaxy10net.max_epochs = 5
# To train the nerual net
# astroNN will normalize the data by default
galaxy10net.train(train_images, train_labels)
# print model summary before training
galaxy10net.keras_model.summary()
# After the training, you can test the neural net performance
# Please notice predicted_labels are labels predicted from neural network. test_labels are ground truth from the dataset
predicted_labels = galaxy10net.test(test_images)
# Convert predicted_labels to class
prediction_class = np.argmax(predicted_labels, axis=1)
# Convert test_labels to class
test_class = np.argmax(test_labels, axis=1)
# Prepare a confusion matrix
confusion_matrix = np.zeros((10,10))
# create the confusion matrix
for counter, i in enumerate(prediction_class):
confusion_matrix[i, test_class[counter]] += 1
# Plot the confusion matrix
galaxy10_confusion(confusion_matrix)