Tutorial on training the MNN (wrapper ver)#
1. Quick start: three steps to run your first MNN model#
The following provides a step-by-step instruction to train an MNN to learn MNIST image classification task with a multi-layer perceptron architecture.
Clone the repository to your local drive.
Copy the demo files,
./example/mnist/mnist.pyand./example/mnist/mnist_config.yamlto the root directory.Create two directories,
./checkpoint/for saving trained model results and./data/for downloading the MNIST dataset.Run the following command to call the script named
mnist.pywith the config file specified through the option:python mnist.py --config=./mnist_config.yaml
After training is finished, you should find four files in the ./checkpoint/mnist/ folder:
Two ‘.pth’ files which contain the trained model parameters.
One ‘.yaml’ file which is a copy of the config file used for running the training the model.
One ‘.txt’ log file that prints the standard output during training (such as model performance).
One directroy called
mnn_net_snn_resultthat stores the simulation result of the SNN reconstructed from the trained MNN (if enabled).
2. Step-by-step explanation#
Here we will illustate how the above codes work.
Before we start, we need load the reqired packages.
Since MNN is still at an early stage of development such that we have not yet published it on Pypi so you need to copy this notebook to root directory of the repo (moment-neural-network).
import torch
from mnn import snn, models, utils
from mnn.utils.training_tools import general_train, general_prepare
When calling the script
python mnist.py --config=./mnist_config.yaml
It will first load all necessary hyperparameters for training, which is equivalent to run the following code:
class TempArgs:
def __init__(self):
self.bs = 50 # batch size
self.print_freq = 20 # print frequency
self.dir = 'mnist' # directory name to save the model
self.save_name = 'mnn_mnist' # name of the model to save
self.use_cuda = True # whether to use cuda
self.seed = None # random seed
self.resume = False # whether to resume training from a checkpoint
self.distributed = False # whether to use distributed training
self.evaluate = False # whether to evaluate the model only
self.start_epoch = 0 # starting epoch
self.local_rank = 0 # local rank for distributed training
args = TempArgs()
setattr(args, 'config', './examples/mnist/mnist_config.yaml') # path to the config file
args = general_prepare.set_config2args(args)
for key, value in args.__dict__.items():
print(f'{key}: {value}')
bs: 50
print_freq: 20
dir: mnist
save_name: mnn_mnist
use_cuda: True
seed: None
resume: False
distributed: False
evaluate: False
start_epoch: 0
local_rank: 0
config: ./examples/mnist/mnist_config.yaml
LR_SCHEDULER: None
OPTIMIZER: {'name': 'AdamW', 'args': {'lr': 0.001, 'weight_decay': 0.01}}
DATASET: None
DATALOADER: None
MODEL: {'meta': {'arch': 'mnn_mlp', 'cnn_type': None, 'mlp_type': 'mnn_mlp'}, 'mnn_mlp': {'structure': [784, 100], 'num_class': 10, 'bn_bias_var': False, 'predict_bias': True, 'predict_bias_var': False, 'special_init': True, 'dropout': None, 'momentum': 0.9, 'eps': 1e-05}, 'snn_mlp': {'structure': [784, 800], 'num_class': 10, 'use_cov': False, 'bn_bias_var': False}}
CRITERION: {'name': 'CrossEntropyOnMean', 'source': 'mnn_core', 'args': {'reduction': 'mean'}}
DATAAUG_TRAIN: {'aug_order': ['ToTensor'], 'RandomCrop': {'size': 28, 'padding': 2}}
DATAAUG_VAL: {'aug_order': ['ToTensor']}
workers: 2
lr: 0.001
epochs: 1
pin_mem: True
world_size: 1
dataset: mnist
dataset_type: classic
input_prepare: flatten_poisson
save_epoch_state: False
scale_factor: 1.0
data_dir: ./data/
task_type: classification
background_noise: None
dump_path: ./checkpoint/
In this configuration,
MODELspecifies the network architecture and hyperparameters used to constructe an MNN model. In this tutorial, the configuration will create a feedforward MNN with one hidden layer (100 neurons).OPTIMIZERspecifies the optimizer (offered by Pytorch) and its hyperparameters.CRITERIONspecifies the criterion (loss function) for optimizing the model.sourcefurther specifies where to get the corresponding module. By default we will useCrossEntropyOnMean.DATASETandDATALOADERare for those who want customize their own dataset. By default, we use the dataset provided by torchvision, wheredatasetspecify which dataset is used anddata_diris the path where dataset is stored. In this tutorial, we use MNIST. The batch size is specified bybsand the training epochs is specified byepochs.DATAAUG_TRAINandDATAAUG_VALspecify the type of data augmentation.
After setting up all necessary hyperparameters, we can run the following code to start training:
args.print_freq = 100 # set print frequency to 100 for this example
# Simlarly, you can set other hyperparameters as needed
general_train.general_train_pipeline(args, train_func=general_train.TrainProcessCollections)
Epoch: [0][ 0/1200] Time 0.024 ( 0.024) Data 0.007 ( 0.007) Loss 2.3007e+00 (2.3007e+00) Acc@1 8.00 ( 8.00)
Epoch: [0][ 100/1200] Time 0.030 ( 0.028) Data 0.004 ( 0.004) Loss 1.9582e+00 (2.0891e+00) Acc@1 70.00 ( 64.28)
Epoch: [0][ 200/1200] Time 0.031 ( 0.029) Data 0.004 ( 0.004) Loss 1.6539e+00 (1.9580e+00) Acc@1 88.00 ( 71.00)
Epoch: [0][ 300/1200] Time 0.030 ( 0.030) Data 0.004 ( 0.004) Loss 1.4744e+00 (1.8345e+00) Acc@1 82.00 ( 74.86)
Epoch: [0][ 400/1200] Time 0.030 ( 0.030) Data 0.004 ( 0.004) Loss 1.2726e+00 (1.7258e+00) Acc@1 90.00 ( 76.98)
Epoch: [0][ 500/1200] Time 0.030 ( 0.030) Data 0.004 ( 0.004) Loss 1.2473e+00 (1.6237e+00) Acc@1 78.00 ( 78.71)
Epoch: [0][ 600/1200] Time 0.030 ( 0.030) Data 0.004 ( 0.004) Loss 1.0248e+00 (1.5336e+00) Acc@1 90.00 ( 79.94)
Epoch: [0][ 700/1200] Time 0.030 ( 0.030) Data 0.004 ( 0.004) Loss 9.2382e-01 (1.4497e+00) Acc@1 92.00 ( 81.05)
Epoch: [0][ 800/1200] Time 0.030 ( 0.030) Data 0.004 ( 0.004) Loss 7.5740e-01 (1.3746e+00) Acc@1 90.00 ( 81.96)
Epoch: [0][ 900/1200] Time 0.029 ( 0.030) Data 0.004 ( 0.004) Loss 8.4661e-01 (1.3089e+00) Acc@1 76.00 ( 82.67)
Epoch: [0][1000/1200] Time 0.029 ( 0.030) Data 0.004 ( 0.004) Loss 7.5224e-01 (1.2495e+00) Acc@1 84.00 ( 83.27)
Epoch: [0][1100/1200] Time 0.029 ( 0.030) Data 0.004 ( 0.004) Loss 6.3211e-01 (1.1956e+00) Acc@1 88.00 ( 83.74)
Epoch: [0] * Time 0.030 Data 0.004 Loss 1.149 Acc@1 84.182
Test [Epoch:0]: * Time 0.014 Data 0.000 Loss 0.574 Acc@1 90.040
When training is finished, you are will find a directory named by dir in the path specified by dump_path. The directory will contain four files named by save_name with different suffixes:
*_config.yamlrecords all hyperparamters used in training so you can reproduce the experiments.*_log.txtrecords the loss and accuracy of the model during the training process.*.pthcontains the model parameters at the last epoch.*_best_model.pthcontains the model parameters that hit the highest accuracy on the validation set during training.
Reconstruct SNN based on trained MNN#
The parameters of MNN can be directly used in SNN without further fine tuning. We also provided a pipeline to recontructe SNN based on trained MNN and run simulation by using the following codes:
dt = 1 # time step for simulation
input_type = 'poisson' # Using Poisson process to generate input spikes
num_trial = 100 # number of trials for validation
running_time = 100 # running time for each trial in ms
pregenerate = False # whether to pregenerate the input spikes
m = snn.functional.MnnSnnValidate(args, running_time=running_time, dt=dt, num_trials=num_trial,
pregenerate=pregenerate, resume_best=False, input_type=input_type)
for index in range(5): # run simuations with the first 5 samples in the validation set
m.validate_one_sample(index, do_reset=True, dump_spike_train=True, record=True)
test set, Img idx: 0, target: 7, pred: tensor([7])
test set, Img idx: 1, target: 2, pred: tensor([2])
test set, Img idx: 2, target: 1, pred: tensor([1])
test set, Img idx: 3, target: 0, pred: tensor([0])
test set, Img idx: 4, target: 4, pred: tensor([4])
You will find another directory under dir by the name save_name with a suffix _snn_validate_result.
There are two types of file:
*.snnvalstores the information of running the simulation such as spike count and simulation duration.*.sptstores the spike trains of hidden neurons during the simulation, which is stored as a sparse tensor with the shape(int(running_time/dt), num_trial, hidden_neurons).