Hands-On Neural Networks with Keras
上QQ阅读APP看书,第一时间看更新

Evaluating model performance

Whenever we evaluate a network, we are actually interested in our accuracy of classifying images in the test set. This remains true for any ML model, as our accuracy on the training set is not a reliable indicator of our model's generalizability.

In our case, the test set accuracy is 95.78%, which is marginally lower than our training set accuracy of 96%. This is a classic case of overfitting, where our model seems to have captured irrelevant noise in our data to predict the training images. Since that inherent noise is different on our randomly selected test set, our network couldn't rely on the useless representations it had previously picked up on, and so performed poorly during testing. As we will see throughout this book, when testing neural networks, it is important to ensure that it has learnt correct and efficient representations of our data. In other words, we need to ensure that our network is not overfitting on our training data.

By the way, you can always visualize your predictions on the test set by printing out the label with the highest probability value for the given test subject and plotting the said test subject using Matplotlib. Here, we are printing out the label with maximum probability for test subject 110. Our model thinks it is an 8. By plotting the subject, we see that our model is right in this case:

predictions= load_model.predict([x_test])

#predict use the inference graph generated in the model to predict class labels on our test set
#print maximum value for prediction of x_test subject no. 110)

import numpy as np
print(np.argmax(predictions[110]))
-------------------------------------------
8
------------------------------------------
plt.imshow(x_test[110]))
<matplotlib.image.AxesImage at 0x174dd374240>

The preceding code generates the following output:

Once satisfied, you can save and load your model for later use, as follows:

model.save('mnist_nn.model')
load_model=kera.models.load_model('mnist_nn.model')