Tutorial on training the MNN (wrapper ver)#

1. Quick start: three steps to run your first MNN model#

The following provides a step-by-step instruction to train an MNN to learn MNIST image classification task with a multi-layer perceptron architecture.

  1. Clone the repository to your local drive.

  2. Copy the demo files, ./example/mnist/mnist.py and ./example/mnist/mnist_config.yaml to the root directory.

  3. Create two directories, ./checkpoint/ for saving trained model results and ./data/ for downloading the MNIST dataset.

  4. Run the following command to call the script named mnist.py with the config file specified through the option:

    python mnist.py --config=./mnist_config.yaml
    

After training is finished, you should find four files in the ./checkpoint/mnist/ folder:

  • Two ‘.pth’ files which contain the trained model parameters.

  • One ‘.yaml’ file which is a copy of the config file used for running the training the model.

  • One ‘.txt’ log file that prints the standard output during training (such as model performance).

  • One directroy called mnn_net_snn_result that stores the simulation result of the SNN reconstructed from the trained MNN (if enabled).

2. Step-by-step explanation#

Here we will illustate how the above codes work. Before we start, we need load the reqired packages. Since MNN is still at an early stage of development such that we have not yet published it on Pypi so you need to copy this notebook to root directory of the repo (moment-neural-network).

import torch
from mnn import snn, models, utils
from mnn.utils.training_tools import general_train, general_prepare

When calling the script

python mnist.py --config=./mnist_config.yaml

It will first load all necessary hyperparameters for training, which is equivalent to run the following code:

class TempArgs:
    def __init__(self):
        self.bs = 50 # batch size
        self.print_freq = 20 # print frequency
        self.dir = 'mnist' # directory name to save the model
        self.save_name = 'mnn_mnist' # name of the model to save
        self.use_cuda = True # whether to use cuda
        self.seed = None # random seed
        self.resume = False # whether to resume training from a checkpoint
        self.distributed = False # whether to use distributed training
        self.evaluate = False # whether to evaluate the model only
        self.start_epoch = 0 # starting epoch
        self.local_rank = 0 # local rank for distributed training

args = TempArgs()
setattr(args, 'config', './examples/mnist/mnist_config.yaml') # path to the config file
args = general_prepare.set_config2args(args)
for key, value in args.__dict__.items():
    print(f'{key}: {value}')
bs: 50
print_freq: 20
dir: mnist
save_name: mnn_mnist
use_cuda: True
seed: None
resume: False
distributed: False
evaluate: False
start_epoch: 0
local_rank: 0
config: ./examples/mnist/mnist_config.yaml
LR_SCHEDULER: None
OPTIMIZER: {'name': 'AdamW', 'args': {'lr': 0.001, 'weight_decay': 0.01}}
DATASET: None
DATALOADER: None
MODEL: {'meta': {'arch': 'mnn_mlp', 'cnn_type': None, 'mlp_type': 'mnn_mlp'}, 'mnn_mlp': {'structure': [784, 100], 'num_class': 10, 'bn_bias_var': False, 'predict_bias': True, 'predict_bias_var': False, 'special_init': True, 'dropout': None, 'momentum': 0.9, 'eps': 1e-05}, 'snn_mlp': {'structure': [784, 800], 'num_class': 10, 'use_cov': False, 'bn_bias_var': False}}
CRITERION: {'name': 'CrossEntropyOnMean', 'source': 'mnn_core', 'args': {'reduction': 'mean'}}
DATAAUG_TRAIN: {'aug_order': ['ToTensor'], 'RandomCrop': {'size': 28, 'padding': 2}}
DATAAUG_VAL: {'aug_order': ['ToTensor']}
workers: 2
lr: 0.001
epochs: 1
pin_mem: True
world_size: 1
dataset: mnist
dataset_type: classic
input_prepare: flatten_poisson
save_epoch_state: False
scale_factor: 1.0
data_dir: ./data/
task_type: classification
background_noise: None
dump_path: ./checkpoint/

In this configuration,

  • MODEL specifies the network architecture and hyperparameters used to constructe an MNN model. In this tutorial, the configuration will create a feedforward MNN with one hidden layer (100 neurons).

  • OPTIMIZER specifies the optimizer (offered by Pytorch) and its hyperparameters.

  • CRITERION specifies the criterion (loss function) for optimizing the model. source further specifies where to get the corresponding module. By default we will use CrossEntropyOnMean.

  • DATASET and DATALOADER are for those who want customize their own dataset. By default, we use the dataset provided by torchvision, where dataset specify which dataset is used and data_dir is the path where dataset is stored. In this tutorial, we use MNIST. The batch size is specified by bs and the training epochs is specified by epochs.

  • DATAAUG_TRAIN and DATAAUG_VAL specify the type of data augmentation.

After setting up all necessary hyperparameters, we can run the following code to start training:

args.print_freq = 100 # set print frequency to 100 for this example
# Simlarly, you can set other hyperparameters as needed
general_train.general_train_pipeline(args, train_func=general_train.TrainProcessCollections)
Epoch: [0][   0/1200]	Time  0.024 ( 0.024)	Data  0.007 ( 0.007)	Loss 2.3007e+00 (2.3007e+00)	Acc@1   8.00 (  8.00)
Epoch: [0][ 100/1200]	Time  0.030 ( 0.028)	Data  0.004 ( 0.004)	Loss 1.9582e+00 (2.0891e+00)	Acc@1  70.00 ( 64.28)
Epoch: [0][ 200/1200]	Time  0.031 ( 0.029)	Data  0.004 ( 0.004)	Loss 1.6539e+00 (1.9580e+00)	Acc@1  88.00 ( 71.00)
Epoch: [0][ 300/1200]	Time  0.030 ( 0.030)	Data  0.004 ( 0.004)	Loss 1.4744e+00 (1.8345e+00)	Acc@1  82.00 ( 74.86)
Epoch: [0][ 400/1200]	Time  0.030 ( 0.030)	Data  0.004 ( 0.004)	Loss 1.2726e+00 (1.7258e+00)	Acc@1  90.00 ( 76.98)
Epoch: [0][ 500/1200]	Time  0.030 ( 0.030)	Data  0.004 ( 0.004)	Loss 1.2473e+00 (1.6237e+00)	Acc@1  78.00 ( 78.71)
Epoch: [0][ 600/1200]	Time  0.030 ( 0.030)	Data  0.004 ( 0.004)	Loss 1.0248e+00 (1.5336e+00)	Acc@1  90.00 ( 79.94)
Epoch: [0][ 700/1200]	Time  0.030 ( 0.030)	Data  0.004 ( 0.004)	Loss 9.2382e-01 (1.4497e+00)	Acc@1  92.00 ( 81.05)
Epoch: [0][ 800/1200]	Time  0.030 ( 0.030)	Data  0.004 ( 0.004)	Loss 7.5740e-01 (1.3746e+00)	Acc@1  90.00 ( 81.96)
Epoch: [0][ 900/1200]	Time  0.029 ( 0.030)	Data  0.004 ( 0.004)	Loss 8.4661e-01 (1.3089e+00)	Acc@1  76.00 ( 82.67)
Epoch: [0][1000/1200]	Time  0.029 ( 0.030)	Data  0.004 ( 0.004)	Loss 7.5224e-01 (1.2495e+00)	Acc@1  84.00 ( 83.27)
Epoch: [0][1100/1200]	Time  0.029 ( 0.030)	Data  0.004 ( 0.004)	Loss 6.3211e-01 (1.1956e+00)	Acc@1  88.00 ( 83.74)
Epoch: [0] * Time 0.030 Data 0.004 Loss 1.149 Acc@1 84.182
Test [Epoch:0]:  * Time 0.014 Data 0.000 Loss 0.574 Acc@1 90.040

When training is finished, you are will find a directory named by dir in the path specified by dump_path. The directory will contain four files named by save_name with different suffixes:

  • *_config.yaml records all hyperparamters used in training so you can reproduce the experiments.

  • *_log.txt records the loss and accuracy of the model during the training process.

  • *.pth contains the model parameters at the last epoch.

  • *_best_model.pth contains the model parameters that hit the highest accuracy on the validation set during training.

Reconstruct SNN based on trained MNN#

The parameters of MNN can be directly used in SNN without further fine tuning. We also provided a pipeline to recontructe SNN based on trained MNN and run simulation by using the following codes:

dt = 1 # time step for simulation
input_type = 'poisson' # Using Poisson process to generate input spikes
num_trial = 100 # number of trials for validation
running_time = 100 # running time for each trial in ms
pregenerate = False # whether to pregenerate the input spikes
m = snn.functional.MnnSnnValidate(args, running_time=running_time, dt=dt, num_trials=num_trial, 
pregenerate=pregenerate, resume_best=False, input_type=input_type)
for index in range(5): # run simuations with the first 5 samples in the validation set
    m.validate_one_sample(index, do_reset=True, dump_spike_train=True, record=True)
test set, Img idx: 0, target: 7, pred: tensor([7])
test set, Img idx: 1, target: 2, pred: tensor([2])
test set, Img idx: 2, target: 1, pred: tensor([1])
test set, Img idx: 3, target: 0, pred: tensor([0])
test set, Img idx: 4, target: 4, pred: tensor([4])

You will find another directory under dir by the name save_name with a suffix _snn_validate_result. There are two types of file:

  • *.snnval stores the information of running the simulation such as spike count and simulation duration.

  • *.spt stores the spike trains of hidden neurons during the simulation, which is stored as a sparse tensor with the shape (int(running_time/dt), num_trial, hidden_neurons).