Tutorial on training the MNN (vanilla ver)#
Here, we show a minimal example of training the MNN using the standard pytorch style without the wrapper. This is suited to those who require under-the-hood modifications of the model.
If you don’t already have PyTorch installed, you need to install it following the instruction on this page: https://pytorch.org/get-started/locally/
You need to copy this notebook to the root directory (under moment-neural-network).
First, the necessary imports.
from mnn.mnn_core.mnn_pytorch import *
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
A quick check of your pytorch version and GPU availability.
print('Using PyTorch version:', torch.__version__)
if torch.cuda.is_available():
print('Using GPU, device name:', torch.cuda.get_device_name(0))
device = torch.device('cuda')
else:
print('No GPU found, using CPU instead.')
device = torch.device('cpu')
Using PyTorch version: 2.7.0+cu126
Using GPU, device name: NVIDIA GeForce RTX 4090
Loading the data & input encoding#
PyTorch has two classes from torch.utils.data to work with data:
Dataset which represents the actual data items, such as images or pieces of text, and their labels
DataLoader which is used for processing the dataset in batches during training.
Here we will use TorchVision and torchvision.datasets to access the MNIST dataset. (By setting download=True, the code below will attempt to download the dataset if it doesn’t already exist locally.)
batch_size = 32
train_dataset = datasets.MNIST('./datasets/', train=True, download=True,
transform=transforms.Compose([ToTensor(),
transforms.Normalize((0,), (1,))]))
test_dataset = datasets.MNIST('./datasets/', train=False, download=True,
transform=transforms.Compose([ToTensor(),
transforms.Normalize((0,), (1,))]))
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
The data loaders provide a way of iterating through the datasets in batches.
# load the first batch of data
for (data, target) in train_loader:
print('data:', data.size(), 'type:', data.type())
print('target:', target.size(), 'type:', target.type())
break
data: torch.Size([32, 1, 28, 28]) type: torch.FloatTensor
target: torch.Size([32]) type: torch.LongTensor
We need to specify an appropriate encoding scheme of the input data to pass it to the MNN. Here, we suppose that the inputs are the statistical moments of independent Poisson spike trains whose firing rates are proportional to the image pixel values. Below is a helper function that implements this input encoding. The scale parameter describes how input pixel values should be converted to firing rates in sp/ms.
def input_encoder(data, scale=1):
data = torch.flatten(data, start_dim=1)
input_mean = data*scale
input_cov = torch.diag_embed(input_mean)
return input_mean, input_cov
# load the first batch of data
for (data, target) in train_loader:
input_mean, input_cov = input_encoder(data)
print('input_mean:', input_mean.size(), 'type:', data.type())
print('input_cov:', input_cov.size(), 'type:', data.type())
print('target:', target.size(), 'type:', target.type())
break
input_mean: torch.Size([32, 784]) type: torch.FloatTensor
input_cov: torch.Size([32, 784, 784]) type: torch.FloatTensor
target: torch.Size([32]) type: torch.LongTensor
Building a feedforward MNN#
A single feedforward layer of MNN consists of the following components:
linear (bilinear) layer: outputs synaptic current mean/covariance given pre-synaptic neuron spike mean/covariance. Accessed through the
LinearDuoclass undermnn.mnn_core.nn.linear.moment batch normalization: outputs batch-normalized synaptic current mean/covariance. This is a generalization of standard batchnorm to second-order moments and is required to avoid vanishing gradient problem. Accessed through the
CustomBatchNorm1Dclass undermnn.mnn_core.nn.custom_batch_norm.moment activation: outputs post-synaptic neuron spike mean/covairance, given input current mean/covariance. Accessed through the
OriginMnnActivationclass undermnn.mnn_core.nn.activation.
A single feedforward layer can be stack multiple times to form a deep MNN. For illustrative purposes, here we show an example consisting of a single hidden layer followed by a linear readout.
from mnn.mnn_core.nn.activation import OriginMnnActivation
from mnn.mnn_core.nn.linear import LinearDuo
from mnn.mnn_core.nn.custom_batch_norm import CustomBatchNorm1D
class SimpleMNN(torch.nn.Module):
def __init__(self, hidden_size = 64, input_size = 2, output_size = 1):
super(SimpleMNN, self).__init__()
self.linear = LinearDuo(input_size, hidden_size)
self.batchnorm = CustomBatchNorm1D(hidden_size)
self.activate = OriginMnnActivation()
self.readout = LinearDuo(hidden_size,output_size)
return
def forward(self, input_mean, input_cov):
curr_mean, curr_cov = self.linear(input_mean, input_cov)
bn_mean, bn_cov = self.batchnorm(curr_mean, curr_cov)
hidden_mean, hidden_cov = self.activate(bn_mean, bn_cov)
readout_mean, readout_cov = self.readout(hidden_mean, hidden_cov)
return readout_mean, readout_cov
The following script creates an instance of the model and prints the name and shape of their trainable parameters.
model = SimpleMNN(hidden_size=100, input_size=28*28,output_size=10)
# you can inspect the trainable parameters
print('Linear layer: ', model.linear)
for name, param in model.linear.named_parameters():
print(' Name: {}, Shape: {}'.format(name, param.shape))
print('Moment batchnorm: ', model.batchnorm)
for name, param in model.batchnorm.named_parameters():
print(' Name: {}, Shape: {}'.format(name, param.shape))
print('Moment activation', model.activate)
print('Readout layer:', model.readout)
for name, param in model.readout.named_parameters():
print(' Name: {}, Shape: {}'.format(name, param.shape))
Linear layer: LinearDuo(in_features: 784, out_features: 100, bias_mean: False, bias_var: False, dropout: False, scale: None)
Name: weight, Shape: torch.Size([100, 784])
Moment batchnorm: CustomBatchNorm1D(num_features: 100, bias_std=False, special_init=True, momentum=0.9, eps=1e-05, affine=True)
Name: weight, Shape: torch.Size([100])
Name: bias, Shape: torch.Size([100])
Moment activation OriginMnnActivation()
Readout layer: LinearDuo(in_features: 100, out_features: 10, bias_mean: False, bias_var: False, dropout: False, scale: None)
Name: weight, Shape: torch.Size([10, 100])
Training the MNN#
So far we have defined the dataset, the data loader, and the model. To train the model, we need to specify the loss function and the optimizer.
For classification problems, we provide CrossEntropyOnMean and GaussianSamplingCrossEntropyLoss under mnn.mnn_core.nn.criterion. The former is identical to the standard cross-entropy loss in PyTorch, whereas the latter is a generalized cross-entropy taking into account of the second-order moments of the output.
Below is a minimal example using the standard cross-entropy and Adam optimizer.
from mnn.mnn_core.nn.criterion import CrossEntropyOnMean
batch_size = 32
num_epoch = 1
lr = 0.01
input_size = 28*28
hidden_size = 100
output_size = 10
model = SimpleMNN(hidden_size = hidden_size, input_size = input_size, output_size = output_size)
params = model.parameters()
optimizer = torch.optim.Adam(params, lr = lr, amsgrad = True)
criterion = CrossEntropyOnMean()
for epoch in range(num_epoch):
model.train()
print('Training epoch {}/{}...'.format(epoch,num_epoch))
for i_batch, (images, target) in enumerate(train_loader):
optimizer.zero_grad()
input_mean, input_cov = input_encoder(images) # encode input data to moment representation
output_mean, output_cov = model.forward(input_mean,input_cov) # run the forward pass
loss = criterion((output_mean, output_cov), target) # calculate the loss function
loss.backward() # backpropagation
optimizer.step() # update model parameters
with torch.no_grad():
model.eval()
num_correct = 0
for i_batch, (images, target) in enumerate(test_loader):
input_mean, input_cov = input_encoder(images)
output_mean, output_cov = model.forward(input_mean,input_cov)
prediction = output_mean.argmax(1) # index of the largest entry in the output mean
num_correct += torch.sum(prediction == target).item() # count correct predictions
acc = np.round(num_correct/len(test_dataset)*100,2)
print('Validation accuracy = {}%'.format(acc))
Training epoch 0/1...
Validation accuracy = 94.43%
We can access all the trained parameters of the model and the state of the optimizer using the following lines of code:
print('Model state dictionary: ', model.state_dict().keys())
print('Optimizer state dictionary: ', optimizer.state_dict().keys())
Model state dictionary: odict_keys(['linear.weight', 'batchnorm.weight', 'batchnorm.bias', 'batchnorm.running_mean', 'batchnorm.running_var', 'readout.weight'])
Optimizer state dictionary: dict_keys(['state', 'param_groups'])
Reconstruct spiking neural network#
As the MNN is derived from its corresponding spiking neural network (SNN) model (of current-based leaky integrate-and-fire neurons) on a mathematically rigorous ground, the trained parameters can be used to reconstruct the SNN without futher tuning.
Note that the moment batchnorm is only required for training purposes, and we can simply absorb its parameters into the linear layer, using the following helper function:
@torch.no_grad()
def weight_fusion(ln, bn):
ln_weight = ln.weight.detach()
bn_weight = bn.weight / torch.sqrt(bn.running_var + bn.eps)
bn_weight = bn_weight.detach()
weight = ln_weight * bn_weight.unsqueeze(-1)
bias = -bn.running_mean * bn_weight + bn.bias
return weight, bias
weight, bias = weight_fusion(model.linear,model.batchnorm)
These weight and bias can then be used to reconstruct the SNN that will generate the same firing statistics as in the MNN.
Exercises#
Try modifying
SimpleMNNby stacking multiple hidden layers to form a deep MNN.Replace the task with a regression problem and also the loss function accordingly. Hint: see
MSEOnMeanandLikelihoodMSEundermnn.mnn_core.nn.criterion.