Skip to main content

Introduction

image

SNNAX is a lightweight library built on Equinox and JAX to provide a spiking neural network (SNN) simulator for deep learning. It is designed to be user-friendly and flexible, allowing users to define custom SNN layers while leveraging common deep learning layers from Equinox. SNNAX is fully compatible with JAX, enabling the use of JAX's function transformation features like vectorization with jax.vmap, automatic differentiation, and JIT compilation with XLA.

Quick Start

The following code demonstrates how to define a simple SNN in SNNAX using the snnax.snn.Sequential class to stack layers of SNNs and Equinox layers into a feed-forward architecture:

import jax
import jax.numpy as jnp

import equinox as eqx
import snnax.snn as snn

import optax

model = snn.Sequential(eqx.Conv2D(2, 32, 7, 2, key=key1),
snn.LIF((8, 8), [.9, .8], key=key2),
snn.flatten(),
eqx.Linear(64, 11, key=key3),
snn.LIF(11, [.9, .8], key=key4))

Next, define a loss function for a single sample and use JAX's vectorization features to create a batched loss function. Note that the model's output is a tuple of membrane potentials and spikes. For a feed-forward SNN, use the last element of the spike list, out_spikes[-1], and sum the spikes across time to get the spike count.

# Simple batched loss function

@partial(jax.vmap, in_axes=(0, 0, 0))
def loss_fn(in_states, in_spikes, tgt_class):
out_state, out_spikes = model(in_states, in_spikes)

# Spikes from the last layer are summed across time
pred = out_spikes[-1].sum(-1)
loss = optax.softmax_cross_entropy(pred, tgt_class)
return loss


# Calculating the gradient with Equinox PyTree filters and
# subsequently jitting the resulting function
@eqx.filter_jit
@eqx.filter_value_and_grad
def loss_and_grad(in_states, in_spikes, tgt_class):
return jnp.mean(loss_fn(in_states, in_spikes, tgt_class))

Finally, train the model by feeding it input spike trains and states. Initialize the states of the SNN using the init_states method of the snnax.snn.Sequential class.

# ...
# Simple training loop

for in_spikes, tgt_class in tqdm(dataloader):
# Initializing the membrane potentials of LIF neurons
states = model.init_states(key)

# Jitting with Equinox PyTree filters
loss, grads = loss_and_grad(states, in_spikes, tgt_class)

# Update parameter PyTree with Equinox and optax
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)

Fully worked-out examples can be found in the examples directory.