Understanding neural networks is one thing, but seeing them in action is quite another. The MNIST database, a large collection of handwritten digits, is the perfect playground to train and test a neural network for image recognition. In this blog post, we will build upon our previous “Crafting a Neural Network in Just One File! Java with JBang” and use the MNIST Dataset described in “Exploring the Classic MNIST: A Benchmark for Machine Learning Models” to develop a fully functioning neural network for digit classification.
- Prerequisites
- The MNIST Playground
- Crafting the Neural Network with Java and JBang
- Conclusion
Prerequisites
Before we dive into the development process, ensure you have:
-
JBang installed on your system. You can install it from JBang’s official website.
-
You can clone the
https://github.com/dmakariev/examples
repository.git clone https://github.com/dmakariev/examples.git cd examples/artificial-intelligence/neural-network-compare
The MNIST Playground
The MNIST dataset is a cornerstone of machine learning that contains 70,000 images of handwritten digits, each 28x28 pixels. Each image is a monochrome with pixel-values ranging from 0 to 255, representing the intensity of black. Training a model on such a dataset gives a straightforward benchmark for learning and performance assessment. For simplicity, we are going to use the CSV version provided on the kaggle website :
https://www.kaggle.com/datasets/oddrationale/mnist-in-csv
The dataset consists of two files:
mnist_train.csv
mnist_test.csv
The mnist_train.csv
file contains the 60,000 training examples and labels. The mnist_test.csv
contains 10,000 test examples and labels. Each row consists of 785 values: the first value is the label (a number from 0 to 9) and the remaining 784 values are the pixel values (a number from 0 to 255).
Due to the size of the files mnist_train.csv
is ~110Mb and mnist_test.csv
is ~18Mb , we are going to use them from inside a zip file dataset/dataset-MNIST.zip
( ~ 16Mb )
Crafting the Neural Network with Java and JBang
We’ll use Java as our language of choice, combined with the JBang scripting tool, to construct our neural network from scratch. JBang allows us to run our Java programs with a script-like feel, removing much of the boilerplate associated with traditional Java applications.
Step 1: Setting Up the Neural Network Architecture
Using the NeuralNetTutorial
class we’ve crafted before, we establish the neural network’s architecture. For the MNIST dataset, we choose a simple yet effective layout:
An input layer with 784 neurons, corresponding to the 28x28 pixels of the MNIST images. A hidden layer with 64 neurons, a figure that offers a balance between complexity and computational efficiency. An output layer with 10 neurons, representing the digits 0 through 9. We apply activation functions like the leaky ReLU to introduce non-linearity, which helps the network learn complex patterns in the dataset.
In the updated version of our NeuralNetTutorial
, we’ve modularized the code by extracting the TrainingData
into a distinct utility class. Additionally, we have shifted from employing Random Initialization to the He Initialization method for setting up our network’s weights. Initialization in the context of neural networks refers to the process of setting the initial weights and biases of the network before training begins. Proper initialization is crucial because it can significantly affect the convergence rate and the quality of the final solution that the training process yields.
He Initialization is named after Kaiming He, it is specifically designed for layers with ReLU activation functions, but could also be used with other activation functions. The weights are initialized keeping in mind the size of the previous layer which helps in attaining a global optimum during the training process.
double stddev = Math.sqrt(2.0 / numOutputs); // Standard deviation for He initialization
for (int c = 0; c < numOutputs; ++c) {
outputWeights.add(new Connection(rand.nextGaussian() * stddev, 0));
}
While the default activation function is now the leaky ReLU, the code allows for easy substitution of this function, enabling users to experiment with different activation functions and observe their corresponding outcomes.
private final static ActivationFunction ACTIVATION_FUNCTION = ActivationFunction.leakyReLU();
//private final static ActivationFunction ACTIVATION_FUNCTION = ActivationFunction.ReLU();
//private final static ActivationFunction ACTIVATION_FUNCTION = ActivationFunction.tanh();
//private final static ActivationFunction ACTIVATION_FUNCTION = ActivationFunction.sigmoid();
Step 2: Training on the MNIST Data
The trainNNT_Mnist.java
is our training ground.
//usr/bin/env jbang "$0" "$@" ; exit $?
//SOURCES NeuralNetTutorial.java
//SOURCES TrainingData.java
package com.makariev.examples.ai.neuralnet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class trainNNT_Mnist {
public static void main(String[] args) {
final long startTime = System.currentTimeMillis();
// Example: Assuming the data has 784 inputs, 64 hidden neurons, and 10 output
final NeuralNetTutorial.Net myNet = new NeuralNetTutorial.Net(Arrays.asList(784, 64, 10));
final TrainingData trainData = TrainingData.mnistTrainData();
// Train
for (int epoch = 0; epoch < 5; epoch++) {
trainData.trainLine((inputVals, targetVals) -> {
double[] input = inputVals;
for (int n = 0; n < input.length; n++) {
// normalization
// Scale the pixel values to the range [0,1]
input[n] = input[n] / 255;
}
double[] target = new double[10];
target[(int) targetVals[0]] = 1; // One-hot encoding
final List<Double> inputValues = new ArrayList<>();
for (int i = 0; i < input.length; i++) {
inputValues.add(input[i]);
}
// Train the MLP with the current sample
// Get new input data and feed it forward:
myNet.feedForward(inputValues);
final List<Double> targetValues = new ArrayList<>();
for (int i = 0; i < target.length; i++) {
targetValues.add(target[i]);
}
// Train the net what the outputs should have been:
myNet.backProp(targetValues);
return true;
});
// Test and Calculate Accuracy
trainData.testPredictChunk(
10_000,
"Epoch: %d, ".formatted(epoch),
(inputVals, targetVals) -> {
final List<Double> inputValues = new ArrayList<>();
for (int n = 0; n < inputVals.length; n++) {
// normalization
// Scale the pixel values to the range [0,1]
inputValues.add(inputVals[n] / 255);
}
myNet.feedForward(inputValues);
final double[] predictions = myNet.getResults().stream().mapToDouble(Double::doubleValue).toArray();
final int predictedLabel = TrainingData.getMaxIndex(predictions);
return predictedLabel == (int) targetVals[0];
}
);
}
System.out.println();
trainData.testPredictChunk(10_000, (inputVals, targetVals) -> {
final List<Double> inputValues = new ArrayList<>();
for (int n = 0; n < inputVals.length; n++) {
// normalization
// Scale the pixel values to the range [0,1]
inputValues.add(inputVals[n] / 255);
}
myNet.feedForward(inputValues);
final double[] predictions = myNet.getResults().stream().mapToDouble(Double::doubleValue).toArray();
final int predictedLabel = TrainingData.getMaxIndex(predictions);
return predictedLabel == (int) targetVals[0];
});
System.out.println("\nexecution time: " + (System.currentTimeMillis() - startTime) + "ms\n");
}
}
Here we load the MNIST training data, normalize it by scaling pixel values to the range [0,1], and then feed it into our network. We use the backpropagation algorithm to adjust the weights of the network, using the mean squared error as our loss function to gauge performance.
Training is done over multiple epochs, with each pass the network learning a bit more about the handwritten digits it sees. We also implement a mechanism to track and reduce the error over time.
Step 3: Testing and Validating the Network
After training, we don’t just want a network that memorizes the digits; we want one that generalizes well to new, unseen data. This is where our testing phase comes in. We test the network’s accuracy against a set of data it hasn’t seen before, making predictions and comparing them to the true values.
Seeing the Results in Action
As we run our training script, we print out the network’s accuracy at the end of each epoch, observing as it (hopefully) increases with each pass. The excitement comes in watching the network’s predictions improve, transforming from random guesses to confident, accurate classification.
using Random Initialization with hyperbolic tangent activation function
jbang trainNNT_Mnist.java
[jbang] Building jar for trainNNT_Mnist.java...
finished initialization
Epoch: 0, Accuracy: 63.04%
Epoch: 1, Accuracy: 75.37%
Epoch: 2, Accuracy: 79.64%
Epoch: 3, Accuracy: 81.57%
Epoch: 4, Accuracy: 82.38%
Accuracy: 82.38%
execution time: 145985ms
using He Initialization with hyperbolic tangent activation function
jbang trainNNT_Mnist.java
[jbang] Building jar for trainNNT_Mnist.java...
finished initialization
Epoch: 0, Accuracy: 82.02%
Epoch: 1, Accuracy: 87.44%
Epoch: 2, Accuracy: 88.69%
Epoch: 3, Accuracy: 89.21%
Epoch: 4, Accuracy: 89.81%
Accuracy: 89.81%
execution time: 145787ms
using He Initialization with Leaky ReLU
jbang trainNNT_Mnist.java
[jbang] Building jar for trainNNT_Mnist.java...
finished initialization
Epoch: 0, Accuracy: 81.50%
Epoch: 1, Accuracy: 89.11%
Epoch: 2, Accuracy: 90.78%
Epoch: 3, Accuracy: 91.88%
Epoch: 4, Accuracy: 92.33%
Accuracy: 92.33%
execution time: 147053ms
As you can see by the results above, the Neural Network is quite sensitive to both initialization and activation function.
Conclusion
This foray into neural networks with Java and JBang takes you through the process of building, training, and validating a neural network for a classic machine learning task. It showcases the strength of Java in handling complex tasks like neural network computation, while also demonstrating how with tools like JBang, we can streamline the process, making it accessible and manageable.
As you dive into the code, experiment, and tweak the parameters, you’ll gain a deeper understanding of neural networks’ inner workings and the beauty of machine learning. The MNIST dataset offers a fantastic playground for beginners and seasoned practitioners alike, providing immediate visual feedback and a benchmark that’s stood the test of time.
So grab your favorite Java IDE, install JBang, and get ready to embark on an exciting journey into the world of neural networks and handwritten digit classification.
Happy coding!