Tutorial on moment activation#
1. Theoretical backgrounds#
The spiking neuron model#
Consider the current-based leaky integrate-and-fire (LIF) neuron model
where the sub-threshold membrane potential \(V_i(t)\) of a neuron \(i\) is driven by the total synaptic current \(I_i(t)\) and \(L=0.05\) \(\text{ms}^{-1}\) is the conductance. When the membrane potential \(V_i(t)\) exceeds a threshold \(V_{\rm th}=20\) mV a spike is emitted, as represented with a Dirac delta function. Afterwards, the membrane potential \(V_i(t)\) is reset to the resting potential \(V_{\rm res}=0\) mV, followed by a refractory period \(T_{\rm ref}=5\) ms. The synaptic current takes the form
where \(S_j(t)=\sum_k \delta(t-t^k_j)\) represents the spike train generated by pre-synaptic neurons.
Under certain conditions, the synaptic current can be replaced by a Gaussian white noise with an appropriate mean \(\bar{\mu}_i\) and variance \(\bar{\sigma}^2_i\), such that the membrane potential distribution matches the original. This technique is known as the diffusion approximation. Our goal here is, given the input current statistics, to calculate the output spike count statistics defined by
and
where \(N_i(\Delta t)\) is the spike count of neuron \(i\) over a time window \(\Delta t\). Here, we refer the moments \(\mu_i\) and \(C_{ij}\) as the mean firing rate and the firing co-variability, respectively. Such a input-output mapping for statistical moments is called the moment activation.
Moment activation#
The moment activation (MA) is given by $\( \mu_i = \phi_\mu(\bar{\mu}_i,\bar{\sigma}_i), \)$
where the correlation coefficient \(\rho_{ij}\) is related to the covariance as \(C_{ij}=\sigma_i\sigma_j\rho_{ij}\).
The functions \(\phi_\mu\) and \(\phi_\sigma\) together map the mean and variance of the input current to that of the output spikes according to
where \(T_{\rm ref}\) is the refractory period with integration bounds \(I_{\rm lb}(\bar{\mu},\bar{\sigma}) = \tfrac{V_{\rm res}L-\bar{\mu}}{\sqrt{L}\bar{\sigma}}\) and \(I_{\rm ub}(\bar{\mu},\bar{\sigma}) = \tfrac{V_{\rm th}L-\bar{\mu}}{\sqrt{L}\bar{\sigma}}\). The constant parameters \(L\), \(V_{\rm res}\), and \(V_{\rm th}\) are identical to those in the LIF neuron model. The pair of Dawson-like functions \(g(x)\) and \(h(x)\) are \(g(x)=e^{x^2}\int_{-\infty}^x e^{-u^2}du\) and \(h(x)=e^{x^2}\int_{-\infty}^x e^{-u^2}[g(u)]^2du\).
The function \(\chi\), which we refer to as the linear perturbation coefficient, is equal to \(\chi(\bar{\mu},\bar{\sigma})=\tfrac{\bar{\sigma}}{\sigma}\tfrac{\partial\mu}{\partial\bar{\mu}}\) and it is derived using a linear perturbation analysis around \(\bar{\rho}_{ij}=0\). This approximation is justified as pairwise correlations between neurons in the brain are typically weak.
2. Mean-driven vs fluctuation-driven spiking activity#
Let us first focus on a single spiking neuron and explore its firing properties when driven by noisy input currents of varying mean and variance.
# First, the necessary imports
# You need to copy this notebook to root directory of the repo (moment-neural-network)
from mnn.mnn_core.mnn_utils import Param_Container, Mnn_Core_Func
import numpy as np
from matplotlib import pyplot as plt
# Simulator of a spikng neuron simulator
class InteNFire():
def __init__(self):
self.L = 1/20 #ms
self.Vth = 20
self.Vres = 0
self.Tref = 5 #ms
self.Vspk = 50 #for visualization purposes only
self.dt = 1e-2 #integration time step (ms)
self.num_neurons = 2
def forward(self, v, tref, is_spike, ff_current):
#compute voltage
v += -v*self.L*self.dt + ff_current
#compute spikes
is_ref = (tref > 0.0) & (tref < self.Tref)
is_spike = (v > self.Vth) & ~is_ref
is_sub = ~(is_ref | is_spike)
v[is_spike] = self.Vspk
v[is_ref] = self.Vres
#update refractory period timer
tref[is_sub] = 0.0
tref[is_ref | is_spike] += self.dt
return v, tref, is_spike
def run(self, T, input_mean, input_std, input_corr=None):
'''Simulate integrate and fire neurons
T = simulation duration (ms)
'''
self.T = T #min(10e3, 100/maf_u) #T = desired number of spikes / mean firing rate
num_timesteps = int(self.T/self.dt)
t = np.arange(0,self.T,self.dt)
tref = np.zeros(self.num_neurons) #tracker for refractory period
v = np.zeros(self.num_neurons) #initial voltage
is_spike = np.zeros(self.num_neurons)
t = np.arange(0, self.T , self.dt)
V = np.zeros((self.num_neurons,num_timesteps)) #probably out of memory on gpu
input_curr = np.zeros((self.num_neurons, num_timesteps))
if input_corr: # if there is input correlation
# Define desired correlation matrix
rho = np.array([
[1.0, input_corr],
[input_corr, 1.0]
])
# Cholesky decomposition
L = np.linalg.cholesky(rho)
else:
L = np.eye(self.num_neurons)
for i in range(num_timesteps):
noise = L @ np.random.randn(self.num_neurons).reshape(-1,1)
I = input_mean*self.dt + input_std*np.sqrt(self.dt)*noise.flatten()
v, tref, is_spike = self.forward(v, tref, is_spike, I)
V[:,i] = v
input_curr[:,i]=I/self.dt
return V,t,input_curr
# Plotting routine
def plot_moment_activation(input_mean, input_std, input_corr=None):
# plot moment activation
curr_mean = np.linspace(-5,5,51)
curr_std = np.linspace(0,10,51)
X, Y = np.meshgrid(curr_mean,curr_std)
ma = Mnn_Core_Func()
mean_out = ma.forward_fast_mean(X, Y)
std_out = ma.forward_fast_std(X,Y,mean_out)
FF_out = std_out**2/mean_out
FF_out[np.isnan(FF_out)]=1.0
chi_out = ma.forward_fast_chi(X,Y,mean_out,std_out)
plt.figure(figsize=(3.5*3, 3))
plt.subplot(1,3,1)
plt.imshow(mean_out, origin='lower', extent=[curr_mean[0],curr_mean[-1],curr_std[0],curr_std[-1]])
plt.plot(input_mean, input_std, 'or')
plt.title("Mean firing rate (sp/ms)")
plt.xlabel('Input current mean')
plt.ylabel('Input current std')
plt.colorbar()
plt.subplot(1,3,2)
plt.imshow(std_out, origin='lower', extent=[curr_mean[0],curr_mean[-1],curr_std[0],curr_std[-1]])
plt.plot(input_mean, input_std, 'or')
plt.title("Firing variability (sp$^2$/ms)")
plt.xlabel('Input current mean')
plt.colorbar()
plt.subplot(1,3,3)
plt.imshow(FF_out, origin='lower', extent=[curr_mean[0],curr_mean[-1],curr_std[0],curr_std[-1]])
plt.plot(input_mean, input_std, 'or')
plt.title("Fano factor")
plt.xlabel('Input current mean')
plt.colorbar()
plt.tight_layout()
if input_corr:
corr_out = chi_out**2*input_corr
plt.figure(figsize=(3.5*3, 3))
plt.subplot(1,3,1)
plt.imshow(chi_out, origin='lower', extent=[curr_mean[0],curr_mean[-1],curr_std[0],curr_std[-1]])
plt.plot(input_mean, input_std, 'or')
plt.title("Linear response coefficient")
plt.xlabel('Input current mean')
plt.ylabel('Input current std')
plt.colorbar()
plt.subplot(1,3,2)
plt.imshow(corr_out, origin='lower', cmap='coolwarm', vmin=-1,vmax=1,extent=[curr_mean[0],curr_mean[-1],curr_std[0],curr_std[-1]])
plt.plot(input_mean, input_std, 'or')
plt.title("Correlation coefficient")
plt.xlabel('Input current mean')
plt.colorbar()
plt.tight_layout()
Run the code below to simulate the spiking activity of a single LIF neuron receiving white gaussian noise as input.
The top panels show the time series of the membrane potential and the synaptic current that drives it. The bottom panels show moment activation over the two-dimensional plane spanned by the input current mean and variance.
Observe the spiking activity under the following conditions:
Constant current without noise. Start from
input_mean=0.8andinput_std=0, and then gradually increaseinput_mean. What’s the critical input current necessary for generating spikes?Mean-dominant activity. Start from
input_mean=1andinput_std=1, and then gradually increaseinput_std. Observe how the presence of noise induces irregular spiking activity.Fluctuation-dominant activity. Start from
input_mean=0andinput_std=1, and then gradually increaseinput_std. Can you find a parameter regime when the Fano factor is larger than one?
Can you explain the observed neural spiking activity using the moment activation?
# specify input current stats
input_mean = 0.8
input_std = 0
# simulate spiking activity
lif = InteNFire()
V,t,input_curr =lif.run(1000, input_mean, input_std)
# plot spiking activity
plt.figure(figsize=(3.5*3,3))
plt.subplot(2,1,1)
plt.plot(t,V[0,:])
plt.plot([t[0],t[-1]], [lif.Vth, lif.Vth],'r--')
plt.ylabel('$V_m$ (mV)')
plt.subplot(2,1,2)
plt.plot(t[::10],input_curr[0,::10])
plt.xlabel('Time (ms)')
plt.ylabel('$I$ (mV/ms)')
plt.tight_layout()
plt.show()
# plot moment activation
plot_moment_activation(input_mean,input_std)