Using a Kalman Filter to decode intended cursor position from iEEG data¶
In this tutorial, we will use this dataset, which includes data from four recording sessions of a macaque performing delayed center-out reaches when recording from the primary motor and dorsal premotor cortices. Delayed center-out reaches means a monkey was moving a cursor on a screen with its hand to move hit a target. The actions are 'delayed' because once a target appears on the screen, the monkey must wait before doing the action (allowing the researchers to study motor preparation).
In this tutorial, however, we want to build a system that can move a cursor on a screen by reading a monkey's brain signals (i.e. decode the intended motor action of the monkey). We will do this by predicting hand velocities of the monkey (which are our proxy for understanding how to move the cursor on the screen) given some neural data from the monkey.
This dataset actually has the neural data preprocessed! They already time-bin the data and provide spike timings.
An overview of neural data preprocessing¶
If this dataset had not provided us time bins and spike timings, how would we be able to find them? Understanding this is not essential to building our decoding model, but is helpful for realizing the hard work that goes into preprocessing!
Once we implant the monkey with electrode arrays in the regions of interest (motor cortex, in this context), the electrodes will begin reading the electrical activity (in microvolts) of nearby neurons at a specific sampling rate.
Filtering¶
Our data now looks like a large voltage vs time waveform. For our decoding task, we are only interested in specific frequencies of this data. For example, we care less about slower, low-frequency drifts and much more about sudden, high-frequency spikes. We can filter the waveform to only include the frequency regions of interest:
- Power line filtering: electrical power across the world is provided at 60Hz or 50Hz, and this frequency must be removed from our signals incase it was able to reach our electrodes. A notch filter is used to remove a single frequency.
- A mixture of bandpass filters are then used to isolate the frequencies of interest (spike frequencies are often between 300-3000Hz).
These filters are often implemented in an analog manner through hardware, directly in the implant.
Once we have time bins, what should we look for in the data to help us decode intended cursor position?¶
Our data right now still looks like large, voltage vs time waveforms but now filtered to only include frequencies helpful for our decoding tasks. What useful information could we take from this to help us with our decoding tasks? 2 common methods are:
Spike detection via a threshold: When the waveform crosses a specific threshold frequency, label the small window around the peak of that spike as a spike event. A threshold is often decided upon by taking a multiple of the standard deviation of the background noise in the data. To get the gaussian noise in the data, you can take $$\sigma = median(|x|) / 0.6745 $$, and set a threshold as some multiple of this (4-5x, for example).
Spike band RMS power: Take the RMS power or voltage of the entire spike band, and store this value, instead of identifying specific spike events.
This dataset, like many others, follows approach 1, where we detect single spike events. The data is structured such that spike times are binned at 1ms.
Discretizing our data is important - making time bins¶
When following the first method, we get many, individual spike timings. We make time bins for many reasons:
- When controlling a cursor, we don't need hypothetically instant update times (which would require a massive amount of computation). Instead, we can decide to update every 10ms, or 20ms, because we wouldn't notice the difference (10ms or 20ms update times are very fast). Therefore, we can split our data into time bins of this size, and analyze each bin individually to come to an estimate on what to do on the next update cycle.
- We want to understand how spiking rate (how many spikes/ms for example) changes over time. By making time bins and counting the number of spikes in that bin, we can make this discrete data continous to understand the change in spiking rate with time.
Decoding neural data¶
Let's get to the fun part! How do we take these spike times, binned at 1ms, and perform an update to where our cursor moves every 1ms? Since each recording channel (implants can often have 1000s of them) has its list of spike times in 1ms bins, how can we combine all this information and tell our decoder to output cursor movement. Let's break this down:
Setup¶
## Download dataset and required packages if necessary
!pip install git+https://github.com/neurallatents/nlb_tools.git
!pip install dandi
!dandi download https://gui.dandiarchive.org/#/dandiset/000128
# Note: ensure numpy is < 2.0
Collecting git+https://github.com/neurallatents/nlb_tools.git
Cloning https://github.com/neurallatents/nlb_tools.git to /private/var/folders/z5/3h5htfw96rbf7_7plxy7mys00000gp/T/pip-req-build-ugy49x2j
Running command git clone --filter=blob:none --quiet https://github.com/neurallatents/nlb_tools.git /private/var/folders/z5/3h5htfw96rbf7_7plxy7mys00000gp/T/pip-req-build-ugy49x2j
Resolved https://github.com/neurallatents/nlb_tools.git to commit 42f8410b88e12db136910fa2f888b025ea0aa2ae
Installing build dependencies ... one
Getting requirements to build wheel ... done
Preparing metadata (pyproject.toml) ... done
Requirement already satisfied: pandas<=1.3.4,>=1.0.0 in ./cursor_control/lib/python3.10/site-packages (from nlb_tools==0.0.3) (1.3.4)
Requirement already satisfied: scipy>=1.1.0 in ./cursor_control/lib/python3.10/site-packages (from nlb_tools==0.0.3) (1.15.3)
Requirement already satisfied: numpy in ./cursor_control/lib/python3.10/site-packages (from nlb_tools==0.0.3) (1.26.4)
Requirement already satisfied: scikit-learn in ./cursor_control/lib/python3.10/site-packages (from nlb_tools==0.0.3) (1.7.2)
Requirement already satisfied: h5py<4,>=2.9 in ./cursor_control/lib/python3.10/site-packages (from nlb_tools==0.0.3) (3.15.1)
Requirement already satisfied: pynwb in ./cursor_control/lib/python3.10/site-packages (from nlb_tools==0.0.3) (3.1.3)
Requirement already satisfied: python-dateutil>=2.7.3 in ./cursor_control/lib/python3.10/site-packages (from pandas<=1.3.4,>=1.0.0->nlb_tools==0.0.3) (2.9.0.post0)
Requirement already satisfied: pytz>=2017.3 in ./cursor_control/lib/python3.10/site-packages (from pandas<=1.3.4,>=1.0.0->nlb_tools==0.0.3) (2025.2)
Requirement already satisfied: six>=1.5 in ./cursor_control/lib/python3.10/site-packages (from python-dateutil>=2.7.3->pandas<=1.3.4,>=1.0.0->nlb_tools==0.0.3) (1.17.0)
Requirement already satisfied: hdmf<5,>=4.1.2 in ./cursor_control/lib/python3.10/site-packages (from pynwb->nlb_tools==0.0.3) (4.3.1)
Requirement already satisfied: platformdirs>=4.1.0 in ./cursor_control/lib/python3.10/site-packages (from pynwb->nlb_tools==0.0.3) (4.9.2)
Requirement already satisfied: jsonschema>=3.2.0 in ./cursor_control/lib/python3.10/site-packages (from hdmf<5,>=4.1.2->pynwb->nlb_tools==0.0.3) (4.26.0)
Requirement already satisfied: ruamel-yaml>=0.16 in ./cursor_control/lib/python3.10/site-packages (from hdmf<5,>=4.1.2->pynwb->nlb_tools==0.0.3) (0.19.1)
Requirement already satisfied: attrs>=22.2.0 in ./cursor_control/lib/python3.10/site-packages (from jsonschema>=3.2.0->hdmf<5,>=4.1.2->pynwb->nlb_tools==0.0.3) (25.4.0)
Requirement already satisfied: jsonschema-specifications>=2023.03.6 in ./cursor_control/lib/python3.10/site-packages (from jsonschema>=3.2.0->hdmf<5,>=4.1.2->pynwb->nlb_tools==0.0.3) (2025.9.1)
Requirement already satisfied: referencing>=0.28.4 in ./cursor_control/lib/python3.10/site-packages (from jsonschema>=3.2.0->hdmf<5,>=4.1.2->pynwb->nlb_tools==0.0.3) (0.37.0)
Requirement already satisfied: rpds-py>=0.25.0 in ./cursor_control/lib/python3.10/site-packages (from jsonschema>=3.2.0->hdmf<5,>=4.1.2->pynwb->nlb_tools==0.0.3) (0.30.0)
Requirement already satisfied: typing-extensions>=4.4.0 in ./cursor_control/lib/python3.10/site-packages (from referencing>=0.28.4->jsonschema>=3.2.0->hdmf<5,>=4.1.2->pynwb->nlb_tools==0.0.3) (4.15.0)
Requirement already satisfied: joblib>=1.2.0 in ./cursor_control/lib/python3.10/site-packages (from scikit-learn->nlb_tools==0.0.3) (1.5.3)
Requirement already satisfied: threadpoolctl>=3.1.0 in ./cursor_control/lib/python3.10/site-packages (from scikit-learn->nlb_tools==0.0.3) (3.6.0)
Requirement already satisfied: dandi in ./cursor_control/lib/python3.10/site-packages (0.74.3)
Requirement already satisfied: bidsschematools~=1.0 in ./cursor_control/lib/python3.10/site-packages (from dandi) (1.2.0)
Requirement already satisfied: bids-validator-deno>=2.0.5 in ./cursor_control/lib/python3.10/site-packages (from dandi) (2.4.0)
Requirement already satisfied: click<8.2.0,>=7.1 in ./cursor_control/lib/python3.10/site-packages (from dandi) (8.1.8)
Requirement already satisfied: click-didyoumean in ./cursor_control/lib/python3.10/site-packages (from dandi) (0.3.1)
Requirement already satisfied: dandischema~=0.12.0 in ./cursor_control/lib/python3.10/site-packages (from dandi) (0.12.1)
Requirement already satisfied: etelemetry>=0.2.2 in ./cursor_control/lib/python3.10/site-packages (from dandi) (0.3.1)
Requirement already satisfied: fasteners in ./cursor_control/lib/python3.10/site-packages (from dandi) (0.20)
Requirement already satisfied: fscacher>=0.3.0 in ./cursor_control/lib/python3.10/site-packages (from dandi) (0.4.4)
Requirement already satisfied: hdmf!=3.14.4,!=3.5.0 in ./cursor_control/lib/python3.10/site-packages (from dandi) (4.3.1)
Requirement already satisfied: humanize in ./cursor_control/lib/python3.10/site-packages (from dandi) (4.15.0)
Requirement already satisfied: interleave~=0.3 in ./cursor_control/lib/python3.10/site-packages (from dandi) (0.3.0)
Requirement already satisfied: joblib in ./cursor_control/lib/python3.10/site-packages (from dandi) (1.5.3)
Requirement already satisfied: keyring!=23.9.0 in ./cursor_control/lib/python3.10/site-packages (from dandi) (25.7.0)
Requirement already satisfied: keyrings.alt in ./cursor_control/lib/python3.10/site-packages (from dandi) (5.0.2)
Requirement already satisfied: packaging in ./cursor_control/lib/python3.10/site-packages (from dandi) (26.0)
Requirement already satisfied: platformdirs in ./cursor_control/lib/python3.10/site-packages (from dandi) (4.9.2)
Requirement already satisfied: pycryptodomex in ./cursor_control/lib/python3.10/site-packages (from dandi) (3.23.0)
Requirement already satisfied: pydantic~=2.0 in ./cursor_control/lib/python3.10/site-packages (from dandi) (2.12.5)
Requirement already satisfied: pynwb!=1.1.0,!=2.3.0,>=1.0.3 in ./cursor_control/lib/python3.10/site-packages (from dandi) (3.1.3)
Requirement already satisfied: numcodecs<0.16 in ./cursor_control/lib/python3.10/site-packages (from dandi) (0.13.1)
Requirement already satisfied: nwbinspector!=0.4.32,>=0.4.28 in ./cursor_control/lib/python3.10/site-packages (from dandi) (0.6.5)
Requirement already satisfied: pyout!=0.6.0,>=0.5 in ./cursor_control/lib/python3.10/site-packages (from dandi) (0.8.1)
Requirement already satisfied: python-dateutil in ./cursor_control/lib/python3.10/site-packages (from dandi) (2.9.0.post0)
Requirement already satisfied: requests~=2.20 in ./cursor_control/lib/python3.10/site-packages (from dandi) (2.32.5)
Requirement already satisfied: ruamel.yaml<1,>=0.15 in ./cursor_control/lib/python3.10/site-packages (from dandi) (0.19.1)
Requirement already satisfied: semantic-version in ./cursor_control/lib/python3.10/site-packages (from dandi) (2.10.0)
Requirement already satisfied: tenacity in ./cursor_control/lib/python3.10/site-packages (from dandi) (9.1.4)
Requirement already satisfied: tensorstore in ./cursor_control/lib/python3.10/site-packages (from dandi) (0.1.78)
Requirement already satisfied: urllib3>=2.0.0 in ./cursor_control/lib/python3.10/site-packages (from dandi) (2.6.3)
Requirement already satisfied: yarl~=1.9 in ./cursor_control/lib/python3.10/site-packages (from dandi) (1.22.0)
Requirement already satisfied: zarr<=3.1.5,>=2.10 in ./cursor_control/lib/python3.10/site-packages (from dandi) (2.18.3)
Requirement already satisfied: zarr_checksum~=0.4.0 in ./cursor_control/lib/python3.10/site-packages (from dandi) (0.4.7)
Requirement already satisfied: acres in ./cursor_control/lib/python3.10/site-packages (from bidsschematools~=1.0->dandi) (0.5.0)
Requirement already satisfied: pyyaml in ./cursor_control/lib/python3.10/site-packages (from bidsschematools~=1.0->dandi) (6.0.3)
Requirement already satisfied: jsonschema[format] in ./cursor_control/lib/python3.10/site-packages (from dandischema~=0.12.0->dandi) (4.26.0)
Requirement already satisfied: pydantic-settings in ./cursor_control/lib/python3.10/site-packages (from dandischema~=0.12.0->dandi) (2.13.0)
Requirement already satisfied: numpy>=1.7 in ./cursor_control/lib/python3.10/site-packages (from numcodecs<0.16->dandi) (1.26.4)
Requirement already satisfied: annotated-types>=0.6.0 in ./cursor_control/lib/python3.10/site-packages (from pydantic~=2.0->dandi) (0.7.0)
Requirement already satisfied: pydantic-core==2.41.5 in ./cursor_control/lib/python3.10/site-packages (from pydantic~=2.0->dandi) (2.41.5)
Requirement already satisfied: typing-extensions>=4.14.1 in ./cursor_control/lib/python3.10/site-packages (from pydantic~=2.0->dandi) (4.15.0)
Requirement already satisfied: typing-inspection>=0.4.2 in ./cursor_control/lib/python3.10/site-packages (from pydantic~=2.0->dandi) (0.4.2)
Requirement already satisfied: email-validator>=2.0.0 in ./cursor_control/lib/python3.10/site-packages (from pydantic[email]~=2.4->dandischema~=0.12.0->dandi) (2.3.0)
Requirement already satisfied: charset_normalizer<4,>=2 in ./cursor_control/lib/python3.10/site-packages (from requests~=2.20->dandi) (3.4.4)
Requirement already satisfied: idna<4,>=2.5 in ./cursor_control/lib/python3.10/site-packages (from requests~=2.20->dandi) (3.11)
Requirement already satisfied: certifi>=2017.4.17 in ./cursor_control/lib/python3.10/site-packages (from requests~=2.20->dandi) (2026.1.4)
Requirement already satisfied: multidict>=4.0 in ./cursor_control/lib/python3.10/site-packages (from yarl~=1.9->dandi) (6.7.1)
Requirement already satisfied: propcache>=0.2.1 in ./cursor_control/lib/python3.10/site-packages (from yarl~=1.9->dandi) (0.4.1)
Requirement already satisfied: asciitree in ./cursor_control/lib/python3.10/site-packages (from zarr<=3.1.5,>=2.10->dandi) (0.3.3)
Requirement already satisfied: tqdm>=4.67.1 in ./cursor_control/lib/python3.10/site-packages (from zarr_checksum~=0.4.0->dandi) (4.67.3)
Requirement already satisfied: dnspython>=2.0.0 in ./cursor_control/lib/python3.10/site-packages (from email-validator>=2.0.0->pydantic[email]~=2.4->dandischema~=0.12.0->dandi) (2.8.0)
Requirement already satisfied: ci-info>=0.2 in ./cursor_control/lib/python3.10/site-packages (from etelemetry>=0.2.2->dandi) (0.4.0)
Requirement already satisfied: h5py>=3.1.0 in ./cursor_control/lib/python3.10/site-packages (from hdmf!=3.14.4,!=3.5.0->dandi) (3.15.1)
Requirement already satisfied: pandas<3,>=1.2.0 in ./cursor_control/lib/python3.10/site-packages (from hdmf!=3.14.4,!=3.5.0->dandi) (1.3.4)
Requirement already satisfied: pytz>=2017.3 in ./cursor_control/lib/python3.10/site-packages (from pandas<3,>=1.2.0->hdmf!=3.14.4,!=3.5.0->dandi) (2025.2)
Requirement already satisfied: attrs>=22.2.0 in ./cursor_control/lib/python3.10/site-packages (from jsonschema[format]->dandischema~=0.12.0->dandi) (25.4.0)
Requirement already satisfied: jsonschema-specifications>=2023.03.6 in ./cursor_control/lib/python3.10/site-packages (from jsonschema[format]->dandischema~=0.12.0->dandi) (2025.9.1)
Requirement already satisfied: referencing>=0.28.4 in ./cursor_control/lib/python3.10/site-packages (from jsonschema[format]->dandischema~=0.12.0->dandi) (0.37.0)
Requirement already satisfied: rpds-py>=0.25.0 in ./cursor_control/lib/python3.10/site-packages (from jsonschema[format]->dandischema~=0.12.0->dandi) (0.30.0)
Requirement already satisfied: importlib_metadata>=4.11.4 in ./cursor_control/lib/python3.10/site-packages (from keyring!=23.9.0->dandi) (8.7.1)
Requirement already satisfied: jaraco.classes in ./cursor_control/lib/python3.10/site-packages (from keyring!=23.9.0->dandi) (3.4.0)
Requirement already satisfied: jaraco.functools in ./cursor_control/lib/python3.10/site-packages (from keyring!=23.9.0->dandi) (4.4.0)
Requirement already satisfied: jaraco.context in ./cursor_control/lib/python3.10/site-packages (from keyring!=23.9.0->dandi) (6.1.0)
Requirement already satisfied: zipp>=3.20 in ./cursor_control/lib/python3.10/site-packages (from importlib_metadata>=4.11.4->keyring!=23.9.0->dandi) (3.23.0)
Requirement already satisfied: aiohttp in ./cursor_control/lib/python3.10/site-packages (from nwbinspector!=0.4.32,>=0.4.28->dandi) (3.13.3)
Requirement already satisfied: fsspec in ./cursor_control/lib/python3.10/site-packages (from nwbinspector!=0.4.32,>=0.4.28->dandi) (2026.2.0)
Requirement already satisfied: hdmf-zarr in ./cursor_control/lib/python3.10/site-packages (from nwbinspector!=0.4.32,>=0.4.28->dandi) (0.12.0)
Requirement already satisfied: isodate in ./cursor_control/lib/python3.10/site-packages (from nwbinspector!=0.4.32,>=0.4.28->dandi) (0.7.2)
Requirement already satisfied: natsort in ./cursor_control/lib/python3.10/site-packages (from nwbinspector!=0.4.32,>=0.4.28->dandi) (8.4.0)
Requirement already satisfied: blessed in ./cursor_control/lib/python3.10/site-packages (from pyout!=0.6.0,>=0.5->dandi) (1.30.0)
Requirement already satisfied: six>=1.5 in ./cursor_control/lib/python3.10/site-packages (from python-dateutil->dandi) (1.17.0)
Requirement already satisfied: aiohappyeyeballs>=2.5.0 in ./cursor_control/lib/python3.10/site-packages (from aiohttp->nwbinspector!=0.4.32,>=0.4.28->dandi) (2.6.1)
Requirement already satisfied: aiosignal>=1.4.0 in ./cursor_control/lib/python3.10/site-packages (from aiohttp->nwbinspector!=0.4.32,>=0.4.28->dandi) (1.4.0)
Requirement already satisfied: async-timeout<6.0,>=4.0 in ./cursor_control/lib/python3.10/site-packages (from aiohttp->nwbinspector!=0.4.32,>=0.4.28->dandi) (5.0.1)
Requirement already satisfied: frozenlist>=1.1.1 in ./cursor_control/lib/python3.10/site-packages (from aiohttp->nwbinspector!=0.4.32,>=0.4.28->dandi) (1.8.0)
Requirement already satisfied: wcwidth>=0.6 in ./cursor_control/lib/python3.10/site-packages (from blessed->pyout!=0.6.0,>=0.5->dandi) (0.6.0)
Requirement already satisfied: threadpoolctl>=3.1.0 in ./cursor_control/lib/python3.10/site-packages (from hdmf-zarr->nwbinspector!=0.4.32,>=0.4.28->dandi) (3.6.0)
Requirement already satisfied: more-itertools in ./cursor_control/lib/python3.10/site-packages (from jaraco.classes->keyring!=23.9.0->dandi) (10.8.0)
Requirement already satisfied: backports.tarfile in ./cursor_control/lib/python3.10/site-packages (from jaraco.context->keyring!=23.9.0->dandi) (1.2.0)
Requirement already satisfied: fqdn in ./cursor_control/lib/python3.10/site-packages (from jsonschema[format]->dandischema~=0.12.0->dandi) (1.5.1)
Requirement already satisfied: isoduration in ./cursor_control/lib/python3.10/site-packages (from jsonschema[format]->dandischema~=0.12.0->dandi) (20.11.0)
Requirement already satisfied: jsonpointer>1.13 in ./cursor_control/lib/python3.10/site-packages (from jsonschema[format]->dandischema~=0.12.0->dandi) (3.0.0)
Requirement already satisfied: rfc3339-validator in ./cursor_control/lib/python3.10/site-packages (from jsonschema[format]->dandischema~=0.12.0->dandi) (0.1.4)
Requirement already satisfied: rfc3987 in ./cursor_control/lib/python3.10/site-packages (from jsonschema[format]->dandischema~=0.12.0->dandi) (1.3.8)
Requirement already satisfied: uri-template in ./cursor_control/lib/python3.10/site-packages (from jsonschema[format]->dandischema~=0.12.0->dandi) (1.3.0)
Requirement already satisfied: webcolors>=1.11 in ./cursor_control/lib/python3.10/site-packages (from jsonschema[format]->dandischema~=0.12.0->dandi) (25.10.0)
Requirement already satisfied: arrow>=0.15.0 in ./cursor_control/lib/python3.10/site-packages (from isoduration->jsonschema[format]->dandischema~=0.12.0->dandi) (1.4.0)
Requirement already satisfied: tzdata in ./cursor_control/lib/python3.10/site-packages (from arrow>=0.15.0->isoduration->jsonschema[format]->dandischema~=0.12.0->dandi) (2025.3)
Requirement already satisfied: python-dotenv>=0.21.0 in ./cursor_control/lib/python3.10/site-packages (from pydantic-settings->dandischema~=0.12.0->dandi) (1.2.1)
Requirement already satisfied: ml_dtypes>=0.5.0 in ./cursor_control/lib/python3.10/site-packages (from tensorstore->dandi) (0.5.4)
PATH SIZE DONE DONE% CHECKSUM STATUS MESSAGE
000128/dandiset.yaml skipped no change
000128/sub-Jenkins/sub-Jenkins_ses-full_desc-test_ecephys.nwb error FileExistsError
000128/sub-Jenkins/sub-Jenkins_ses-full_desc-train_behavior+ecephys.nwb error FileExistsError
Summary: 0 Bytes 0 Bytes 1 skipped 1 no change
+694.0 MB 0.00% 2 error 2 FileExistsError
2026-02-16 17:47:12,837 [ ERROR] Encountered 2 errors while downloading. The first error: {'status': 'error', 'message': 'FileExistsError', 'path': '000128/sub-Jenkins/sub-Jenkins_ses-full_desc-test_ecephys.nwb'}
2026-02-16 17:47:12,838 [ INFO] Logs saved in /Users/Siddharth/Library/Logs/dandi-cli/2026.02.17-01.47.11Z-22221.log
Error: Encountered 2 errors while downloading.
# download our dataset
!dandi download https://dandiarchive.org/dandiset/000128 -e overwrite
PATH SIZE DONE DONE% CHECKSUM STATUS MESSAGE
000128/dandiset.yaml skipped no change
000128/sub-Jenkins/sub-Jenkins_ses-full_desc-test_ecephys.nwb 3.4 MB 3.4 MB 100% ok done
000128/sub-Jenkins/sub-Jenkins_ses-full_desc-train_behavior+ecephys.nwb 690.6 MB 690.6 MB 100% ok done
Summary: 694.0 MB 694.0 MB 1 skipped 1 no change
100.00% 2 done
2026-02-16 17:47:39,641 [ INFO] Logs saved in /Users/Siddharth/Library/Logs/dandi-cli/2026.02.17-01.47.13Z-22226.log
## Imports
# %matplotlib widget # uncomment for interactive plots
from nlb_tools.nwb_interface import NWBDataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
# Load dataset
data_path = "000128/sub-Jenkins/" # change this if your datapath is different. you can evaluate the contents of your directory with the command 'ls'
dataset = NWBDataset(data_path, "*train", split_heldout=False)
Below, we can see all of our data, split into 1ms time bins. For each time bin, we have access to a lot of info: true cursor position, eye position, hand position, hand velocity, and spikes across each channel.
# View dataset
dataset.data
| signal_type | cursor_pos | eye_pos | hand_pos | hand_vel | spikes | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| channel | x | y | x | y | x | y | x | y | 1011 | 1021 | ... | 2861 | 2862 | 2871 | 2881 | 2882 | 2911 | 2931 | 2941 | 2951 | 2961 |
| clock_time | |||||||||||||||||||||
| 0 days 00:00:00 | -0.900000 | -5.700000 | 7.2 | 2.0 | -0.714908 | -40.526123 | -2.624567 | 29.977111 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 0 days 00:00:00.001000 | -0.907457 | -5.687027 | 7.2 | 2.1 | -0.717532 | -40.496146 | -2.707321 | 30.577662 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 |
| 0 days 00:00:00.002000 | -0.912768 | -5.672115 | 7.6 | 1.2 | -0.720323 | -40.464968 | -2.872729 | 31.744164 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 0 days 00:00:00.003000 | -0.914050 | -5.653433 | 7.4 | 1.4 | -0.723278 | -40.432658 | -3.019660 | 32.847931 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 0 days 00:00:00.004000 | -0.909980 | -5.629617 | 7.4 | 3.6 | -0.726362 | -40.399272 | -3.059403 | 33.895227 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 0 days 01:55:52.296000 | -114.378901 | -79.712313 | -95.0 | -117.5 | -114.334012 | -114.809976 | 0.905895 | -0.883716 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 0 days 01:55:52.297000 | -114.366164 | -79.728485 | -94.9 | -117.4 | -114.333252 | -114.810622 | 0.598148 | -0.420075 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 0 days 01:55:52.298000 | -114.365911 | -79.749577 | -94.6 | -117.7 | -114.332816 | -114.810816 | 0.218816 | 0.012961 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 0 days 01:55:52.299000 | -114.378419 | -79.774473 | -94.8 | -117.7 | -114.332814 | -114.810596 | -0.212940 | 0.393580 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 0 days 01:55:52.300000 | -114.400000 | -79.800000 | -97.8 | -118.2 | -114.333242 | -114.810029 | -0.427820 | 0.566803 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
6952301 rows × 190 columns
We can also access the dataset trial-by-trial. In each trial, the subject received a new target and restarted the process of reaching a target.
The trial info dataframe has a number of fields containing information about each trial:
- trial_id - a number assigned to each trial during loading
- start_time - time when the trial begins
- end_time - time when the trial ends
- trial_type - the maze configuration that was used for the trial
- trial_version - a number 0-2 indicating which variant of the maze is presented. 0 is 1-target no-barrier, 1 is 1-target with barriers, 2 is 3-target with barriers
- maze_id - a unique identifier for the maze configuration used. Different maze sets were used for each session, so trial_type is not unique across dataset files
- success - whether the trial was successful. In provided training data, unsuccessful trials have already been removed
- target_on_time - time of target presentation
- go_cue_time - time of go cue
- move_onset_time - time of movement onset, calculated offline with robust algorithm
- rt - reaction time in ms
- delay - time between target presentation and go cue in ms
- num_targets - number of targets displayed in the maze
- target_pos - x and y position of the target(s)
- num_barriers - number of barriers in the maze
- barrier_pos - position of the barrier(s). First two values are the x and y positions of the center of the barrier, last two values are the half-width and half-height of the barrier
- active_target - which target is reachable and was hit by the monkey. Its value corresponds to the index of the target in target_pos
dataset.trial_info
| trial_id | start_time | end_time | trial_type | trial_version | maze_id | success | target_on_time | go_cue_time | move_onset_time | rt | delay | num_targets | target_pos | num_barriers | barrier_pos | active_target | split | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 0 days 00:00:00 | 0 days 00:00:03.321000 | 25 | 2 | 84 | True | 0 days 00:00:00.880000 | 0 days 00:00:01.478000 | 0 days 00:00:01.905000 | 427 | 598 | 3 | [[-111, -82], [-108, 81], [118, 72]] | 8 | [[69, 31, 14, 99], [69, 54, 5, 101], [-62, -48... | 2 | val |
| 1 | 1 | 0 days 00:00:03.400000 | 0 days 00:00:06.521000 | 3 | 1 | 3 | True | 0 days 00:00:04.291000 | 0 days 00:00:04.739000 | 0 days 00:00:05.280000 | 541 | 448 | 1 | [[-116, -5]] | 6 | [[-69, -16, 13, 69], [-120, -62, 83, 15], [95,... | 0 | val |
| 2 | 2 | 0 days 00:00:06.600000 | 0 days 00:00:09.856000 | 22 | 1 | 66 | True | 0 days 00:00:07.471000 | 0 days 00:00:07.969000 | 0 days 00:00:08.346000 | 377 | 498 | 1 | [[-82, -86]] | 9 | [[34, -41, 86, 8], [9, -42, 33, 19], [7, -41, ... | 0 | train |
| 3 | 3 | 0 days 00:00:09.900000 | 0 days 00:00:12.946000 | 29 | 2 | 100 | True | 0 days 00:00:10.853000 | 0 days 00:00:11.335000 | 0 days 00:00:11.752000 | 417 | 482 | 3 | [[-109, 2], [2, 82], [132, -65]] | 9 | [[-9, 52, 43, 8], [-50, 91, 14, 64], [-133, -5... | 1 | train |
| 4 | 4 | 0 days 00:00:13 | 0 days 00:00:15.481000 | 21 | 0 | 65 | True | 0 days 00:00:13.687000 | 0 days 00:00:14.235000 | 0 days 00:00:14.507000 | 272 | 548 | 1 | [[27, 82]] | 0 | [] | 0 | val |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 2290 | 2290 | 0 days 01:55:36.600000 | 0 days 01:55:39.796000 | 34 | 1 | 91 | True | 0 days 01:55:37.362000 | 0 days 01:55:38.277000 | 0 days 01:55:38.585000 | 308 | 915 | 1 | [[116, -77]] | 7 | [[66, -43, 30, 9], [-66, 1, 11, 70], [-35, 50,... | 0 | train |
| 2291 | 2291 | 0 days 01:55:39.900000 | 0 days 01:55:42.736000 | 15 | 1 | 75 | True | 0 days 01:55:40.717000 | 0 days 01:55:41.265000 | 0 days 01:55:41.641000 | 376 | 548 | 1 | [[133, -81]] | 9 | [[-33, 47, 37, 6], [-77, 48, 61, 11], [-64, -2... | 0 | train |
| 2292 | 2292 | 0 days 01:55:42.800000 | 0 days 01:55:45.766000 | 23 | 0 | 67 | True | 0 days 01:55:43.465000 | 0 days 01:55:44.396000 | 0 days 01:55:44.714000 | 318 | 931 | 1 | [[94, -86]] | 0 | [] | 0 | train |
| 2293 | 2293 | 0 days 01:55:45.800000 | 0 days 01:55:49.201000 | 25 | 2 | 84 | True | 0 days 01:55:46.631000 | 0 days 01:55:46.663000 | 0 days 01:55:47.616000 | 953 | 32 | 3 | [[-111, -82], [-108, 81], [118, 72]] | 8 | [[69, 31, 14, 99], [69, 54, 5, 101], [-62, -48... | 2 | val |
| 2294 | 2294 | 0 days 01:55:49.300000 | 0 days 01:55:52.301000 | 16 | 0 | 76 | True | 0 days 01:55:50.025000 | 0 days 01:55:50.807000 | 0 days 01:55:51.183000 | 376 | 782 | 1 | [[-118, -83]] | 0 | [] | 0 | val |
2295 rows × 18 columns
Filtering dataset¶
What parts of this large dataset do we want to access?
- We want spike times in each time bin. We can see that we access to spike times across various recording channels. In this dataset, channel numbers starting with 1 are for primary motor cortex, and channel numbers starting with 2 are for the premotor cortex. Since we are curious about true intended motor movement and not movement related planning, we should only consider channels in M1.
- We want to associate spiking rates to intended movement: hand positions and velocities.
To train our decoder, we will take our raw data: spike times in each time bin, and our label: hand velocities. We also want to resample the spikes from 1ms -> 20ms, since 1ms updates are uneccesarily small for cursor input and will allow for updates with less noise.
signal_types = dataset.data.columns.get_level_values('signal_type')
channels = dataset.data.columns.get_level_values('channel')
# Convert channels to numeric (strings become NaN)
channel_nums = pd.to_numeric(channels, errors='coerce')
# Filtering to only include spike channels in primary motor cortex (channels starting with number 1)
spike_mask = (signal_types == 'spikes') #& (channel_nums >= 1000) & (channel_nums < 2000)
# Filtering to only include hand positions and velocities
hand_pos_mask = (signal_types == 'hand_pos')
hand_vel_mask = (signal_types == 'hand_vel')
# Combine all masks
mask = spike_mask | hand_pos_mask | hand_vel_mask
# Apply filter
dataset.data = dataset.data.loc[:, mask]
print(f"Kept {spike_mask.sum()} spike channels")
print(f"Final shape: {dataset.data.shape}")
Kept 92 spike channels Final shape: (6952301, 96)
# Extract each signal type from your dataset
hand_pos = dataset.data['hand_pos']
hand_vel = dataset.data['hand_vel']
spikes = dataset.data['spikes']
# Resample with appropriate aggregation functions
hand_pos_resampled = hand_pos.resample('20ms').mean()
hand_vel_resampled = hand_vel.resample('20ms').mean()
spikes_resampled = spikes.resample('20ms').sum()
# Combine back into single DataFrame with MultiIndex structure
resampled_data = pd.concat(
[hand_pos_resampled, hand_vel_resampled, spikes_resampled],
axis=1,
keys=['hand_pos', 'hand_vel', 'spikes']
)
print(f"Original shape: {dataset.data.shape}")
print(f"Resampled shape: {resampled_data.shape}")
dataset.data = resampled_data
Original shape: (6952301, 96) Resampled shape: (347616, 96)
# Round trial times to nearest 20ms bin
time_columns = ['start_time', 'end_time', 'target_on_time',
'go_cue_time', 'move_onset_time']
for col in time_columns:
dataset.trial_info[col] = (
dataset.trial_info[col] / pd.Timedelta('20ms')
).round() * pd.Timedelta('20ms')
dataset.data # take a peek at our new resampled, 20ms bins!
| hand_pos | hand_vel | spikes | |||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| channel | x | y | x | y | 1011 | 1021 | 1022 | 1031 | 1032 | 1033 | ... | 1801 | 1812 | 1831 | 1841 | 1851 | 1861 | 1881 | 1891 | 1901 | 1902 |
| clock_time | |||||||||||||||||||||
| 0 days 00:00:00 | -0.737494 | -40.185452 | -1.339941 | 38.272109 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 0 days 00:00:00.020000 | -0.653940 | -39.304780 | 9.786980 | 46.763506 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 0 days 00:00:00.040000 | -0.529994 | -38.465552 | -3.312316 | 34.765720 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 |
| 0 days 00:00:00.060000 | -0.831260 | -37.897353 | -23.142283 | 23.071452 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 |
| 0 days 00:00:00.080000 | -1.309288 | -37.592735 | -23.083476 | 4.854699 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 0 days 01:55:52.220000 | -114.379208 | -114.806499 | 0.526060 | -1.209532 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 0 days 01:55:52.240000 | -114.367861 | -114.786945 | 0.344761 | 4.817995 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 0 days 01:55:52.260000 | -114.370232 | -114.693716 | -0.025136 | 0.021171 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 0 days 01:55:52.280000 | -114.347036 | -114.782456 | 1.706303 | -4.396547 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 0 days 01:55:52.300000 | -114.333242 | -114.810029 | -0.427820 | 0.566803 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
347616 rows × 96 columns
dataset.trial_info # make sure all our trial start and end_times are alligned with 20ms boundaries
| trial_id | start_time | end_time | trial_type | trial_version | maze_id | success | target_on_time | go_cue_time | move_onset_time | rt | delay | num_targets | target_pos | num_barriers | barrier_pos | active_target | split | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 0 days 00:00:00 | 0 days 00:00:03.320000 | 25 | 2 | 84 | True | 0 days 00:00:00.880000 | 0 days 00:00:01.480000 | 0 days 00:00:01.900000 | 427 | 598 | 3 | [[-111, -82], [-108, 81], [118, 72]] | 8 | [[69, 31, 14, 99], [69, 54, 5, 101], [-62, -48... | 2 | val |
| 1 | 1 | 0 days 00:00:03.400000 | 0 days 00:00:06.520000 | 3 | 1 | 3 | True | 0 days 00:00:04.300000 | 0 days 00:00:04.740000 | 0 days 00:00:05.280000 | 541 | 448 | 1 | [[-116, -5]] | 6 | [[-69, -16, 13, 69], [-120, -62, 83, 15], [95,... | 0 | val |
| 2 | 2 | 0 days 00:00:06.600000 | 0 days 00:00:09.860000 | 22 | 1 | 66 | True | 0 days 00:00:07.480000 | 0 days 00:00:07.960000 | 0 days 00:00:08.340000 | 377 | 498 | 1 | [[-82, -86]] | 9 | [[34, -41, 86, 8], [9, -42, 33, 19], [7, -41, ... | 0 | train |
| 3 | 3 | 0 days 00:00:09.900000 | 0 days 00:00:12.940000 | 29 | 2 | 100 | True | 0 days 00:00:10.860000 | 0 days 00:00:11.340000 | 0 days 00:00:11.760000 | 417 | 482 | 3 | [[-109, 2], [2, 82], [132, -65]] | 9 | [[-9, 52, 43, 8], [-50, 91, 14, 64], [-133, -5... | 1 | train |
| 4 | 4 | 0 days 00:00:13 | 0 days 00:00:15.480000 | 21 | 0 | 65 | True | 0 days 00:00:13.680000 | 0 days 00:00:14.240000 | 0 days 00:00:14.500000 | 272 | 548 | 1 | [[27, 82]] | 0 | [] | 0 | val |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 2290 | 2290 | 0 days 01:55:36.600000 | 0 days 01:55:39.800000 | 34 | 1 | 91 | True | 0 days 01:55:37.360000 | 0 days 01:55:38.280000 | 0 days 01:55:38.580000 | 308 | 915 | 1 | [[116, -77]] | 7 | [[66, -43, 30, 9], [-66, 1, 11, 70], [-35, 50,... | 0 | train |
| 2291 | 2291 | 0 days 01:55:39.900000 | 0 days 01:55:42.740000 | 15 | 1 | 75 | True | 0 days 01:55:40.720000 | 0 days 01:55:41.260000 | 0 days 01:55:41.640000 | 376 | 548 | 1 | [[133, -81]] | 9 | [[-33, 47, 37, 6], [-77, 48, 61, 11], [-64, -2... | 0 | train |
| 2292 | 2292 | 0 days 01:55:42.800000 | 0 days 01:55:45.760000 | 23 | 0 | 67 | True | 0 days 01:55:43.460000 | 0 days 01:55:44.400000 | 0 days 01:55:44.720000 | 318 | 931 | 1 | [[94, -86]] | 0 | [] | 0 | train |
| 2293 | 2293 | 0 days 01:55:45.800000 | 0 days 01:55:49.200000 | 25 | 2 | 84 | True | 0 days 01:55:46.640000 | 0 days 01:55:46.660000 | 0 days 01:55:47.620000 | 953 | 32 | 3 | [[-111, -82], [-108, 81], [118, 72]] | 8 | [[69, 31, 14, 99], [69, 54, 5, 101], [-62, -48... | 2 | val |
| 2294 | 2294 | 0 days 01:55:49.300000 | 0 days 01:55:52.300000 | 16 | 0 | 76 | True | 0 days 01:55:50.020000 | 0 days 01:55:50.800000 | 0 days 01:55:51.180000 | 376 | 782 | 1 | [[-118, -83]] | 0 | [] | 0 | val |
2295 rows × 18 columns
Introduction to the Kalman Filter¶
I strongly reccomend you read Parts 1, 2, and 3 of Alana Zucconi's guide to Kalman Filters: https://www.alanzucconi.com/2022/07/24/kalman-filter-1/. I will use terminology and concepts that he introduces in these readings.
Our Kalman Filter will allow us to estimate hand positions and velocities (and therefore INTENDED cursor position and velocities) directly from neural signals. It does this by assuming all sources of noise in the system are gaussian, and that system dynamics are linear. We can define:
Our state transition model (how does hand position and velocity change over time?): $$x_n = Ax_{n-1} + w_n$$ where $x_n$ is our state that encodes position and velocity: $x_n = [x, y, v_x, v_y]$, $A$ is a precomputed state transition matrix (given some set of positions and velocities, how can we predict the next set of positions and velocities?), and $w_n$ is gaussian process noise (the way that are system changes is not perfect, and its noise can be approximated to be gaussian), $w_n \sim N(0, Q)$ where Q is the variance of our process noise.
Our observation model: it's important to note we have an indirect observation model (we don't measure hand positions and velocities directly, but rather we must infer them from neural signals). $$z_n = Cx_n + v_n$$ where $C$ is a matrix that maps neural data to state space (4D) and $v_n$ is gaussian measurement noise, $v_n \sim N(O, R)$ where R is the variance of such noise. This noise exists because the manner in which neurons fire and how we measure them is not perfect - noise is involved!
Precomputing values for Filter¶
The Kalman Filter requires us to precompute some values before using it for inference:
- The state transition model requires us to precompute $A$, the matrix that defines how hand positions and velocities evolve in the absence of any observations, and $Q$ which defines the variance of gaussian noise associated with this evolution.
- The observation model requires us to precompute $C$, the matrix that defines how neural signals influence changes in state (x, y, $v_x, v_y$), and $R$, the variance of gaussian noise associated with these measurements.
- Define our initial conditions: what is our mean initial state (given that targets are dispersed across the whole screen, the ideal spot is probably at the center of the screen), and what is our initial covariance of our estimated state $\hat{x}$, $P$. This is usually 1.
Learn some more about the nature of spiking in our dataset!
spike_data = dataset.data['spikes']
print("Spike count statistics:")
mean_spikes = np.mean(spike_data.values)
max_spikes = np.max(spike_data.values)
print(f"Mean: {mean_spikes:.6f} spikes/20ms")
print(f"Max: {max_spikes:.1f}")
total_entries = spike_data.size
total_spikes = mean_spikes * total_entries
print(f"Total spikes: {total_spikes:.0f}")
print(f"% zeros: {100 * (spike_data.values == 0).sum() / spike_data.size:.2f}%")
Spike count statistics: Mean: 0.044067 spikes/20ms Max: 6.0 Total spikes: 1409305 % zeros: 95.88%
Calculating $A$ and $Q$ for our state transition model¶
To find how hand positions and velocities change with time, we can use linear regression. However, we must keep this in mind: we don't want the linear regression to consider the transition from the last time bin of one trial -> first time bin of the next, because this could look like the hand transporting to the center of the screen!
To do this, we can extract pairs (state at t - 1, state at t) across all trials while excluding cross-trial transitions. We can then use least squares method to solve for $A$ in: $$x_n = Ax_{n-1} + w_t$$ $Q$ is simply the $Cov(w_n)$, and we can compute $w_n = x_n - Ax_{n-1}$.
Transitions we don't want to include:
- from last bin of one trial -> first trial of next, as discussed
- transitions that have NaN data! Some data is padded with NaN values, and we can't consider those in linear regression
# Extract state variables: [x, y, vx, vy]
states = pd.DataFrame({
'x': dataset.data[('hand_pos', 'x')],
'y': dataset.data[('hand_pos', 'y')],
'vx': dataset.data[('hand_vel', 'x')],
'vy': dataset.data[('hand_vel', 'y')]
})
print(f"Original: {len(states):,} 20ms time bins")
Original: 347,616 20ms time bins
# === MARK TRIAL BOUNDARIES FIRST === -> we don't want out linear regression model to learn from cross-trial transitions
is_trial_end = states.index.isin(dataset.trial_info['end_time'])
print(f"Trial boundaries: {is_trial_end.sum():,}")
# === CHECK NaN ===
n_nan = states.isna().sum().sum()
n_rows_with_nan = states.isna().any(axis=1).sum()
print(f"NaN: {n_nan:,} values in {n_rows_with_nan:,} rows")
# === CREATE SHIFTED DATA ===
states_next = states.shift(-1)
# === VALIDITY MASK ===
# Invalid if: NaN at t, NaN at t+1, OR trial boundary
has_nan_current = states.isna().any(axis=1)
has_nan_next = states_next.isna().any(axis=1)
# === FILTER VALID TRANSITIONS ===
valid = ~(has_nan_current | has_nan_next | is_trial_end)
X_prev = states[valid].values
X_next = states_next[valid]
print(f"Valid transitions: {valid.sum():,} / {len(states):,}")
# === COMPUTE A (STATE TRANSITION MATRIX) ===
A = np.linalg.lstsq(X_prev, X_next, rcond=None)[0].T
# === COMPUTE Q (PROCESS NOISE COVARIANCE) ===
residuals = X_next - X_prev @ A.T
Q = np.cov(residuals.T)
# Validation
r2 = 1 - np.sum(residuals**2, 0) / np.sum((X_next - X_next.mean(0))**2, 0)
print(f"\nA matrix:\n{A}")
print(f"\nQ diagonal: {np.diag(Q)}")
print(f"R²: x={r2[0]:.4f}, y={r2[1]:.4f}, vx={r2[2]:.4f}, vy={r2[3]:.4f}")
Trial boundaries: 2,295 NaN: 23,280 values in 5,820 rows Valid transitions: 339,501 / 347,616 A matrix: [[ 9.99066658e-01 1.98899671e-05 1.99022995e-02 -5.28341156e-04] [-3.06785927e-04 9.99061807e-01 4.35695009e-04 1.97851797e-02] [-8.86301583e-02 7.95444853e-04 9.80171017e-01 -5.24464490e-02] [-2.93817432e-02 -9.01933105e-02 4.30288463e-02 9.66416635e-01]] Q diagonal: [1.13212049e-01 1.13944317e-01 1.10393974e+03 1.11079057e+03] R²: x=1.0000, y=0.9999, vx=0.9439, vy=0.9305
We can also tune Q based on empirical results. If your entries for Q are large --> filter reacts more to noisy observations (choppy movement) If your entries for Q are small --> filter reacts less to noisy observations (slower, smoother movement)
# Scale factors: <1 = trust prediction more (smoother), >1 = trust observations more (more responsive)
q_scale = np.diag([0.02, 0.02, 3.0, 4.0]) # dampen position, amplify velocity
Q_tuned = q_scale @ Q @ q_scale
print(f"\nQ original diagonal: {np.diag(Q)}")
print(f"Q tuned diagonal: {np.diag(Q_tuned)}")
# Use the tuned Q going forward
Q = Q_tuned
Q original diagonal: [1.13212049e-01 1.13944317e-01 1.10393974e+03 1.11079057e+03] Q tuned diagonal: [4.52848195e-05 4.55777269e-05 9.93545762e+03 1.77726492e+04]
Computing C and R for observation model¶
$C$ is the matrix that, when multiplied with a state and added to some noise, gives us our neural data. Essentially, we want to run linear regression on our equation: $$z_n = Cx_n + d + v_n$$ $R$ is simply the $Cov(v_n)$, and we can compute $v_n = x_n - Cx_{n-1}$. $d$ is our bias vector.
We will first apply a preprocessing pipeline for our observation model¶
Feeding completely raw spike data into a Kalman filter would create some problem:
- Spike counts are Poisson, not gaussian. Use the sqrt transform to make the spikes more gaussian.
- We have lots of noisy channels with 94% zeroes -> we can use PCA to concatenate signal into a few independent sources
- There is a natural lag in neural data, where neurons fire before observed change in the state (hand moving). So, if we are t time n and want to know the hand's current position and velocity, the most informative neural data is not at time n, it is at time n - 5 (if we assume neural data lag is ~100ms). So, instead of a single observation $z_n$ per time step we include current AND past observations into a vector $z_{lagged, t}$. Now, when we fit matrix $C$, the model can learn if neural data from previous time steps predicts our state better!
# === OBSERVATION MODEL WITH TIME-LAGGED SPIKE HISTORY ===
# Key insight: neural activity PRECEDES movement by ~100-150ms (motor planning).
# Using only z_t to predict x_t misses this temporal structure.
# By concatenating [z_t, z_{t-1}, ..., z_{t-n_lags}], C can learn the lead-lag relationship.
# --- Step 1: Square-root transform ---
neural_obs_raw = dataset.data['spikes']
neural_obs_sqrt = np.sqrt(neural_obs_raw)
# --- Step 2: PCA ---
n_pca_components = 20
is_trial_end = states.index.isin(dataset.trial_info['end_time'])
has_nan_state = states.isna().any(axis=1)
has_nan_neural = neural_obs_sqrt.isna().any(axis=1)
valid_base = ~(has_nan_state | has_nan_neural | is_trial_end)
pca = PCA(n_components=n_pca_components)
pca.fit(neural_obs_sqrt[valid_base].values)
print(f"PCA variance explained: {pca.explained_variance_ratio_.sum():.1%} with {n_pca_components} components")
neural_pca_all = pd.DataFrame(
pca.transform(neural_obs_sqrt.values),
index=neural_obs_sqrt.index,
columns=[f'pc{i}' for i in range(n_pca_components)]
)
dataset.data_neural_pca = neural_pca_all
# --- Step 3: Build time-lagged observation matrix ---
# n_lags=5 at 20ms = 100ms of history, capturing the neural-to-movement delay
n_lags = 5
print(f"Using {n_lags} lags ({n_lags * 20}ms history)")
Z_pca = neural_pca_all.values # (T, n_pca_components)
# Stack [z_t, z_{t-1}, ..., z_{t-n_lags}] into a wide matrix
Z_lagged_list = []
for lag in range(n_lags + 1): # lag 0 through n_lags
Z_lagged_list.append(np.roll(Z_pca, lag, axis=0))
Z_lagged = np.hstack(Z_lagged_list) # (T, n_pca_components * (n_lags+1))
n_obs_features = Z_lagged.shape[1]
print(f"Observation dimension: {n_obs_features} ({n_pca_components} PCA × {n_lags+1} time steps)")
# Validity mask: exclude first n_lags rows (rolled data is invalid), NaN rows, trial boundaries
valid_lag = np.ones(len(Z_lagged), dtype=bool)
valid_lag[:n_lags] = False
# Also exclude rows near trial boundaries (lags could cross trials)
for _, trial in dataset.trial_info.iterrows():
trial_start_idx = states.index.get_indexer([trial['start_time']], method='nearest')[0]
# Mark the first n_lags rows of each trial as invalid
for offset in range(n_lags):
idx = trial_start_idx + offset
if 0 <= idx < len(valid_lag):
valid_lag[idx] = False
valid = valid_base.values & valid_lag
X = states.values[valid] # (n_valid, 4)
Z = Z_lagged[valid] # (n_valid, n_obs_features)
print(f"Valid observations: {valid.sum():,}")
PCA variance explained: 58.0% with 20 components Using 5 lags (100ms history) Observation dimension: 120 (20 PCA × 6 time steps) Valid observations: 329,212
Now that we have preprocessed data, we can find $C$ and $R$.
# --- Step 4: Fit C with Ridge regression (centered, with bias) ---
x_mean = X.mean(axis=0)
z_mean = Z.mean(axis=0)
X_centered = X - x_mean
Z_centered = Z - z_mean
ridge = Ridge(alpha=10.0) # stronger regularization for higher-dimensional Z
ridge.fit(X_centered, Z_centered)
C = ridge.coef_ # (n_obs_features, 4)
d = z_mean - C @ x_mean
print(f"✓ C matrix: {C.shape}")
print(f"✓ d (bias) vector: {d.shape}")
# --- Step 5: Compute R and scale it ---
Z_pred = X @ C.T + d
residuals = Z - Z_pred
R = np.cov(residuals.T) + 1e-6 * np.eye(n_obs_features)
# --- Validation ---
ss_res = np.sum(residuals**2, axis=0)
ss_tot = np.sum((Z - Z.mean(axis=0))**2, axis=0)
r2_per_component = 1 - ss_res / ss_tot
print(f"\nObservation model quality (lagged PCA):")
print(f" Mean R²: {r2_per_component.mean():.4f}")
print(f" Best R²: {r2_per_component.max():.4f}")
print(f" Components with R² > 0.01: {(r2_per_component > 0.01).sum()} / {n_obs_features}")
✓ C matrix: (120, 4) ✓ d (bias) vector: (120,) Observation model quality (lagged PCA): Mean R²: 0.0083 Best R²: 0.0327 Components with R² > 0.01: 30 / 120
Just like we did for $Q$, we can scale R based on how our empirical results pan out. Scaling down R trusts our observations more.
# R SCALING: reduce R to increase Kalman gain and make filter more responsive.
# The empirical R overestimates noise because our linear model can't capture
# all neural-to-movement structure. Scaling down trusts observations more.
r_scale = 0.5
R = R * r_scale
print(f"\n✓ R matrix: {R.shape} (scaled by {r_scale})")
✓ R matrix: (120, 120) (scaled by 0.5)
Defining initial conditions¶
# Initial state: use the mean state from training data as a reasonable prior
x0 = x_mean.copy() # better than zeros - this is where the hand typically is
P0 = np.diag([100, 100, 1000, 1000]) # keep reasonable initial uncertainty
print(f"x0 (initial state): {x0}")
print(f"P0 diagonal: {np.diag(P0)}")
x0 (initial state): [ 4.77250598 -36.45700461 6.27394769 -0.10812666] P0 diagonal: [ 100 100 1000 1000]
Implementing the Kalman Filter¶
Kalman Filters are great because if your state requires an update every 5ms, but you only receieve observations to update your state every 20ms, the filter falls back on your state transition model while it's not receiving the observations.
In this model, let's do the following:
- A prediction step every 5ms: every 5ms, the kalman filter updates the estimated state $\hat{x}_n$ using our state transition model $\hat{x}_n = A\hat{x}_{n-1}$.
- A correction step every 20ms: every 20ms, the kalman filter reads spike counts from each channel and we receive a num_channels dimensional vector $z_n = Cx_n + v_n$.
How does this correction step work? The goal of the prediction step is to update $\hat{x}_n$ given $z_n$ and $\hat{x}_{n-1}$. Here are the steps:
- Carry out another prediction step to find an estimate for $\hat{x}_n$.
- Project this estimation to our observation space (what neural firing rates do we expect to observe with this estimation?) -> expected observation.
- We then compare expected observation with our real observation $z_n$, and find the difference between the two to get $y_n$. You may guess that if $y_n$ is high, our kalman gain will be low (trust this observation less!) and vice versa.
There's a lot of math involved in derivating the equations for kalman gain and covariance updates after step 3. I hope to add this sometime soon, but for now: $$\hat{x}_n = \hat{x}_{n-1} + K(y_n)$$ $$K_n = P_{n | n-1} C^T S^{-1}_n = P_{n | n-1} C^T (CP_{n|n-1}C^T + R)^-1$$ $$P_n = (I - K_nC)P_{n|n-1}$$ Where $P_{n|n-1}$ is the estimate for P via a prediction step.
Let's pick a specific trial to try the filter out on, and define a function so we can call it during evaluation easily.
def kalman_filter(trial_data, trial_spikes_pca, A, Q, R, C, d, x0, P0, n_lags=5, dt=0.020):
"""
Kalman Filter with time-lagged neural observations.
Parameters
----------
trial_spikes_pca : ndarray (n_steps, n_pca) - PCA-transformed spike counts
n_lags : int - number of past timepoints to include in observation vector
"""
n_steps = len(trial_data)
n_state = 4
n_pca = trial_spikes_pca.shape[1]
x_est = np.zeros((n_steps, n_state))
P_history = np.zeros((n_steps, n_state, n_state))
time_array = np.arange(n_steps) * dt
x_est[0] = x0
P_history[0] = P0
x_true = trial_data.values
# Ring buffer for maintaining lag history
z_history = np.zeros((n_lags + 1, n_pca))
z_history[0] = trial_spikes_pca[0]
for k in range(1, n_steps):
# === PREDICT ===
x_pred = A @ x_est[k-1]
P_pred = A @ P_history[k-1] @ A.T + Q
# === BUILD LAGGED OBSERVATION ===
# Shift history back and insert new observation
z_history[1:] = z_history[:-1] # shift older entries
z_history[0] = trial_spikes_pca[k] # newest at index 0
# Concatenate: [z_t, z_{t-1}, ..., z_{t-n_lags}]
z = z_history.flatten()
# === CORRECT ===
z_pred = C @ x_pred + d
innovation = z - z_pred
S = C @ P_pred @ C.T + R
K = P_pred @ C.T @ np.linalg.inv(S)
x_est[k] = x_pred + K @ innovation
P_history[k] = (np.eye(n_state) - K @ C) @ P_pred
return {
'time': time_array,
'x_est': x_est,
'x_true': x_true,
'P': P_history,
}
Evaluating our Kalman Filter¶
We can track how our Kalman Filter estimates for state fare against the true state values by graphing. Let's first define a function for graphing kalman results against the real results:
def plot_kalman_results(results, trial_id):
"""
results: the return value from our kalman filter function
trial_id: which trial ID do we wish to test with?
"""
fig, axes = plt.subplots(4, 1, figsize=(14,12))
time = results['time']
x_est = results['x_est']
x_true = results['x_true']
labels = ['x position (mm)', 'y position (mm)', 'x velocity (mm/s)', 'y velocity (mm/s)']
for i, (ax, label) in enumerate(zip(axes, labels)):
ax.plot(time, x_true[:, i], 'k-', label='True', linewidth=2, alpha=0.7)
ax.plot(time, x_est[:, i], 'r--', label='Kalman Estimate', linewidth=1.5)
ax.set_ylabel(label, fontsize=10)
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
axes[-1].set_xlabel('Time (seconds)', fontsize=11)
fig.suptitle(f'Kalman Filter Performance - Trial {trial_id}\n' +
'20ms Prediction & Observation Steps',
fontsize=14, fontweight='bold')
plt.tight_layout()
return fig
Let's decide to test an arbitrary trial: trial 8.
# === Run Kalman Filter ===
trial_info = dataset.trial_info.iloc[3]
start_time = trial_info['start_time']
end_time = trial_info['end_time']
trial_data = dataset.data.loc[start_time:end_time, [
('hand_pos', 'x'),
('hand_pos', 'y'),
('hand_vel', 'x'),
('hand_vel', 'y')
]]
trial_spikes_raw = dataset.data.loc[start_time:end_time, 'spikes']
trial_spikes_sqrt = np.sqrt(trial_spikes_raw.values)
trial_spikes_pca = pca.transform(trial_spikes_sqrt)
print(f"Trial states: {trial_data.shape}")
print(f"Trial neural (PCA): {trial_spikes_pca.shape}")
print(f"C: {C.shape}, d: {d.shape}, R: {R.shape}")
results = kalman_filter(trial_data, trial_spikes_pca, A, Q, R, C, d, x0, P0, n_lags=n_lags, dt=0.020)
time = results['time']
x_est = results['x_est']
x_true = results['x_true']
# === RMS Error ===
rms = np.sqrt(np.mean((x_est - x_true)**2, axis=0))
labels_short = ['x pos', 'y pos', 'x vel', 'y vel']
print("\n=== RMS Error ===")
for i, label in enumerate(labels_short):
print(f" {label}: {rms[i]:.2f}")
# === Plot ===
plot_kalman_results(results, 7)
Trial states: (153, 4) Trial neural (PCA): (153, 20) C: (120, 4), d: (120,), R: (120, 120) === RMS Error === x pos: 54.30 y pos: 101.68 x vel: 412.76 y vel: 461.31