Close
About
FAQ
Home
Collections
Login
USC Login
Register
0
Selected
Invert selection
Deselect all
Deselect all
Click here to refresh results
Click here to refresh results
USC
/
Digital Library
/
University of Southern California Dissertations and Theses
/
Interaction between Artificial Intelligence Systems and Primate Brains
(USC Thesis Other)
Interaction between Artificial Intelligence Systems and Primate Brains
PDF
Download
Share
Open document
Flip pages
Contact Us
Contact Us
Copy asset link
Request this asset
Transcript (if available)
Content
Interaction Between
Artificial Intelligence Systems and Primate Brains
by
Shixian Wen
A Dissertation Presented to the
FACULTY OF THE USC GRADUATE SCHOOL
UNIVERSITY OF SOUTHERN CALIFORNIA
In Partial Fulfillment of the
Requirements for the Degree
DOCTOR OF PHILOSOPHY
(COMPUTER SCIENCE)
May 2022
Copyright 2022 Shixian Wen
ii
Table of Contents
List of T ables ................................................................................................... vi
List of Figures ................................................................................................. vi
Abstract ......................................................................................................... ix
1. Rapid transfer of brain-computer interfaces to new neuronal ensembles or
participants via generative modelling ..................................................................................... 1
1.1 Abstract .................................................................................................. 1
1.2 Introduction ............................................................................................. 1
1.3 Results .................................................................................................... 4
1.3.1 Experimental setup and data preparation ....................................... 4
1.3.2 Spike synthesizer and general structure ......................................... 6
1.3.3 Validation of spike synthesizer using a random reaching task ........ 8
1.3.4 Synthesized neural spikes can accelerate the training and improve
generalization of cross-session and cross-subject decoding ................. 14
1.4 Discussion ............................................................................................ 20
1.5 Methods ................................................................................................ 23
1.5.1 The BCI decoder ............................................................................ 23
1.5.2 Bidirectional LSTM ......................................................................... 24
1.5.3 Generative adversarial network (GAN) ......................................... 25
1.5.4 Constrained Conditional LSTM GAN (cc-LSTM-GAN, the spike
synthesizer) ............................................................................................. 26
1.5.5 Fine tuning...................................................................................... 30
1.5.6 Correlation ...................................................................................... 30
1.5.7 Data Augmentation ......................................................................... 31
1.5.8 Hyperparameters ............................................................................ 33
2. Capturing spike train temporal pattern with Wavelet Average Coefficient for
Brain Machine Interface ................................................................................................................ 33
2.1 Abstract ............................................................................................... 33
2.2 Introduction ......................................................................................... 34
2.3 Experiment paradigm.......................................................................... 36
2.4 Methods ................................................................................................ 37
2.4.1 Wavelet Framework ....................................................................... 37
iii
2.4.2 Kernel Function Module ................................................................. 39
2.4.3 Discrete Wavelet Transform ........................................................... 39
2.4.4 Preprocessing Module for Generating WAC................................... 42
2.4.5 Comparison between trend feature Q and spike counts ................. 42
2.4.6 Sliding Window for Wiener and Kalman Filter ................................ 43
2.4.7 LSTM decoder using WAC as inputs .............................................. 45
2.5 Results .................................................................................................. 45
2.5.1 Sliding Window improves decoding performances of the classical
Wiener and Kalman filters in high temporal resolution ............................ 45
2.5.2 Wavelet framework further improves the performance of Kalman
and Wiener filters augmented by slide windows ...................................... 47
2.5.3 Sliding window size correlates with movement frequency .............. 47
2.5.4 Using WAC as inputs improves the decoding performance of the
LSTM decoder in high temporal resolution .............................................. 49
2.6 Discussion ............................................................................................ 50
2.7 Materials and Method ........................................................................... 52
3. Beneficial Perturbation Network for designing general adaptive artificial
intelligence systems ........................................................................................................................ 52
3.1 Abstract ................................................................................................ 52
3.2 Introduction ........................................................................................... 53
3.3 Types of methods for enabling lifelong learning .................................... 57
3.4 Adversarial directions and perturbations ............................................... 59
3.5 Beneficial directions and perturbations, & The effects of beneficial
perturbations in multitask sequential learning scenario .............................. 60
3.6 Beneficial perturbation network ............................................................. 64
3.7 Experiment ............................................................................................ 69
3.7.1 Experimental Setup for Incremental Tasks ..................................... 69
3.7.2 Experimental Setup for Eight Sequential Object Recognition Tasks
................................................................................................................ 70
3.7.3 Experimental Setup for 100 permuted MNIST dataset ................... 71
3.7.4 Our model and baselines ............................................................... 72
3.8 Results .................................................................................................. 74
iv
3.8.1 The beneficial perturbations can bias the network and maintain the
decision boundary ................................................................................... 74
3.8.2 Quantitative analysis for incremental tasks .................................... 76
3.8.3 Quantitative analysis for eight sequential object recognition tasks . 78
3.8.4 Quantitative analysis for 100 permuted MNIST dataset ................. 79
3.9 Discussion ............................................................................................ 81
4. Beneficial Perturbation Network for Defending Adversarial Examples ........ 84
4.1 Abstract ................................................................................................ 84
4.2 Introduction ........................................................................................... 85
4.3 Related work – adversarial training ....................................................... 89
4.4 Beneficial Perturbation Network ............................................................ 91
4.4.1 High-level ideas – difference between BPN and adversarial training
in fixing distribution drifts of input data .................................................... 91
4.4.2 Formulation of beneficial perturbations .......................................... 94
4.4.3 Creation of reverse adversarial attack ............................................ 95
4.4.4 Computation Costs ....................................................................... 96
4.4.5 Loss Function, Forward and Backward Rules ................................ 97
4.4.6 Extending BPN to deep convolutional networks ............................. 97
4.5 Experiments .......................................................................................... 98
4.5.1 Datasets ......................................................................................... 98
4.5.2 Network Structure ........................................................................... 99
4.5.3 Various attack methods .................................................................. 99
4.6 Beneficial Perturbation Network .......................................................... 100
4.6.1 In scenario 1: BPN can defend adversarial examples with additional
negligible computational costs .............................................................. 100
4.6.2 In scenario 2: BPN can alleviate the decay of clean sample
accuracy ................................................................................................ 101
4.6.3 In scenario 3: BPN can alleviate the decay of clean sample
accuracy ................................................................................................ 102
4.6.4 BPN can generalize to unseen attacks that it has never been trained
on .......................................................................................................... 103
4.7 Discussion .......................................................................................... 105
4.7.1 Intriguing property of beneficial perturbations ............................... 105
v
4.7.2 Beneficial perturbations: The opposite “twins” of adversarial
perturbations ......................................................................................... 106
5. What can we learn from misclassified ImageNet images? ............................. 107
5.1 Abstract .............................................................................................. 107
5.2 Motivation ........................................................................................... 108
5.3 Observations and General Framework ............................................... 109
5.4 Our approach .................................................................................... 114
5.4.1 The Superclassing ImageNet dataset .......................................... 114
5.4.2 Super-Sub framework ................................................................. 115
5.4.2.1 Vanilla Network Architectures and Vanilla Inference Rules . 115
5.4.2.2 Efficient Network Architectures and Efficient Inference Rules
........................................................................................................ 117
5.5 Experiments ........................................................................................ 120
5.6 Results and Analysis .......................................................................... 120
5.7 Discussion .......................................................................................... 124
6. Conclusion .......................................................................................................................................... 125
7. Acknowledgements ....................................................................................................................... 125
8. References ......................................................................................................................................... 126
9. Supplementary ................................................................................................................................. 142
9.1 Supplementary for rapid transfer of brain-machine interfaces to new
neuronal ensembles or participants .......................................................... 142
9.1.1 Supplementary Figures ................................................................ 142
9.1.2 Supplementary Discussion ........................................................... 154
9.2 Supplementary for capturing spike train temporal pattern with Wavelet
Average Coefficient for Brain Machine Interface....................................... 155
9.2.1 Recursive equation of Wavelet framework for Kalman filter with
sliding window augmentation ................................................................ 156
9.2.2 Classical Kalman filters with sliding window augmentation .......... 156
9.2.3 Supplementary Figures ................................................................ 158
9.3 Supplementary for Beneficial Perturbation Network for designing
general adaptive artificial intelligence systems ......................................... 161
9.3.1 Clarification of memory storage costs .......................................... 161
9.3.2 Clarification of parameter costs .................................................... 162
vi
9.3.3 Choice of Hyperparameter ........................................................... 163
9.3.4 Difference between Transfer Learning and Continual Learning ... 163
9.3.5 Algorithms for BD+PSP .............................................................. 164
List of Tables
Table 1.1 Hyperparameters ..................................................................... 33
Table 3.1 Task 1 performance with "single-head" evaluation ................... 77
Table 3.2 Test accuracy (in percent correct) achieved by each method
with "multi-head" evaluation ..................................................................... 78
Table 4.1. Computation costs of BPN trained on clean examples ........... 96
Table 4.2 Scenario 1 Training on only clean examples for both BPN and
classical network .................................................................................. 101
Table 4.3. Scenario 2: Training on only adversarial examples for both
BPN and classical network .................................................................... 102
Table 4.4. Scenario 3: Training on both clean and adversarial examples
for both BPN and classical network ....................................................... 103
Table 4.5. BPN can generalize to unseen attacks ................................. 104
List of Figures
Figure 1. The big picture of my research. Interaction between Primate
Brains and Artificial Intelligence Systems ............................................... ix
Figure 2. From A.I. to brain ........................................................................ x
Figure 3. From brain to A.I. ...................................................................... x
Figure 1.1 Experimental paradigm and training the baseline BCI LSTM
decoder. ..................................................................................................... 5
Figure 1.2. General framework .................................................................. 7
Figure 1.3. Normalized position activity map ............................................. 9
Figure 1.4. Firing rates ............................................................................. 12
Figure 1.5. Time series of binned spike counts ....................................... 13
Figure 1.6. Cross-session decoding ........................................................ 18
Figure 1.7. Cross-subject decoding. ........................................................ 19
Figure 2.1. Centerout tasks and Locomotion tasks .................................. 37
Figure 2.2. Overview of wavelet framework ............................................. 38
vii
Figure 2.3. Comparison between neural signal waveform and spike counts
and Temporal features Q capture the temporal patterns of spike trains .. 41
Figure 2.4. Sliding window structure ........................................................ 43
Figure 2.5. Decoding performance for locomotion tasks and center-out
tasks ........................................................................................................ 46
Figure 2.6. Influence of window size and different hyperparameters ....... 48
Figure 2.7. Decoding performance for locomotion tasks using LSTM
decoder .................................................................................................... 49
Figure 3.1. With BPN, one can switch at runtime the network parameters
that are global optimal for each task ........................................................ 54
Figure 3.2. Concept ................................................................................. 55
Figure 3.3. Defining adversarial perturbations in input space vs. beneficial
perturbations in activation space ............................................................. 62
Figure 3.4. Beneficial perturbation network (BD + EWC or BD + PSP
variant) with two tasks ............................................................................. 64
Figure 3.5. Visualization of classification regions .................................... 75
Figure 3.6. Results for a fully-connected network with 5 hidden layers of
300 ReLU units ........................................................................................ 77
Figure 3.7. 100 permuted MNIST datasets results .................................. 80
Figure 4.1: Difference in training pipelines between adversarial training
and BPN to defend against adversarial examples ................................... 87
Figure 4.2. Difference between adversarial training and BPN in fixing the
data distribution drifts ............................................................................. 91
Figure 4.3. Structure difference between normal network (baseline) and
BPN for forward and backward pass ....................................................... 94
Figure 4.4. BPN extension to deep convolutional neural network ............ 98
Figure 5.1. Superclassing ImageNet dataset ......................................... 108
Figure 5.2 Confusion matrix for inter-superclasses prediction ............... 110
Figure 5.3. Performance gap between classifying an image for subclasses
inside the superclass it belongs to and classifying an image for all
subclasses from all superclasses .......................................................... 112
Figure 5.4. Two-stage Super-Sub framework ........................................ 115
Figure 5.5. Efficient implementation of Super-Sub framework ............... 117
Figure 5.6. The performances of Efficient and Vanilla Implementations of
the Super-Sub framework ...................................................................... 121
viii
Supplementary Figure 1.1. Detailed general framework ........................ 143
Supplementary Figure 1.2. Velocity activity maps ................................. 144
Supplementary Figure 1.3. Position activity maps ................................. 145
Supplementary Figure 1.4. Cross-session decoding ............................. 146
Supplementary Figure 1.5. Cross-subject decoding .............................. 147
Supplementary Figure 1.6. Correlations across neural spike trains
samples for each neuron sorted by the averaged correlation coefficient for
each neuron. .......................................................................................... 148
Supplementary Figure 1.7. Detail structure of CC-LSTM-GAN .............. 149
Supplementary Figure 1.8. averaged performance for cross-session
decoding on GAN-augmentation and Real-only methods with increasing
number of dropped neurons .................................................................. 150
Supplementary Figure 1.9. Visualization examples of actual movement
trajectory ................................................................................................ 151
Supplementary Figure 1.10. Normalized velocity activity map ............... 152
Supplementary Figure 1.11. Normalized acceleration activity map ....... 153
Supplementary Figure 2.S1. Reconstruction neural signal from wavelet
coefficients and scaling function coefficients. ........................................ 158
Supplementary Figure 2.S2, Visualization of ankle movement .............. 159
Supplementary Figure 2.S3. Influence of window size and different
hyperparameters in 5-fold cross-validation for Kalman Filter ................. 160
Supplementary Figure 2.S4. Decoding performance for locomotion tasks
and center-out tasks measured by mean square errors between decoded
covariates and ground truths in 5-fold cross-validation .......................... 161
Supplementary Figure 3.S1. Flow chart of a typical Type 4 method ...... 165
ix
Abstract
Figure 1. The big picture of my research. Interaction between Primate Brains
and Artificial Intelligence Systems. Inspirations from understanding the
mechanisms of primate brains can help to design a better artificial intelligence
system. Inversely, the improvement of artificial intelligence systems provides
us with better tools to discover the mechanism of primate brains.
Recent technology improvements such as the emerging of deep neural network
enable us to have a better understanding the mechanisms of human brains.
Reversely, better understanding the mechanisms of human brains help us build
a better bio-inspired system, letting them to control themselves like humans,
think like humans and react like humans. The two entities form a symbiotic
relationship (Fig.1). During the last five years study with Dr. Laurent Itti, I have
been working on several projects to explore the interaction between better
understanding the human brains using machine learning tools and designing a
better bio-inspired deep machine learning system. There are two research
pathways –
(i.) From Artificial Intelligence system (A.I.) to brain (Fig.2).
x
Figure 2. From A.I. to brain
In the first pathway, we leverage current mathematic and A.I. tools (e.g., Deep
generative model, Recurrent neural network, wavelet transform) to better
understanding brain functions and structures. These understandings would
improve the applicability of brain computer interfaces and enhance the
diagnosis and intervention of brain disorders.
(ii.) From brain to A.I. (Fig.3).
Figure 3. From brain to A.I.
In the second pathway, current A.I. systems have some limitations (e.g.,
catastrophic forgetting, adversarial examples, Fig.3). We leveraged inspirations
from primate brains (e.g., Hippocampus, Vision Cortex) to design better bio-
inspired A.I. systems to address these limitations. These improvements would
xi
help us to design a next generational general and adaptive artificial intelligence
systems.
In the following chapters, I will describe five research projects.
1) Rapid transfer of brain-machine interfaces to new neuronal ensembles
or participants (Published on Nature Biomedical Engineering, impact
factor: 26.7)
2) Capturing spike train temporal pattern with wavelet average coefficient
for brain machine interface (Published on Scientific Reports, Impact
factor: 4.57)
3) Beneficial Perturbation Network for designing general adaptive artificial
intelligence systems (Published on IEEE Transactions on Neural
Networks and Learning Systems, impact factor: 11.68)
4) Beneficial perturbation network for defending adversarial examples
(Submitted to Pattern Recognition Letters, Impact factor: 2.81)
5) What can we learn from misclassified ImageNet images (under
preparation)
1
1 Rapid transfer of brain-computer interfaces to new neuronal ensembles or
participants via generative modelling
1.1 Abstract
Brain computer interfaces (BCI) that directly link the brain to artificial actuators
have the potential to circumvent severe paralysis. However, obtaining sufficient
training data for the algorithms that map neural signals onto actions can be
difficult, expensive, or even impossible. A recent trend in machine vision is to
develop generative models which learn from an image distribution to synthesize
a virtually unlimited number of new, similar images. Here, we developed a new
type of generative model that learns a mapping from hand movements
(kinematics) to associated spike trains (neural attributes). After training on one
experimental session, we rapidly adapted the learned mapping to new sessions
or subjects using limited additional neural data. The adapted model synthesized
new spike trains to accelerate the training and improve the generalization of a
BCI decoder. The approach is general and fully data-driven, and hence could
apply to neuroscience problems beyond motor control.
1.2 Introduction
A motor brain computer interface (BCI)
1
is a system that enables users to
control artificial actuators or even contraction of paralyzed muscles
2,3
, by
decoding motor output from recorded neural activity. Many methods have been
proposed to build such a decoder, including linear Wiener Filters
1,4,
Kalman
Filters
5,6
, Particle Filters
7,8
, Point Process methods
9,10
, and Long Short-Term
Memory (LSTM)
11,12
networks. However, current BCI decoders face several
limitations. First, the most powerful of these methods (LSTM) typically require
large amounts of neural data to achieve good performance. Second, these
decoders typically generalize poorly over time, requiring periodic recalibration.
2
Lastly, decoders are user-specific and must be trained from scratch for each
subject, posing problems for clinical applications, where training data is difficult
to acquire.
The applicability of BCI decoders could be greatly improved if they could
generalize across recording sessions and to new subjects. In the cross-session
scenario, a decoder trained with data from one recording session is expected
to generalize to data from another session, despite possibly having different
sets of recorded neurons. This scenario is important for designing BCI decoders
that can maintain decoding performance despite glial scarring
13,14
, relative
motion between electrode, and brain or cell death
14
, which may change the
effective number and identity of recorded neurons from day to day. Even if those
problems can be minimized, natural neural plasticity
15
might still require
adaptation of the decoder over time. In the cross-subject scenario, a decoder
trained with data from one subject would be expected to generalize to data from
another subject, usually also with a different number of neurons, possibly after
some limited additional training using data from the new subject. Leveraging
data from the first subject (or subjects) would be beneficial, for example, when
simultaneously obtaining sufficient neural data and covariates of interest from
subsequent subjects is expensive, difficult, or impossible
16
(e.g., paralyzed
patients cannot generate motor outputs or some complex task variables are
difficult to track). In addition, neural data from the first subject might be
inherently easier to decode (e.g., the quality of signal collected by the implanted
electrode arrays might be better).
Even with ample available data, cross-session and cross-subject adaptation
could leverage similar structure in the neural data despite differences in the
recordings
17,18
. Recently, Pandarinath et al.
19
demonstrated that a latent
dynamic system, trained on neural data from multiple sessions, can
successfully predict movement kinematics from additional neural data of these
3
sessions, but they did not show generalization to completely new sessions or
subjects. Farshchian et al.
20
demonstrated that adversarial domain adaptation
of latent representation of firing rates could successfully predict movement
kinematics from latent-signal inputs over many days using a fixed decoder.
However, BCI decoders that can generalize to different sessions
21,22
or subjects
without complete re-training have not yet been demonstrated. We believe this
is because current approaches lack a principled representation of neural
attributes (e.g., position, velocity and acceleration activity maps, velocity neural
tuning curves, distribution of mean firing rates and correlation between spike
trains)
23
.
Here, we leverage state-of-the-art machine learning techniques
24–29
to explore
the interactions between a generative spike synthesizer and a BCI decoder to
improve generalization across sessions and subjects. We trained a deep-
learning spike synthesizer with one session of motor cortical neural population
data recorded from a reaching monkey. The spike synthesizer can synthesize
realistic spike trains with realistic neural attributes. With the help of the
synthesized spike trains, our results show, for the first time, how one can
accelerate the training of BCI decoders and improve generalization across
sessions and subjects. With small amounts of training data, we demonstrate
modest but highly significantly improved training of a BCI decoder in cross-
session and cross-subject decoding. Further, we demonstrate that our method
can transfer some useful information and boost cross-subject decoding
performance beyond the best achievable performance by training only on data
from the second subject.
4
1.3 Results
1.3.1 Experimental setup and data preparation
Two monkeys (Monkey C and Monkey M) were chronically implanted with
electrode arrays (Blackrock microsystems) in the arm representation of the
primary motor cortex (M1). The monkeys were seated in front of a video screen
and grasped the handle of a planar manipulandum that controlled the position
of a cursor. We recorded neural spiking activity on each electrode while the
monkeys made reaching movements to a sequence of targets appearing in
random locations on the screen
30
. After the cursor reached a given target, a
new target appeared, to which the monkeys could reach immediately (Fig.
1.1.a).
We collected two sessions of neural data: one with 33 minutes and 69 neurons
(session one), and the other with 37 minutes and 77 neurons (session two) from
Monkey C. We also collected one session with 11 minutes and 60 neurons
(session one) from Monkey M. We parsed and binned all neural and kinematic
data with 10ms time resolution.
5
Figure 1.1 . (a) Experimental paradigm: monkeys were seated in front of a video
screen and grasped the handle of a planar manipulandum that controlled the
position of a cursor. Monkeys made reaching movements to a sequence of
randomly-placed targets appearing on the screen while we recorded neural
activity in primary motor cortex using an implanted electrode array. (b) Training
the baseline BCI LSTM decoder. Recorded spike trains are input to the BCI
LSTM (Brain-Computer Interface, Long Short-Term Memory, a recurrent neural
network that can learn to decode spike trains). The decoder outputs predicted
kinematics, by first learning a time-varying generalizable internal representation
(symbols t-1, t, t+1), and then mapping it to kinematic space (using readout
weights WBCI). During training, actual kinematics (ground truth) are compared
to the predicted ones using an L2 loss function (i.e., a Euclidean distance
comparison) and used to refine the decoder (Methods).
6
1.3.2 Spike synthesizer and general structure
Generative adversarial networks
24–29
(GAN) provide a tool to learn, end-to-end,
a mapping from a random noise input to an output point that belongs to a
desired data distribution. The process of training a GAN can be thought of as
an adversarial game between a generator and a discriminator. The role of the
generator is to produce fake data that seem real, while the discriminator learns
to recognize whether data are real or fake. Competition in this adversarial game
can improve both components until the generated fake data are
indistinguishable from real data. The random noise input allows the GAN to
synthesize different instances or variations around the desired target point. For
example, in computer vision, after training a GAN on images of people with
various hairstyles and clothes, the generator can synthesize, for any new
person, many realistic-looking images of that person with different hairstyles
25
or with different clothes
26
. Note that current deep generative models can only
generate samples from the distribution they have been trained on. Thus, GANs
cannot generalize to new kinds of things that they have never been trained with
(e.g., a GAN trained in images of shirts could not generate images of pants).
Our synthesizer (Fig. 1.2. step 1, a sequential adaptation of a GAN) learns a
direct mapping from hand kinematics to spike trains. This is achieved through
the training of a new type of spike-based GAN (Methods, Supplementary Fig.
1.S.1. step 1). After training, it can capture the embedded neural attributes. As
machine vision GANs which cannot generalize from shirts to pants, here, we
expect that our model can synthesize new spike trains with good neural
attributes for kinematics seen during training, but cannot generalize to new
kinematics never encountered at training time. Next, new neural data from a
second session, or a different subject, is split into training and test sets. The
training set, which can be very small (e.g., 35 seconds of data), is used to
adapt
31,32
the synthesizer to the new domain (Fig. 1.2. step 2). Once adapted,
7
the synthesizer outputs spike trains that emulate the properties of the new data,
and thus can be used to assist the training of a BCI decoder for that session.
To train the BCI decoder (Fig. 1.2. step 3), the synthesized spike trains are
combined with the same limited training set that was used to adapt the
synthesizer. We show that training on this combination of real and synthesized
spike trains yields a better BCI decoder than training on the limited neural data
alone, because the spike synthesizer significantly increases the diversity and
quantity of neural data available to train the BCI decoder. In essence, this
achieves smart data augmentation
30,33
, whereby a limited amount of new real
neural data combines with and adapts a previously learned mapping from
kinematics to spike trains, delivering an updated mapping and synthesizer that
can generate sufficiently realistic data to effectively train the BCI decoder.
Figure 1.2. General Framework Step 1: Training a spike synthesizer on the
neural data from session one of Monkey C (S.1, M.C) to learn a direct mapping
8
from kinematics to spike trains and to capture the embedded neural attributes.
Gaussian Noise and Real Kinematics are input to the Spike synthesizer
(consisted of a Generator and a Readout). The spike synthesizer generates
realistic synthesized spike trains by first learning the embedded neural
attributes using a Generator (a bidirectional LSTM recurrent neural network)
through a bidirectional time-varying generalizable internal representation
(symbols t-1, t, t+1). Different instances of Gaussian Noise combined with new
kinematics yield different embedded neural attributes that all have similar
properties to those used for training. Then, the Readout maps the embedded
neural attributes to spike trains (using readout weight WG). Step 2: Adapting
the spike synthesizer to produce synthesized spike trains suitable for another
session or subject from Real Kinematics and Gaussian noise. We first freeze
the generator to preserve the embedded neural attributes or virtual neurons
learned previously. Then, we substitute and fine-tune the readout modules
using a limited neural data from another session or subject (session two of
Monkey C (S.2, M.C) or session one of Monkey M (S.1, M.M)). The fine-tuned
readout modules adapt the captured expression of these neural attributes into
spike trains suitable for another session or subject. Step 3: Training a BCI
decoder for another session or subject using the combination of same small
amount of real neural data used for fine-tuning (in step 2) and a large amount
of synthesized spike trains (in step 2).
1.3.3 Validation of spike synthesizer using a random reaching task.
We trained the spike synthesizer on session one of Monkey C, and
characterized both the recorded and virtual neurons using properties such as
firing rates and position, velocity and acceleration activity maps. In addition,
we measured the correlations between the recorded real and synthesized
neurons.
9
1) The spike synthesizer can synthesize spikes with realistic position (velocity,
acceleration)-related activity
Figure 1.3. Normalized position activity map, constructed as the histogram of
neural activity as a function of position. (a) Position activity map for real
neuron 35 normalized across the workspace. (b) corresponding position activity
map for virtual neuron 35. (c,d) Position activity maps for real and virtual
neuron 3 (e) Histogram of mean squared error between the real and generated
activity maps for all neurons. The purple line is the trimmed averaged mean
square error (based on 99% samples, 0.13) between real neurons. It provides
a reasonable bound for quantifying the difference between real and virtual
neurons.
We first asked whether the virtual neurons had position (velocity, acceleration)
10
activity maps that resembled those of real neurons. This would indicate that the
spike synthesizer captured how M1 neurons encode position information. To
answer this question, we compared the position, velocity and acceleration
activity maps built from synthesized spikes trains to those of actual spikes. We
counted the number of spikes for different hand positions and normalized them
with respect to the averaged spike counts across the workspace. Fig. 1.3.a
shows the normalized position activity map for real neuron 35, and Fig. 1.3.b
shows its virtual counterpart. The MSE between the two maps is 0.0086, which
in this example is lower than the trimmed average MSE between real neurons
(0.13, based on 99% samples). Note how the position activity maps have similar
light and dark spots. Fig. 1.3.c shows the normalized position activity map for
real neuron 3, and Fig. 1.3.d shows its virtual counterpart. The MSE between
the two maps is 0.21, higher than the average MSE. In this example, the
position activity maps exhibit similar overall features but differ slightly in the
exact location of the peaks. Fig. 1.3.e shows a summary position histogram for
all real - virtual neuron pairs. The histogram is left-skewed, around the mean of
0.13. Sixty-one of sixty-nine neurons (88%) had an error less than the trimmed
average. Fig. 1.S.10 (Fig. 1.S.11) shows a summary velocity (acceleration)
histogram for all real - virtual neuron pairs. Those histograms are left-skewed,
with a trimmed mean of 0.11 (0.10). Sixty of sixty-nine neurons (87%) had an
error less than the average. This shows that, with respect to position (velocity,
acceleration)-related activity, the model has learned realistic virtual neurons.
2) The spike synthesizer can synthesize spikes with realistic firing patterns
We then asked whether the spike synthesizer can learn to synthesize spike
trains from specific kinematics (Fig. 1.4.a) with realistic firing rates (Fig. 1.4.b,c).
We found that the spike synthesizer produces firing rates (distribution of firing
rates across all neurons) that are not different from the those of real neurons
(Fig. 1.4.d; Kolmogorov–Smirnov test: The samples are not from different
11
distributions with p=0.1536, i.e., the test failed to show the distributions differ).
Then, we asked whether the synthesized spike trains are correlated with real
spike trains more than would be expected by chance. To assess this, we used
pairwise correlations computed after placing the spikes in 250ms time bins
34
, a
time scale relevant to behavioral movements. We compared the correlation
coefficients between pairs of real and synthesized spike trains to those between
real spike trains and those generated by a homogenous Poisson process, an
estimate of chance level (also see Supplementary information for an alternate
measure that uses randomly shuffled real spike trains). The correlation
coefficient is equal to one if spike trains are identical and zero if they are
independent. Fig. 1.5.a shows time series of binned spike counts for neuron 8
(an example from left portions of Fig. 1.S6.b), with a correlation between
synthesized and real neural data of 0.59. In comparison, the correlation (Fig.
1.5.b) between the neural data from a homogeneous Poisson distribution and
real neural data is 0.03. Fig. 1.5.c shows similar plots for neuron 59 (an
example from right portions of Fig. 1.S6.b). Here, the correlation between actual
and synthesized data is 0.18, while between real neural and homogeneous
Poisson data is 0.07 (Fig. 1.5.d). In summary, for 91% (63 out of 69) of the
neurons (Fig. 1.5.e), the correlations of the actual spike trains with synthesized
trains were higher than the correlations between neural spikes of randomly
shuffled neurons.
In sum, the spike synthesizer learned to generate spike trains that are better
correlated with real spike trains than one would expect by chance, although
some differences are apparent between real and synthesized data. We thus
next asked whether the synthesized spike trains are sufficiently realistic to
effectively assist the training of a BCI decoder, which is the primary goal of this
study.
12
Figure 1.4. (a) Movement kinematics (position). (b) Real neural data for the
kinematics in a). (c) Synthesized spike trains from the spike synthesizer for
the kinematics in a). (d) firing rates (distribution of firing rates across all neurons)
for the kinematics in a) for real and virtual neurons.
13
Figure 1.5. (a,b) Time series of binned spike counts for neuron 8 (good example,
left portions of Supplementary Fig. 1.6.b). Correlation between synthesized and
real neural data is 0.59. In contrast, correlation between generated (red, from a
Homogeneous Poisson distribution) and real (black) neural data is 0.03. (c,d)
Time series of binned spike counts for neuron 59 (bad example, right portion of
Supplementary Fig. 1.6.b). Correlation between synthesized neural data and
real neural data is 0.18, while correlation between generated neural data from
Poisson distribution and real neural data is 0.07. (e) Scatter plot for synthesized
neural data vs. randomly shuffled real spike trains baseline. Each black point
represents a neuron. The vertical axis is the correlation between synthesized
and real neural data across all neural spike train samples for each neuron with
blue standard error bar. The horizontal axis is the correlation between neural
data of randomly shuffled neurons across all neural spike train samples with
14
red standard error bar.
1.3.4 Synthesized neural spikes can accelerate the training and improve
generalization of cross-session and cross-subject decoding
To test the utility of our spike synthesizer, we explored whether it can accelerate
the training and improve the generalization of a BCI decoder across sessions
and subjects. We trained the spike synthesizer from the data of the first session
of Monkey C (S.1, M.C, Fig. 1.2. step 1). Then, we split new neural data from a
second session, either from the same subject (session 2 of Monkey C; S.2, M.C)
or from a different subject (first session of Monkey M; S.1, M.M), into training
and test sets. We froze the generator of the spike synthesizer to preserve the
embedded neural attributes. Then, we fine-tuned the readout module of the
spike synthesizer with as little as 35 seconds of the new training set to learn its
dynamics. After training the BCI decoder with various combinations of real and
synthesized data, we tested it on an independent test set from another session
(S.2, M.C) or subject (S.1, M.M), and compared the decoding performance to
other data augmentation methods (Supplementary). These included either
using real neural data alone (Real-only) to train the BCI decoder, or three other
data augmentation methods – Real augmentation (duplicate and concatenate
available real neural data), Stretch augmentation (duplicate, randomly stretch,
and concatenate available real neural data) and Mutation augmentation
(duplicate, add noise, and concatenate available real neural data).
1) Synthesized spike trains accelerate the training of a BCI decoder when the
neural data from another session or subject is limited
We first tested whether synthesized spike trains can accelerate training for
cross-session decoding. When using less than 17 minutes of new data from
S.2, M.C to train the BCI decoder, our GAN augmentation method performed
15
better than the other augmentation methods (Supplementary, Fig. 1.S.4), and
the Real-Only method (Fig. 1.6.a). Without any data augmentation, the BCI
decoder needed at least 8.5 minutes of training data to converge. However,
using our spike synthesizer it becomes possible to train the BCI decoder with
much less real data. At the extreme, with only 35 seconds of new neural data
(S.2, M.C), augmented by 22 minutes of synthesized data, cross-session
decoding performance was better than for competing methods (used to extend
the original 35 seconds of new data to 22 minutes). The averaged best
performance for all kinematics (the lowest point for each curve in Fig. 1.6.b) of
the GAN-augmentation method was 7.2%, 4.8%, 5.6% and 16% better than the
Stretch-Augmentation, Mutation-Augmentation, Real-Concatenation, and Real-
Only methods, respectively. Fig. 1.6.c shows general performance results for
all kinematics variables over the percentage of new neural data used,
demonstrated with a format similar to receiver operating characteristic (ROC)
curves. The area under the GAN-augmentation curve was 0.030, 0.020, 0.015,
and 0.33 greater than for the Stretch-Augmentation, Mutation-Augmentation,
Real-Concatenation, and Real-Only methods, respectively. To achieve
accuracy saturation (Fig. 1.6.d, >=95% of the peak for the real-only method of
training on all neural data from S.2, M.C), GAN-augmentation only requires 1.82
minutes additional neural data from S.2, M.C, compared to 5.47, 6.02, 8.67 and
12.93 minutes (4.27x, 4.70x, 6.77x and 10.10x) for Mutation-Augmentation,
Real-Concatenation, Stretch-Augmentation and Real-Only, respectively. Real-
Only method requires 8.8 minutes additional neural data to converge and
achieves 0.77 for average correlation coefficient across kinematics. In
comparison, GAN-Augmentation requires only 0.67 minutes of additional neural
data to achieve the same performance. In summary, our GAN-augmentation is
the best among all tested methods for cross-session BCI training.
Synthesized spike trains can also accelerate the training for cross-subject
decoding. When using less than 9 minutes of new data from S.1, M.M to train
the BCI decoder, our GAN augmentation method performed better than the
16
other augmentation methods (Supplementary, Fig. 1.S.5), and the Real-Only
method (Fig. 1.7.a). Without any data augmentation, the BCI decoder needed
at least 8.5 minutes of training data to converge. Fig. 1.7.b shows the general
performance results for all kinematics variables over a wide range of new neural
data. The averaged best performance for all kinematics (the lowest point for
each curve in Fig. 1.7.b) of the GAN-augmentation method is 60%, 17%, 11%
and 6% better than the Stretch-Augmentation, Mutation-Augmentation, Real-
Concatenation and Real-only methods. Fig. 1.7.c shows general performance
results for all kinematics variables over the percentage of new neural data used,
demonstrated with a similar format as ROC curves. The area under the GAN-
augmentation curve was 0.11, 0.052, 0.039, and 0.58 greater than the Stretch-
Augmentation, Mutation-Augmentation, Real-Concatenation, and Real-Only
methods, respectively. None of the tested methods achieved performance
comparable to GAN-augmentation, since the GAN-augmentation method can
transfer learned dynamics and improve the decoding performance beyond the
best performance achievable on data from Monkey M alone (Further explained
in the following section). To achieve accuracy saturation, GAN-augmentation
only needs 0.1 minutes of additional neural data from S.1, M.M, compared to
1.25, 2.65, 4.02 and 8.67 minutes (12.50x, 26.50x, 40.20x and 86.70x) for Real-
Concatenation, Mutation-Augmentation, Stretch-Augmentation and Real-Only,
respectively. Thus, our GAN-augmentation is again the best among all tested
methods for accelerating cross-subject BCI training.
2) Transferring learned dynamics and improving beyond the best-achievable
cross-subjects decoding
Training a spike synthesizer that learns good neural attributes from Monkey C
(with its better decoding performance) might transfer useful information to boost
cross-subject decoding performance beyond the best achievable on data from
17
Monkey M alone. When neural data are ample for both Monkey C (Fig. 1.5.a)
and Monkey M (Fig. 1.6.a), the decoding performance on position and velocity
are equally good for both monkeys, though the performance on acceleration of
Monkey C is better than that of Monkey M (0.85 vs. 0.43 for acc x; 0.82 vs. 0.48
for acc y). The better decoding performance on acceleration data from Monkey
C might come from the better quality of signals collected by the electrode arrays,
or from the fact that neural data from Monkey C is inherently easier for the
decoder than that from Monkey M. A lower quality of the collected neural data
from Monkey M might decrease the decoding accuracy of accelerations,
because accelerations are the usually hardest to decode; but this could be
improved through prior learning from monkey C. Since decoding performance
on position and velocity for monkey M are already good, we wondered whether
any useful information learned from neural data of Monkey C could improve the
best achievable decoding performance on acceleration for Monkey M.
Here, we compared the best achievable performance of all methods (the
highest points for acceleration x and acceleration y curves in Fig. 1.7.a and Fig.
1.S.5). As a result, for acceleration y, the best achievable performance for GAN-
Augmentation was 0.64, significantly better than the best achievable
performance for Real-Only (0.48; p = 10
-14
, t-test), Stretch-Augmentation (0.62;
p = 10
-3
, t-test), Mutation-Augmentation (0.51; p = 10
-11
, t-test) and Real-
Concatenation (0.59; p = 10
-4
, t-test). The results for acceleration x are similar
as that of acceleration y. Thus, even with ample neural data for both monkeys,
the neural attributes learned from the first monkey can transfer some useful
information to improve the best achievable decoding performance for the
second monkey.
18
Figure 1.6. Cross-session decoding. There are six methods: GAN-
Augmentation (red) Mutation-Augmentation (purple), Stretch-Augmentation
(orange), Real-Augmentation (blue), Real-Only (green). The cutoff time for b)
and c) is 20.53 minutes - the minimum amount of time needed for other methods
to achieve performance comparable to GAN-augmentation. a) The
performances for each kinematic in 5-fold cross-validation. The horizontal axis
is the number of minutes of neural data from the session two of Monkey C used.
The vertical axis is the correlation coefficient between the decoded kinematics
and real kinematics on an independent test set from the session two of Monkey
C. Synthesized spike trains that capture the neural attributes accelerate the
training of a BCI decoder for the cross-session decoding. b) Average
percentage performances worse than GAN-Augmentation method. The
horizontal axis is the number of minutes of neural data from the session two of
Monkey C used. The vertical axis is the average performance for all six
19
kinematic signals of each method worse than the average percentage
performance of GAN-Augmentation method. c) Average performances
presented similar to receiver operating characteristic (ROC) curve. The
horizontal axis is the percentage of neural data from the training set of session
two of Monkey C used. The vertical axis is the same as a). d) Amount of
additional neural data (S.2, M.C) needed for each method to achieve accuracy
saturation (>=95% of the peak for real-only method training on all neural data
from S.2, M.C).
Figure 1.7. Cross-subject decoding. Methods and their notations are the same
as that of Fig. 1.6. a) Performances for each kinematic in 5 folds cross-
validation. The horizontal axis is the number of minutes of neural data from
20
Monkey M used. The vertical axis is the correlation coefficient between the
decoded kinematics and real kinematics on an independent test set from the
Monkey M. When the neural data from another subject is limited, synthesized
spike trains that capture the neural attributes accelerate the cross-subject
decoding performance. In addition, synthesized spike trains have learned
generalizable information that can boost cross-subject decoding performance
on acceleration over the best achievable performance. b) Average percentage
performances worse than GAN-Augmentation method. The horizontal axis is
the number of minutes of neural data from the session one of Monkey M used.
The vertical axis is the same as Fig. 1.6. b. c) Average performances presented
similar to receiver operating characteristic (ROC) curve. The horizontal axis is
the percentage of neural data from the training set of session one of Monkey M
used. The vertical axis is the same as a). d) Amount of additional neural data
(S.1, M.M) needed for each method to achieve accuracy saturation (>=95% of
the peak for real-only method training on all neural data from S.1, M.M).
1.4 Discussion
The intuition of this paper is that building a spike synthesizer that captures
underlying neural attributes can accelerate the training and improve the
generalization of BCI decoders. We have introduced a new spike synthesizer
that can learn a mapping from kinematics to neural data. The spike synthesizer
captures and reproduces the neural attributes in its training set. After the fine-
tuning procedure, we showed that the synthesizer adapted to data from a new
session or even a new subject, and can synthesize new spike trains. The
synthesized spike trains can accelerate the training of a BCI decoder and
improve its generalization across sessions and subjects.
We could interpret the improvement in the decoding performance from the
perspective of statistics of movements. The understanding of the statistics of
21
movements can help us in many situations where obtaining simultaneous
recordings of both neural activity and kinematics is difficult. Dyer, et al.
35
proposed a fundamentally new approach, distribution-alignment decoding
(DAD), to leverage the statistics of movements to achieve semi-supervised
decoder training. They built prior distributions of motor-output variables during
a simple planar reaching task. DAD was then able to align the distributions of
decoder outputs to kinematic distribution priors. However, it is not clear whether
the approach would work in more complex movements, as one would need
many prior distributions of motor-output variables (high-level templates). The
complex movements might involve sub-movements such as holding still, rapid
reaching, slow reaching, etc. It might take a lot of time to craft these templates
manually, and it is hard to choose the right form of the templates for a task and
to combine them properly. In comparison, our spike synthesizer recreated
neural attributes such as position, velocity and acceleration activity maps,
velocity neural tuning curves (low-level templates) directly from the data. For
more complex kinematics, our spike synthesizer might synthesize spike trains
drawn from different distributions for holding and rapid reaching by combining
those neural attributes in an autonomous way. There is no need to handcraft
high-level templates (prior distributions of motor-output variables) if they can be
constructed from more fundamental low-level ones.
The neural attributes captured by the spike synthesizer (position activity maps
and velocity neural tuning curves, Fig. 1.S.2, 1.S.3, 1.S.10, 1.S.11) have a
special name in the literature
36–41
– motor primitives. Motor primitives are a
generalization of movement components or the templates (kinematic
distribution priors) used by DAD. It has been suggested that the motor cortex
may control movement through flexible combinations of motor primitives,
elementary building blocks that can be combined to give rise to complex motor
behavior. Thoroughman et al.
40
defined a motor primitive as the directional
neural tuning curve for each neuron, fitted by a Gaussian function. They built
movement trajectories through linear combinations of those tuning curves. In
22
related research, Stroud et al.
41
used gain patterns over neurons to predict
movement trajectories. Tanmay et al.
42
learned a diverse set of motor primitives
through a latent representation of a neural abstraction network from kinematics
given in demonstration. Then, they split new complex kinematics into subparts
and fit those subparts with the motor primitives. The key problem in motor
primitive research is how to learn the primitives, and then how to combine them
to create complex motion. Here, our results suggest that our spike synthesizer
learned motor primitives through its internal representation (similar to Tanmay
et al.
42
) from both kinematics and neural data (in contrast to Tanmay et al. who
only use kinematics) and their combination rules in an autonomous and
principled way. In addition, from an extended version of Thoroughman’s
40
definition of motor primitive that includes both position and velocity tuning for
each neuron, we can successfully reproduce these motor primitives by
reconstructing position and velocity tuning curves for each neuron from the
internal representation.
Recently, Pandarinath et al.
19
proposed an interesting method to infer latent
dynamics from single-trial neural spike data by leveraging a deep learning auto-
encoder (LFADS). Our method is complementary to this work in the following
ways. First, LFADS focuses on how to construct a mapping from the neural data
to low-dimensional latent variables (neural population dynamics), while
simultaneously reconstructing the same neural data from a Poisson process
parameterized by these low-dimensional latent variables. In contrast, our
method uses a generative adversarial network to create a mapping directly from
the kinematics to the neural data in an end-to-end manner. In addition, the
mapping does not impose any prior distribution on the data.
Recent advances in the analysis of neural population activity, especially with
the help of BCI methods, have revealed the importance of the covariance
structure of neuronal populations in controlling movements and motor skill
23
acquisition
43–45
. Although, in our work, the spike synthesizer was trained to
learn a mapping between kinematic and neural data, internally, the spike
synthesizer maps the kinematics to a generalizable internal representation that
captures embedded neural attributes, including the high-order statistics
between neurons such as covariance. We expect this covariance structure
imposed by our generalizable internal representation between neurons can
accelerate training for BCI decoders that even aim to generalize across tasks.
Finally, our framework of learning the internal representation from raw data and
additional synthesized data is general and fully data-driven, in contrast to neural
encoding models
46–49
that assume Gaussian or Poisson distributions. Hence,
our framework could be applied to other neuroscience encoding and decoding
problems beyond motor control with minimal too domain-specific modifications.
1.5 Methods
1.5.1 The BCI decoder.
We use the state-of-the-art Long Short-Term Memory (LSTM) network
11,12
as
the decoder. Recurrent neural networks can use their feedback connections to
store a representation of recent input in the hidden states. However, with the
traditional backpropagation through time to update the hidden states, they often
suffer either a gradient exploding or vanishing problem. Long Short-Term
Memory creates an uninterrupted gradient flow and thus have a better
performance. The structure of LSTM cell can be formalized as
𝑖 𝑓 𝑜 𝑔 𝜎 𝜎 𝜎 𝑡𝑎𝑛ℎ
𝑊 ℎ
𝑥 (1)
𝑐 𝑓 ⊙𝑐
𝑖 ⊙𝑔 (2)
24
ℎ
𝑜 ⊙ 𝑡𝑎𝑛ℎ 𝑐 (3)
Where 𝑥 is the input at time t and ℎ
is the hidden dimension at time t. 𝑊 is
the weight vector from ℎ
and 𝑥 to gates. 𝜎 is the sigmoid function. i is the
input gate, deciding whether to write to cell. f is the forget gate, deciding whether
to erase the cell. g is the gate gate, deciding how much to write to the cell. o is
the output gate, deciding how much to reveal cell. 𝑐 is the middle variable. In
the LSTM decoder case, we unroll our LSTM cell and consider 200 timesteps
for each sample. The input dimension is (N, T, D), where N is the number of
samples, T is the number of timesteps, D is the feature dimensions. Our input
is batched neural spikes where there are 128 samples, 200 timesteps and
number of neurons (69 for session 1, 77 for session 2 of Monkey one, 60 for
session 1 of Monkey two) for feature dimensions. The hidden dimensions h is
200 for the LSTM decoder. So, we have an output dimension (N 128, T 200, H
200) from LSTM decoder. We feed this output into a fully connected layer to
produce the kinematics (dimension [128, 200, 6]). We apply dropout
techniques
50
and learning rate decay
51
while training the LSTM decoder.
1.5.2 Bidirectional LSTM
52
.
The output of a sequence at a current time slot relies not only on the sequences
before it, but also depends on the sequences after it. So, to better capture the
neural attributes, we use the bidirectional LSTM to build the generator and
discriminator in our Constrained Conditional LSTM GAN model. At each time-
step t, this network has two hidden states, one for left-to-right propagation and
another for the right-to-left propagation. The update rule is
25
ℎ
⃗
𝑔𝑊 →
𝑥 𝑉
⃗
ℎ
⃗
𝑏
→
(4)
ℎ
←
𝑔𝑊 ←
𝑥 𝑉
←
ℎ
←
𝑏
→
(5)
Where ℎ
⃗
and ℎ
←
maintain the left-to-right hidden state and right-to-left hidden
state separately at time t. g is the LSTM cell update function in Eq.1,2,3.
1.5.3 Generative adversarial network (GAN).
GANs
21
provide a tool to learn a map from a random noise input to the desired
data distribution in an end to end way, updating its parameters via
backpropagation
53
. Thus, it does not require any assumption about the data
distribution. It is purely data-driven and does not need a strong prior model
which would limit the generality. The process of training a GAN can be thought
as an adversarial game between a generator and a discriminator. The role of
the generator can be thought of as to produce fake currency and use it without
detection, while the discriminator learns to detect the counterfeit currency.
Competition in this adversarial game can improve both components’ abilities
until the counterfeits are indistinguishable from real currency. After this
competition, the generator can take random noise (that provides variations) as
input to output different kinds of realistic bills with different textures. Several
approaches have been proposed for image synthesis using GANs, enhanced
to be able to generate output images for a particular object class, such as
conditional GAN
24
, Semi-Supervised GAN
54
, InfoGAN
27
, AC-GAN
25
and
cGANs
29
. In this fake currency scenario, by injecting the conditions
(Denomination of each bill) into the input of GAN, we can select which kind of
bills we want to generate (e.g., a 100-dollar bill), while the noise provides the
26
varied texture of the bills (e.g., smooth or wrinkled).
1.5.4 Constrained Conditional LSTM GAN (cc-LSTM-GAN, the spike
synthesizer):
We propose the constrained conditional LSTM GAN to model the behavior of
M1 from the kinematics. A normal LSTM model takes input which has a
dimension of (N, T, D) where N is the number of samples in one minibatch, T is
the time horizon, D is the hidden dimension size. We chose 2 seconds as time
horizon in our experiments. the first input dimension, N is the number of batches.
The third input dimension, D, is the number of neurons for the discriminator or
noise dimension for the generator. For each item in the batch, we have a 2
second slice of neural spikes from D neurons. Since the number of neurons is
the third hidden dimension of the LSTM in the discriminator, our discriminator
treats different neurons as individuals that have different neural tuning
properties. Thus, the CC-LSTM GAN encoding model is a multiple neural
encoding model.
Training assistant LSTM decoder (GAN-ta LSTM decoder). We train a LSTM
decoder (hidden dimension h: 200) on neural data ([N,T,M], where N is the
sample size 128, T is the time horizon 200, M is the number of neurons (69 for
S.1, M.C, 77 for S.2, M.C, 60 for S.1, M.M) ) from one monkey and freeze its
parameters when we train our Constrained Conditional Bidirectional LSTM
GAN. This decoder applies constraints to the cc-LSTM-GAN. We want to
maintain the decoding performance while we train the encoder.
Bidirectional LSTM generator. The bidirectional-LSTM generator takes
Gaussian noise ([N, T, D], where N is sample size 128, T is time horizon 200,
D: Dimension for Gaussian noise 6 ) and real kinematics ( [N, T, D] where N is
sample size 128, T is time horizon 200, D: Dimension for kinematics 6) as inputs
27
and synthesizes the corresponding spikes trains. We feed the outputs
(dimensions [N, T, 2*H], where N is the Sample size 128, T is the time horizon
200, H is the hidden dimension 200) of the bidirectional-LSTM into a fully
connected layer to synthesize spikes trains with the correct number of neurons
(dimensions [N,T,M] where N is the sample size 128, T is the time horizon 200,
M number of neurons (69 for S.1, M.C, 77 for S.2, M.C, 60 for S.1, M.M)). We
apply
𝑥𝑡𝑎𝑛ℎ function
55
as the output layer of the fully connected layer which
maps a real value to [-0.5, 0.5] that gives us a probability representation of
whether the current bin contains a spike event or not. For example, if the value
in the current bin is -0.3, the probability there is a spike event in this bin is 0.2
(-0.3 + 0.5). Here, for analysis, we sample Bernoulli distribution to generate the
neural firing patterns. While, for decoding, the LSTM decoder directly takes the
probability as an input.
Bidirectional LSTM discriminator. The Discriminator is a bidirectional-LSTM. It
takes the synthesized spikes trains ([N,T,M], where N is the sample size 128, T
is the time horizon 200, M is the number of neurons (69 for S.1, M.C, 77 for S.2,
M.C, 60 for S.1, M.M) ) and neural data ([N,T ,M], N,T , M have the same meaning
and dimension as synthesized spike trains) as inputs and learns to determine
whether a sample is from the neuron data or synthesized spikes trains. We feed
the outputs (processed by a sigmoid activation) of the bidirectional-LSTM into
a fully connected layer to obtain a decision value (0,1), a probability that decides
whether the current sample is real or fake. We feed the output of a bidirectional-
LSTM into another fully connected layer to get the decoded kinematics. This
helps us to apply the category constraints.
Multiple neural encoding model (CC-LSTM-GAN). we use a conditional
structure that has a GAN category loss to let the discriminator distinguish both
data source distribution (neural spikes) and data label distribution (kinematics
28
corresponding to these spikes). When the input of the bidirectional LSTM
discriminator is the neural data (synthesized spikes trains), the real (fake)
embedding features are the output of bidirectional LSTM discriminator. GAN
embedding category loss is the L2 loss between the real embedding features
and the fake embedding features.
𝐺𝐴𝑁𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔𝑐𝑎𝑡𝑒𝑔𝑜𝑟𝑦𝑙𝑜𝑠𝑠
𝑅𝑒𝑎𝑙𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔𝑓𝑒 𝑎𝑡 . 𝑢𝑟𝑒𝑠
𝐹 𝑎𝑘𝑒𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔𝑓𝑒𝑎𝑡𝑢𝑟𝑒𝑠
.(6)
When the input of bidirectional LSTM discriminator is the neural data
(synthesized spikes trains), the real (fake) decoded kinematics are the output
of the fully connected layer after the bidirectional LSTM discriminator. Real
(fake) GAN decoding category loss is the L2 loss between real (fake) decoded
kinematics and real kinematics. The GAN category loss is the average of real
GAN decoding category loss, fake GAN decoding category loss and GAN
embedding category loss.
GAN category loss =
x (Real GAN decoding category loss +
Fake GAN decoding category loss +
GAN Embedding category loss)
(7)
To maintain the source distribution, our cc-LSTM-GAN needs to play the min
max game. We need to minimize discriminator GAN loss ( 𝐿 ) and generator
GAN loss ( 𝐿 ). The 𝐿 and 𝐿
𝐿𝐸 ~ 𝑙𝑜𝑔𝐷 𝑥 𝐸
~ 𝑙𝑜𝑔 1 𝐷𝐺 𝑧 (8)
𝐿𝐸 ~ 𝑙𝑜𝑔 𝐷 𝐺 𝑧 (9)
29
Where z is the Gaussian noise, x is the neuron spikes, k is the kinematics, p(z)
is the noise distribution, 𝑝
is the data distribution. 𝐿 is the discriminator
loss calculated by cross entropy loss function, 𝐿 is the generator loss
calculated by cross entropy loss function.
To further maintain the statistical structure of the real neurons, we want to
maximize the inner product loss between the neural data and synthesized
spikes trains. This is achieved using the hyperparameter 𝑎 , which adjusts the
scale of inner product loss. We use 0.0001 in our implementation. Thus, we
have our inner product loss as
𝑖𝑛𝑛𝑒𝑟𝑝𝑟𝑜𝑑𝑢𝑐𝑡𝑙𝑜𝑠𝑠 𝑎 𝑠𝑦𝑛𝑡ℎ𝑒𝑠𝑖𝑧𝑒𝑑𝑠𝑝𝑖𝑘𝑒𝑡𝑟𝑎𝑖𝑛𝑠 ∙ 𝑛𝑒𝑢𝑟𝑎𝑙𝑑𝑎𝑡𝑎 (10)
The pre-trained GAN-ta LSTM decoder takes the neural data and synthesized
spike trains as input and decodes the corresponding kinematics. The decoded
generated kinematics are kinematics decoded from synthesized spike trains.
The decoded real kinematics are kinematics decoded from neural data. We
apply L2 loss between real kinematics and decoded generated kinematics. We
apply L2 loss between decoded generated kinematics and decoded real
kinematics. The pre-trained GAN-ta LSTM decoder helps our generator
synthesize more realistic spike trains in terms of the performance of GAN-ta
LSTM decoder.
The total generator loss is the weighted average of the generator GAN loss ( 𝐿 ),
the GAN category loss, the inner product loss, the L2 loss between decoded
generated kinematics and decoded real kinematic and the L2 loss between
decoded generated kinematics and real kinematics.
Total
Generator
loss =
0.7 * 𝐿 + 0.2 * GAN category loss + 0.1 * inner product
loss + 0.1 * L2 Loss between decoded generated
kinematics and decoded real kinematics + 0.1 * L2 Loss
between decoded generated kinematics and real kinematics
(11)
30
The total discriminator loss is the weighted average of the discriminator GAN
loss ( 𝐿 ) and the GAN category loss. We train this network by real GAN
training set (a subset of the neural data from the first monkey) and minimize the
total discriminator and generator losses.
Total discriminator loss = 0.8 ∗𝐿
0.2 ∗𝐺𝐴𝑁𝑐𝑎𝑡𝑒𝑔𝑜𝑟𝑦𝑙𝑜𝑠𝑠 (12)
1.5.5 Fine tuning
We trained CC-LSTM-GAN on the session one data of Monkey one. In the
finetuning process, we took a limited amount of neural data from session one
of Monkey one (cross-session) or a from Monkey two (cross-subject). We
added another fully connected layer (Readout module) on top of the
bidirectional LSTM generator and used it to generate the corresponding neural
spikes with the same number of neurons as the added data. We froze the
parameters of hidden units from the bidirectional LSTM generator and only
trained this new fully connected layer with limited new data. The loss function
of the finetuning process is the inner product loss in Eq.10. Then, we fed the
kinematics corresponding this limited amount of neural data into the Generator
multiple times to synthesize a large amount of spike trains.
1.5.6 Correlation
To calculate the correlations across neural spike train samples for each neuron,
31
we parsed the synthesized and real neural spike into 250ms bins and set the
time horizon to 25s. We calculated the correlations for each sample and
averaged the absolute value of correlations for each neuron. We compared the
result against that of randomly shuffled real spike trains (compared the real
neural data with shuffled neural data from other neurons) and that of the neural
spikes generated from a homogeneous Poisson distribution. For homogeneous
Poisson distribution, it assumes that the generation of each spike depends only
on a homogeneous firing rate and is independent of all the other spikes. Thus,
it is a good baseline for comparison.
1.5.7 Data Augmentation
We used multiple data augmentation methods to train a BCI decoder and
achieve a better decoding performance than training a BCI decoder on the
neural data only. The start point of limited neural data is always selected from
the start of the neural data in each fold. The end point of the limited neural data
is the start point plus the length of the required amount of limited data.
Real-Only method: Without any data augmentation, we trained the BCI decoder
on the neural data only and tested on the neural data of an independent test
set. This method required at least 8.5 mins neural data in the training set to let
the BCI decoder converge (The criteria of convergence are that the training loss
would decrease slowly during training. If we have less than 8.5 minutes of
training data, the training loss does not decrease no matter for how long we
train the LSTM BCI decoder). If the BCI decoder fails to converge, the
correlation coefficient is defined as correlation between randomly shuffled real
kinematics and real kinematics (0.00027, chance level)
Real-Concatenation: We took a limited amount of neural data from the training
set and concatenated it multiple times until it has the equal or longer length than
32
the whole training set. We trained the BCI decoder on the concatenated neural
data and tested on the neural data of an independent test set.
Mutation-Augmentation: We flipped the value of the limited neural data with 5%
probability (spike to non-spike or non-spike to spike). We repeated this step
several times and concatenated the mutated neural data and its kinematics until
they had equal or longer length than the whole training set. We trained the BCI
decoder on the mutated neural data and tested on the neural data of an
independent test set.
Stretch-Augmentation: We stretched the limited neural data by 10 percent and
filled the empty stretched slots of the neural data by zero or one (50%
probability). We calculated the averaged absolute gradients for each kinematic
signal for each slot. We filled the empty stretched slots of the kinematics by
adding or subtracting (50% probability) the value of the last slot with the
averaged absolute gradients. We repeated the above steps several times and
concatenated the stretched neural data and its kinematics together until it had
equal or longer length than the whole training set. We trained the BCI decoder
on the stretched neural data and tested on the neural data of an independent
test set.
GAN-Augmentation: We combined the synthesized spike trains (22 minutes in
our paper) with the augmented data from Real-Concatenation method. We
trained the BCI decoder on the combination of synthesized spike trains and
concatenated neural data and tested on neural data of an independent test set.
33
1.5.8 Hyperparameters
Table 1.1 Hyperparameters
Modules Sampl
e
size
Time
step
Hidden
dimension
s
Trainin
g
Epochs
Learnin
g
rates
optimizer activation
Generat
or
128 200 200 4000 0.0006* Adam
56
tanh
Discrimi
nator
128 200 200 4000 0.0003 SGD sigmoid
GANta
LSTM
Deocder
128 200 200 200 0.003* Adam N/A
LSTM
BCI
decoder
**
128 200 200 200 0.003* Adam N/A
*with exponential learning rates decay, ** same hyperparameters for all data
augmentation methods
2. Capturing spike train temporal pattern with Wavelet Average Coefficient for
Brain Machine Interface
2.1 Abstract
Motor brain machine interfaces (BMIs) directly link the brain to artificial
actuators and have the potential to mitigate severe body paralysis caused by
neurological injury or disease. Most BMI systems involve a decoder that
analyzes neural spike counts to infer movement intent. However, many
classical BMI decoders 1) fail to take advantage of temporal patterns of spike
34
trains, possibly over long time horizons; 2) are insufficient to achieve good BMI
performance at high temporal resolution, as the underlying Gaussian
assumption of decoders based on spike counts is violated. Here, we propose a
new statistical feature that represents temporal patterns or temporal codes of
spike events with richer description –wavelet average coefficients (WAC) – to
be used as decoder input instead of spike counts. We constructed a wavelet
decoder framework by using WAC features with a sliding-window approach,
and compared the resulting decoder against classical decoders (Wiener and
Kalman family) and new deep learning based decoders (Long Short-Term
Memory) using spike count features. We found that the sliding-window
approach boosts decoding temporal resolution, and using WAC features
significantly improves decoding performance over using spike count features.
2.2 Introduction
Motor brain machine interfaces (BMIs) utilize signal processing and machine
learning techniques to decode recorded neuronal activity into motor commands.
These techniques include the Wiener filter
4,57
, Kalman filter
7,58,59
, Particle filter
8,60,61
, Point Process filter (PPF)
9,62–66
, and Long Short-Term Memory (LSTM)
from deep learning
67,68
.
BMIs that continuously decode spiking activity of neuronal ensembles often
utilize a decoding scheme where neuronal firing rates are represented a
number of spikes within non-overlapping time bins; the time step for generating
decoded signal is equal to the bin width. Classical BMIs (e.g., Wiener and
Kalman filter) assumes the spike counts within each bin are Gaussian and
updates every 50-100 ms. This bin width usually provides good temporal
resolution and a sufficient amount of neuronal data needed for accurate
decoding, but the Gaussian assumption can sometimes be violated. With
wider bins, neural data is better approximated by a Gaussian distribution, but
35
increasing bin size hinders temporal resolution. Several recent publications
9,62–66,69
have argued that even the temporal resolution of 50-100 ms is
insufficient for high BMI performance, and a resolution of 5 ms is preferable.
However, decoding at such a high temporal resolution would severely decrease
the decoding performance, as spike counts in 5~ms bins severely violate the
classical filter’s approximately Gaussian assumption. This can be solved either
by point process model (e.g. model each bin counts as a Poisson process) or
our sliding window approach (much easier to implement than point process
model). To address this, we develop a sliding window approach for the Kalman
and Wiener filters, where a wide window (e.g., consisting of 10 bins, each 5 ms
wide) slides by small time increments (e.g., 5 ms), to achieve both high temporal
resolution and near-Gaussian data distribution.
To understand the dynamics of neurons
70
, it is important to characterize their
firing patterns. In rate coding scheme, information is encoded in the number of
spikes per observation (spike counts, mean firing rates
23
, etc.). However, any
information possibly encoded in the temporal structure of the spike trains is
ignored. For example, neural spike train sequence (1 for a spike, 0 for no spike)
000111 can mean something different from 100101, even though the mean
firing rate is the same for both sequences. More importantly, precise spike
timing or high-frequency firing-rate fluctuations are found to carry
information
71,72
. Functions of the brain are more temporally precise than the use
of only rate encoding seem to allow
73
. Temporal codes
72–78
employ those
features of the spiking activity that cannot be described by the firing rate (e.g.,
time to first spike, phase of firing, etc.) alone.
In BMIs design, classical decoders often use spike counts (rate coding scheme)
as an input feature. However, spike counts fail to fully take advantage of the
distribution and correlation of the historical data. Spike counts neglect the
distribution of spikes in the current bin, the connections of distributions over
36
past bins, and it cannot derive information contained in quiet periods where
there is no spike. Thus, there is a pressing need to develop better temporal
coding features than spike counts. We argue that the important information is
not only encoded by spikes at specific time instants, but also encoded by the
quiet periods that do not have any spikes.
To address this, we propose a new feature (wavelet average coefficients, WAC)
that can describe a variety of temporal sequences of spike events than mere
binned spike counts allow. The extracted WAC enables the decoders to
incorporate information from a long history (e.g., 500 ms). Such a long history
is achievable because WAC captures the dynamical pattern of the spike events
over time, which allows us to explore the information contained in spike events
better. Indeed, WAC exploits information in both one (spike in that bin) events
and zero (no spike in that bin) events over the longer time horizon of the whole
window. We test the framework on multi-electrode array recordings in monkeys
performing reaching and locomotion tasks. By tuning the sliding window size of
the wavelet framework, we find that sliding window size correlates with
movement frequency. Our results show that the decoding performance of
Wavelet Framework boosted Wiener and Kalman filter & LSTM decoder at high
temporal resolution. The resulting decoders also outperformed that of decoders
using spike counts as input features for monkey data in reaching and
locomotion tasks.
2.3 Experimental paradigm
Four monkeys were chronically implanted with electrode arrays of the primary
motor cortex (M1). We recorded neural activity in primary motor cortex using an
implanted electrode array while monkeys performed "center-out" and
"locomotion" tasks (Fig.2.1). For center-out, we recorded neural activity from
two monkeys making reaching movements to targets, 3 sessions and 153
37
neurons from monkey one, & 2 sessions and 153 neurons from monkey two.
For locomotion, we recorded neural activity from two additional monkeys (490
neurons for monkey three and 388 neurons for monkey four) while they walked
10 minutes forward at 12.5 cm/second, and walked 12.5 minutes backward at
12.5 cm/second.
Figure 2.1. 1) Centerout tasks: monkeys were seated in front of a video screen
and grasped the handle of a planar manipulandum that controlled the position
of a cursor. Monkeys made reaching movements to a sequence of targets
appearing on the screen while we recorded neural activity in primary motor
cortex using an implanted electrode array. 2) Locomotion tasks: monkeys
walked on the treadmill. We measured the ankle x and ankle y while we
recorded neural activity in primary motor cortex. We acknowledge artist MinJun
Xu for creating the artwork.
2.4 Methods
2.4.1 Wavelet Framework
Our wavelet framework consists of four separate modules (Fig. 2.2 a ): Kernel
Function Module, Discrete Wavelet Transform Module, Preprocessing Module,
and Decoder Module. (1) To extract the dynamical pattern of spike events, the
kernel function module converts neural spike trains with different distributions
into different discrete neural signal waveforms. (2) Discrete Wavelet Transform
38
Module
79
encodes the discrete neural signal waveforms with different shapes
into different trend features Q. (3) Preprocessing Module selects the right trend
features and further shrinks the number of trend features by averaging each of
them to produce WAC. This allows us to use a few parameters to describe the
dynamical patterns in a large horizon (e.g., 500 ms neural spike trains). Thus,
it prevents overfitting. (4) The sliding window based Decoders Module decodes
the kinematics from WAC. WAC and sliding window endow the decoder to
decode kinematics with high temporal resolution and high decoding accuracy.
Figure 2.2. a) Overview of wavelet framework. Given the neural spikes, our
model uses kernel functions to transform it into a discrete neural signal
waveform that fluctuates with the distribution of spike events. Our model uses
discrete wavelet transform to encode this discrete neural signal waveform into
trend feature tensor Q to capture the temporal pattern of neural spikes. By
preprocessing the trend feature and using it as input to the traditional decoder
like Wiener or Kalman filters with sliding window approach, we can produce the
39
decoded kinematics with high temporal resolution and high decoding accuracy.
b) Kernel function module converts spike trains to a discrete neural signal
waveform that would fluctuate with the distribution of the spike events. c)
Discrete wavelet transform module captures the temporal patterns of a signal
d) Preprocessing module reduces the dimensions by averaging each trend
features Q and select suitable trend features as WAC.
2.4.2 Kernel Function Module
To extract the dynamical pattern of spike events, the kernel function module
(Fig. 2.2b) takes spike trains (x[n]) as an input, and translates different spike
event distributions into different discrete neural signal waveforms (k[n]). The
kernel function is
k n k n 1 2 ∗ x n 0.5 , s. t. k 0 0, ∀n ∈ 1, T (1)
where T is the time horizon. This kernel function outputs a discrete neural signal
waveform that would fluctuate with the distribution of spike events (see
Discussion for more details).
2.4.3 Discrete Wavelet Transform
Wavelet transform
79
is an excellent tool to capture the temporal patterns of a
signal. Fourier transform decomposes a signal into different frequency
components using different periodic exponential functions. Similarly, wavelet
transform decomposes a signal into different wavelet coefficients and a scaling
function coefficient using different detail functions that have different scales
(see Supplementary Fig. 2.S1). The scaling function coefficient encodes
information from the large scale (trend) of the signal. The wavelet coefficients
encode information from the small scale (details) of the signal. In contrast to
Fourier transform, discrete wavelet transform localizes "spike trends" in both
time and frequency at different scales. Here, Discrete Wavelet Transform
40
Module (Fig.2.2 c) encodes different discrete neural signal waveforms with
different shapes into different trend features Q (concatenation of scaling
function coefficient and wavelet coefficients). Trend features Q allow us to use
a few parameters to represent the discrete neural signal waveform. We use
db3 wavelets
80
as basis to decompose the neural signal waveforms
(corresponding to the high and low pass filters in Fig. 2.2 c). This step
essentially allows us to describe a complicated waveform such as in Fig. 2.3a
with a few numbers.
41
Figure 2.3. a) Comparison between neural signal waveform and spike counts.
There are 10 bins (gray). The size of each bin is 50ms. The time resolution of
spike trains (blue) is 5ms. The spike counts (black square waves with the
number of spikes on top of it) in each bin ignore the temporal patterns of spike
events (e.g., bin 3 and bin 4 have the same number of spikes, but the
distributions of spike events are different). Neural signal waveform (pink)
fluctuates with the distribution of spike events. b) Temporal features Q capture
the temporal patterns of spike trains since we can reconstruct neural signal
42
waveform using encoded trend features Q (see Supplementary Fig.2.S1 for
more information about reconstruction). Each neural signal waveform in each
sliding window is encoded into trend feature Q. BMI decoders can better decode
kinematics using trend features Q than that of BMI decoders using spike counts.
2.4.4 Preprocessing Module for Generating WAC
Preprocessing Module (Fig. 2.2d) selects the suitable trend features and further
reduces the dimensions of trend features Q by averaging each of them to
produce WAC. For example, if we decompose neural signal waveforms using
discrete wavelet transform 5 times, we have one scaling function coefficient ( 𝑐 ,
dimension for 𝑐 : [7] ) and five wavelet coefficients ( 𝑑 , 𝑖∈ 1,5 , dimensions for
each 𝑑 : [100, 50, 25, 13, 7]). We use a single number to represent each
coefficient by averaging each of them through their dimensions. Then we have
one averaged scaling function coefficient 𝑐
and five averaged wavelet
coefficients 𝑑
, 𝑙∈ 1,5 . One can select 𝑐
as WAC since it represents the
large scale (trend) of the neural signal waveforms. In addition, as we show in
results, combining the averaged scaling function coefficient (large scale) with
averaged wavelet coefficients (small scale) can further improve the decoder
performance (e.g., selecting 𝑐
, 𝑑
and 𝑑
as WAC). Thus, additional small
scale information is helpful for decoding.
2.4.5 Comparison between trend feature Q and spike counts
Spike counts fail to capture the temporal patterns of spike events (Fig. 2.3a)
and only use a single number to summarize the firing rates. In comparison,
neural signal waveform is encoded into trend feature Q in each sliding window
(Fig.2.3 b). Trend features Q capture the temporal patterns of spike events with
richer description. Through various experiments (see Results section), we
found that BMI decoders can better decode kinematics from trend features than
43
from spike counts as trend features encode temporal patterns of spike events.
2.4.6 Sliding Window for Wiener filter and Kalman filter
Figure 2.4. Sliding window structure. We bin the neural spikes into 5 or 10~ms
bin size in which there is only one or none spike. The window size is the length
of the sliding window. The tap size is the number of slide windows. The lag size
is the time lag between consecutive slide windows. The slide size is how long
we move in the timeline of the whole sliding window structure from global time
instant n-1 to global time instant n. The slide size is equal to the bin size in our
paper, we move 1 bin at a time.
Here, we proposed a sliding window structure (Fig.2.4). We combined the
sliding window structure with classical Wiener and Kalman filters and compared
their performances between 1) using WAC features as inputs (our wavelet
framework) and 2) using spike counts as inputs (classical approaches with
sliding window improvement, Supplementary). It is worth noting that WAC
allows us to use a long window size (e.g., 500ms) compared to a short window
size of spike counts (e.g., 50 ms). Thus, WAC provides longer historical
information for the decoders.
44
Wavelet framework for Wiener filter with sliding window augmentation: We use
5 ms bin size, 1 s window size, 4 taps (number of slide windows), 50 ms lag
size and 5 ms slide size. Here, as an example, we decompose the neural signal
waveforms five times using discrete wavelet transform. After averaging trend
features Q, we have one averaged scaling function coefficient 𝑐
and five
averaged wavelet coefficients 𝑑
, 𝑙∈ 1,5 . We choose 𝑐
and 𝑑
, 𝑙∈
1,3 as WAC, calculated for neuron i, and sliding window j, and averaged
wavelet coefficients l. The updating rule is:
𝑦 𝑛 ∑∑ ∑ 𝑤
∗𝑑
𝑛 𝑤𝑐
∗𝑐
(2)
where y[n] is the covariates at time n, N is the number of neurons, 4 is the tap
size, l is the iterator for three averaged wavelet coefficients, 𝑤
is the weight
for neuron i, sliding window j and l averaged wavelet coefficients 𝑑
, 𝑤
is
the w eight for neuron i, sliding window j and averaged scaling function
coefficient 𝑐
.
Wavelet framework for Kalman filter with sliding window augmentation: We use
5 ms bin size, 1s window size, 1 tap (number of slide windows) 0 ms lag size
(since we only we 1 tap) and 5 ms slide size. Here, as an example, we
decompose the neural signal waveforms three times using discrete wavelet
transform. After averaging trend features Q, we have one averaged scaling
function coefficient 𝑐
and five averaged wavelet coefficients 𝑑
, 𝑙∈ 1,3 .
We choose 𝑐
as WAC. The state space model for Kalman filter is:
𝑐
𝑛 1 𝐴𝑐
𝑛 𝑤 𝑛 (3)
𝑦 𝑛 𝐶𝑐
𝑛 𝑣 𝑛 (4)
where n is the time instance, y[n] is the covariates, w[n] and v [n] is Gaussian
noise with zero mean, A and C are time constant parameters need to be
estimated in the training part. The recursive equation of Kalman filter is in the
Supplementary, from Eqn.1 to Eqn.5:
45
Classical Wiener filters and Kalman filters (see Supplementary, from Eqn.6 to
Eqn.12) with sliding window augmentation
For Wiener filters, we use 5 ms bin size, 50 ms window size, 4 taps, 5 ms lag
size and 5 ms slide size. the updating rules is:
𝑦 𝑛 ∑∑ 𝑤
∗𝑥
𝑛
(5)
where y[n] is the covariates at time n, N is the number of neurons, M is the
number of taps, 𝑤
is the weight for neuron i at sliding window j, and 𝑥
𝑛 is
the spike counts calculated from sliding window j of neuron i at time n.
2.4.7 LSTM decoder using WAC as inputs
To test whether WAC can improve the decoding performance of the state-of-
the-art LSTM decoder
67,68
(see Supplementary, from Eqn.13 to Eqn.15), we
compared the performance of the LSTM decoder using WAC as inputs to that
of the LSTM decoder using spike counts as inputs.
2.5 Results
2.5.1 Sliding Window improves decoding performances of the classical Wiener
and Kalman filters in high temporal resolution
The decoding performance of sliding window for Kalman (Wiener) filters are
better than that of classical Kalman (Wiener) filter in 5 ms high temporal
resolution (Fig. 2.5). The reason is that spike counts in 5 ms bin size severely
violate the Gaussian assumption of Kalman and Wiener filter. But a sliding
window structure with 50 or 100 ms window size enables the classical decoders
to maintain approximately Gaussian distributions while still maintaining a high
46
temporal resolution. Thus, it yields better decoding accuracy.
Figure 2.5. Decoding performance for locomotion tasks and center-out tasks
measured by correlation coefficient between decoded covariates and ground
truths {in 5-fold cross-validation (mean + / - S.D., n = 5 folds). We use Wilcoxon
signed-rank test to validate the results.} a) Locomotion walking forward task for
47
Monkey 3. b) Locomotion walking forward task for Monkey 4. c) Locomotion
walking backward task for Monkey 3. d) Locomotion walking backward task for
Monkey 4. e) Center-out task for Monkey 1. f) Center-out task for Monkey 2.
2.5.2 Wavelet framework further improves the performance of Kalman and
Wiener filters augmented by slide windows
The decoding performance of wavelet framework for Kalman (Wiener) filters
with sliding window augmented are better than that of Kalman (Wiener) filter
augmented using sliding window alone in 5 ms high temporal resolution
(Fig.2.5). The reason is that WAC enables decoders to use a long window size
(e.g., 500 ms), compared to a short window size (e.g., 50 ms) with spike counts.
Thus, WAC provides longer historical information to the decoder. In addition,
the spike events that contain no spike are as important for our decoder as the
spike events that contain one spike. The distribution of spike events is encoded
inside of WAC. In summary, our trend features WAC, which capture the dynamic
pattern of neural spikes, encoded by the discrete wavelet transform, can
provide us with better features than the traditional spike counts. As a
consequence, decoders using WAC can achieve better decoding performance
than decoders using spike counts.
2.5.3 Sliding window size correlates with movement frequency
We test the decoding performances for each covariate under the influence of
sliding window size. In centerout tasks, the best sliding window size for
decoding position is around 500 ms (Fig. 2.6.a). The best sliding window size
for decoding velocity is around 350 ms (Fig. 2.6.b). Thus, we conclude that
monkey brains encode position (slow changing, position increases
monotonically) with coarser time resolution (i.e., longer window size), while
encoding velocity (fast changing, joystick velocity increase from 0 to some top
48
speed, then decreases back to 0) with higher temporal resolution. In more
complicated locomotion task in 3D environments (Supplementary Fig. 2.S2),
monkey brains exhibit give a temporally coarser encoding for the ankle x (Fig.2.
6.c, around 500 ms window size, time period 3.2 seconds, amplitude 0.25)
which has a larger amplitude with slow changing rates. Meanwhile, monkey
brains exhibit a temporally finer encoding for the ankle y (Fig. 2.6.d, around 350
ms window size, time period 1.9 seconds, amplitude 0.05) which usually
oscillate back and forth rapidly. In addition, monkey brains are not likely to
encode movement information into a large time scale (e.g, 1 second, a large
decline of performances).
Figure 2.6. Influence of window size and different hyperparameters measured
by correlation coefficient between decoded covariates and ground truths in 5-
fold cross-validation (mean + / - S.D., n = 5 folds, 5ms temporal resolution). a,
49
b) monkey 1’s center out task for cursor position y and velocity y, 10 ms slide
size, 50 ms lag size, db3 basis. c, d) monkey 3’s locomotion task for left ankle
x and ankle y, 10 ms slide size, 50 ms lag size, db3 basis.
2.5.4 Using WAC as inputs improves the decoding performance of the LSTM
decoder in high temporal resolution
To show that our WAC is a richer feature compared to spike counts in different
decoding platforms (from simple regressions model (e.g., Wiener or Kalman
Filter) to advanced deep networks (LSTM)), we demonstrated that the decoding
performance of a LSTM decoder using WAC is better than that of LSTM
decoder using spike counts in 5ms high temporal resolution (Fig.2.7).
Figure 2.7. Decoding performance for locomotion tasks using LSTM decoder
measured by correlation coefficient between decoded covariates and ground
truths in 5-fold cross-validation (mean + / - S.D., n = 5 folds, 5ms temporal
resolution). We use t-test on the z-score transformed from the correlation
50
coefficient to validate the result.
2.6 Discussion
There are three major contributions: 1) we proposed a new statistical feature -
WAC, which captures the distribution of spike events. 2) We developed a new
wavelet framework combined with sliding window to leverage WAC. It enables
the classical decoders to work well in high temporal resolution. In addition, we
demonstrated that the BMI decoders using WAC can achieve better decoding
performance than decoders using classical spike counts as inputs. 3) We found
that sliding window size correlates with movement frequency.
Why is the temporal patterns or codes of neural spike so important? The
precise spike timing is significant in neural encoding as several studies have
found that temporal resolution of the neural code is on a millisecond time
scale
73,76,81
. In encoding of visual stimuli, Gollisch et al. claimed that neurons of
the retina encode spatial structure of an image in the relative timing between
stimulus onset and the first action potential (time to first spike)
73
. In encoding
of gustatory stimuli, Carleton et al. claimed that gustatory neurons deploy both
rate and temporal coding schemes to differentiate between different tastants
types. In our wavelet framework approach, WAC captures the temporal patterns
or codes of neural spikes. It not only captures the spike events at specific time
instants, but also captures the information encoded by the quite periods that do
not have any spikes. As a result, WAC incorporate more information than
classical rate-related statistic features (e.g., spike counts). Thus, WAC features
can improve the performance of BMIs than that of BMIs using classical statistic
features.
WAC allows decoders to incorporate information from a very long history of data.
The state space prior model for classical decoders, such as Kalman filters and
51
Point Process filters, only allows those decoders to look back the spike counts
inside of previous one or several bin sizes. Using spike counts as input
features do not give enough information for a model to look at the overall
distribution of spikes. For example, Shanechi et al.
65
proposed a linear
dynamical model, in which the kinematic state at time t only includes information
from the kinematic state, brain control state, and Gaussian noise state at time
t-1. Thus, it oversimplified prior model fails to give enough information that can
be accumulated by all historical data. In comparison, WAC encodes information
from a very long history of data and represented it in a succinct way. When
WAC is combined with our sliding window approach, it provides abundant
historical information for classical decoders.
Wavelet transform is not new in neuroscience. For examples, it have been used
for spike sorting
82
, spike detection,
83,84
, capturing direction-related
information
85
and speed-related
86
features, stably tracking neural information
over a long time
87
& denoising of neural signals
88,89
. In particular, Lee et al.
89
built a BMI decoder that is robust to large background noise by leveraging high
frequency components (wavelet coefficients calculated from wavelet transform
directly on spike trains) since it has the ability to localize high frequency
information in the spike trains. In contrast, our method uses kernel functions to
transform the temporal patterns of spike trains into a discrete signal waveform
that fluctuates with the temporal patterns. Our method then leverage the
equivalent of the low frequency components of the neural signals (scaling
function coefficient calculated from wavelet transform on discrete signal
waveform) to improve decoder performance, since they represent the temporal
patterns of spike trains.
52
2.7 Materials and Method
All animal procedures were performed in accordance with the National
Research Council’s Guide for the Care and Use of Laboratory Animals and were
approved by the Duke University Institutional Animal Care and Use Committee.
The study was carried out in compliance with the ARRIVE guidelines.
3. Beneficial Perturbation Network for designing general adaptive artificial
intelligence systems
3.1 Abstract
The human brain is the gold standard of adaptive learning. It not only can learn
and benefit from experience, but also can adapt to new situations. In contrast,
deep neural networks only learn one sophisticated but fixed mapping from
inputs to outputs. This limits their applicability to more dynamic situations,
where the input to output mapping may change with different contexts. A salient
example is continual learning - learning new independent tasks sequentially
without forgetting previous tasks. Continual learning of multiple tasks in artificial
neural networks using gradient descent leads to catastrophic forgetting,
whereby a previously learned mapping of an old task is erased when learning
new mappings for new tasks. Here, we propose a new biologically plausible
type of deep neural network with extra, out-of-network, task-dependent biasing
units to accommodate these dynamic situations. This allows, for the first time,
a single network to learn potentially unlimited parallel input to output mappings,
and to switch on the fly between them at runtime. Biasing units are programmed
by leveraging beneficial perturbations (opposite to well-known adversarial
perturbations) for each task. Beneficial perturbations for a given task bias the
network toward that task, essentially switching the network into a different mode
to process that task. This largely eliminates catastrophic interference between
53
tasks. Our approach is memory-efficient and parameter-efficient, can
accommodate many tasks, and achieves state-of-the-art performance across
different tasks and domains.
3.2 Introduction
The human brain is the benchmark of adaptive learning. While interacting with
new environments that are not fully known to an individual, it is able to quickly
learn and adapt its behavior to achieve goals as well as possible, in a wide
range of environments, situations, tasks, and problems. In contrast, deep neural
networks only learn one sophisticated but fixed mapping between inputs and
outputs, thereby limiting their application in more complex and dynamic
situations in which the mapping rules are not kept the same but change
according to different tasks or contexts. One of the failed situations is continual
learning - learning new independent tasks sequentially without forgetting
previous tasks. In the domain of image classification, for example, each task
may consist of learning to recognize a small set of new objects. A standard
neural network only learns a fixed mapping rule between inputs and outputs
after training on each task. Training the same neural network on a new task
would destroy the learned fixed mapping of an old task. Thus, current deep
learning models based on stochastic gradient descent suffer from so-called
"catastrophic forgetting"
90–92
, in that they forget all previous tasks after training
each new one.
54
Figure 3.1. With BPN, one can switch at runtime the network parameters that
are global optimal for each task. Training trajectories are illustrated in loss and
parameter space. The green curve shows loss as a function of network
parameters for a first task A, with optimal parameters shown by the green circle.
The purple curve and circle correspond to a second task B. Training first task A
then task B with stochastic gradient descend (SGD, without any constraints on
parameters, gray) leads to optimal parameters for task B (purple circle), but
those are destructive for task A. When, instead, learning task B using EWC or
PSP (have some constraints on parameters, yellow), the solution is a
compromise that can be sub-optimal for both tasks (black circle). Beneficial
perturbations (blue curve for task A, red curve for task B) push the
representation learned by EWC or PSP back to their task-optimal states.
Here, we propose a new biological plausible (Discussion) method - Beneficial
Perturbation Network (BPN) - to accommodate these dynamic situations. The
key new idea is to allow one neural network to learn potentially unlimited task-
dependent mappings and to switch between them at runtime. To achieve this,
we first leverage existing lifelong learning methods to reduce interference
between successive tasks (Elastic Weight Consolidation, EWC
93
, or parameter
55
superposition, PSP
94
). We then add out-of-network, task-dependent bias units,
to provide per-task correction for any remaining parameter drifts due to the
learning of a sequences of tasks. We compute the most beneficial biases -
beneficial perturbations - for each task in a manner inspired by recent work on
adversarial examples. The central difference is that, instead of adding
adversarial perturbations that can force the network into misclassification,
beneficial perturbations can push the drifted representations of old tasks back
to their initial task-optimal working states (Fig.3.1).
Figure 3.2. Concept: Type 1 - constrain the network weights while training the
new task: (a) Retraining models such as elastic weight consolidation
93
: retrains
56
the entire network learned on previous tasks while using a regularizer to prevent
drastic changes in the original model. Type 2 - expanding and retraining
methods (b-c); (b) Expanding models such as progressive neural networks
95
expand the network for new task t without any modifications to the network
weights for previous tasks. (c) Expanding model with partial retraining such as
dynamically expandable networks
96
expand the network for new task t with
partial retraining on the network weights for previous tasks. Type 3 - episodic
memory methods (d): Methods such as Gradient Episodic Memory
97
store a
subset of the original dataset from previous tasks into the episodic memory and
replays them with new data during the training of new tasks. Type 4 - Partition
network (e): these use context or mask matrices to partition the core network
into several sub-networks for different tasks
94,98–101
. Type 5 - beneficial
perturbation methods (f): Beneficial perturbation networks create beneficial
perturbations which are stored in bias units for each task. Beneficial
perturbations bias the network toward that task and thus allow the network to
switch into different modes to process different independent tasks. It retrains
the normal weights learned from previous tasks using elastic weight
consolidation
93
or parameter superposition
94
. (g) Strengths and weaknesses for
each type of method.
There are three major benefits of BPN: 1) BPN is memory and parameter
efficient: to demonstrate it, we validate our BPN for continual learning on
incremental tasks. We test it on multiple public datasets (incremental MNIST
102
,
incremental CIFAR-10 and incremental CIFAR-100
103
), on which it achieves
better performance than the state-of-the-art. For each task, by adding bias units
that store beneficial perturbations to every layer of a 5-layer fully connected
network, we only introduce a 0.3% increase in parameters, compared to a 100%
parameter increase for models that train a separate network, and 11.9% - 60.3%
for dynamically expandable networks
96
. Our model does not need any episodic
memory to store data from the previous tasks and does not need to replay them
57
during the training of new tasks, compared to episodic memory
methods
92,97,104,105
. Our model does not need large context matrices, compared
to partition methods
94,98–101,106–108
. 2) BPN achieves state-of-the-art
performance across different datasets and domains: to demonstrate it, we
consider a sequence of eight unrelated object recognition datasets
(Experiments). After training on the eight complex datasets sequentially, the
average test accuracy of BPN is better than the state-of-the-art. 3) BPN has
capacity to accommodate a large number of tasks: to demonstrate it, we test a
sequence of 100 permuted MNIST tasks (Experiments). A variant of BPN that
uses PSP to constrain the normal network achieves 30.14% better performance
than the second best, the original PSP [4], a partition method which performs
well in incremental tasks and eight object recognition tasks. Thus, BPN has a
promising future to solve continual learning compared to the other types of
methods.
To lay out the foundation of our approach we start by introducing the following
key concepts: Sec.3.3: Different types of methods for enabling lifelong learning;
Sec.3.4: Adversarial directions and perturbations; Sec.3.5: Beneficial directions
and perturbations, and the effects of beneficial perturbations in sequential
learning scenarios; Sec.3.6: Structure and updating rules for BPN.
We then present experiments (Sec.3.7), results (Sec.3.8) and discussion
(Sec.3.9).
3.3 Types of methods for enabling lifelong learning
Four major types of methods have been proposed to alleviate catastrophic
forgetting. Type 1: constrain the network weights to preserve performance on
old tasks while training the new task
93,109,110
(Fig. 3.2a); A famous example of
type 1 methods is EWC
93
. EWC constrains certain parameters based on how
58
important they are to previously seen tasks. The importance is calculated from
their task-specific Fisher information matrix. However, solely relying on
constraining the parameters of the core network eventually exhausts the core
network's capacity to accommodate new tasks. After learning many tasks, EWC
cannot learn anymore because the parameters become too constrained (see
Results). Type 2: dynamic network expansion
95,96,109,111
creates new capacity
for the new task, which can often be combined with constrained network
weights for previous tasks (Fig.3.2b-c); However, this type is not scalable
because it is not parameter efficient (e.g., 11.9% - 60.3% additional parameters
per task for dynamically expandable networks
96
). Type 3: using an episodic
memory
97,104,105
to store a subset of the original dataset from previous tasks,
then rehearsing it while learning new tasks to maintain accuracy on the old
tasks (Fig.3.2d). However, this type is not scalable because it is neither memory
nor parameter efficient. All three approaches attempt to shift the network's
single fixed mapping initially obtained by learning the first task to a new one that
satisfies both old and new tasks. They create a new, but still fixed mapping from
inputs to outputs across all tasks so far, combined. Type 4: Partition Network:
using task-dependent context
94,98,101,108
or mask matrices
99,100,106,107
to partition
the original network into several small sub-networks (Fig.3.2e, flow chart -
Fig.3.S1a). Zeng et al.
98
used context matrices to partition the network into
independent subspaces spanned by rows in the weight matrices to avoid
interference between tasks. However, context matrices introduce as many
additional parameters as training a separate neural network for each new task
(additional 100% parameters per task). To reduce parameter costs, Cheung et
al. proposed binary context matrices
94
, further restricted to diagonal matrices
with -1 and 1 values. The restricted context matrices
98
(1 and -1 values) behave
similarly to mask matrices
99
(0 and 1 values) that split the core network into
several sub-networks for different tasks. With too many tasks, the core network
would eventually run out of capacity to accommodate any new task, because
there is no vacant route or subspace left. Although type 4 methods create
59
multiple input to output mappings for different tasks, many of these methods
are too expensive in terms of parameters, and none of them has enough
capacity to accommodate numerous tasks because methods such as PSP run
out of unrealized capacity of the core network.
In marked contrast to the above artificial neural network methods, here, we
propose a fundamentally new fifth type (Fig. 3.2f, flow chart - Fig. 3.S1 b): We
add out-of-network, task-dependent bias units to neural network. Bias units
enable a neural network to switch into different modes to process different
independent tasks through beneficial perturbations (the memory storage cost
of these new bias units is actually lower than the cost of adding a new mask or
context matrix). With only an additional 0.3% of parameters per mode, this
structure allows BPN to learn potentially unlimited task-dependent mappings
from inputs to outputs for different tasks. The strengths and weaknesses of
each type are in Fig. 3.2g.
3.4 Adversarial directions and perturbations
Three spaces of a neural network are important for this and the following
sections: The input space is the space of input data (e.g., pixels of an image);
the parameter space is the space of all the weights and biases of the network;
the activation space is the space of all outputs of all neurons in all layers in the
network.
By adding a carefully computed “noise” (adversarial perturbations) to the input
space of a picture, without changing the neural network, one can force the
network into misclassification. The noise is usually computed by
backpropagating the gradient in a so-called “adversarial direction” such as by
using the fast gradient sign method (FGSD)
112
. For example, consider a task of
recognizing handwritten digits "1" versus "2". Adversarial perturbations aimed
60
at misclassifying an image of digit 2 as digit 1 may be obtained by
backpropagating from the class digit 1 to the input space, following any of the
available adversarial directions. In Fig. 3.3a, adding adversarial perturbations
to the input image can be viewed as adding an adversarial direction vector (gray
arrows AD) to the clear (non-perturbated) input image of digit 2. The resulting
vector crosses the decision boundary. Thus, adversarial perturbations can force
the neural network into misclassification, here from digit 2 to digit 1. Because
the dimensionality of adversarial directions is around 25 for MNIST
112
, when we
project them into a 2D space, we use the fan-shaped gray arrows to depict
those dimensions.
3.5 Beneficial directions and perturbations, & The effects of beneficial
perturbations in multitask sequential learning scenario
In this section, we first introduce the definition of beneficial directions and
beneficial perturbations. Then, we explain why beneficial perturbations can help
a network recover from a parameter drifting of old tasks after learning new tasks
and can push task representations back to their initial task-optimal working
region.
We consider two incremental digits recognition tasks; Task A (recognizing 1s
and 2s) and Task B (recognizing 3s and 4s). Attack and defense researchers
usually view adversarial examples as a curse of neural networks, but we view
it as a gift to solve continual learning. Instead of adding input "noise"
(adversarial perturbations) to the input space calculated from other classes to
force the network into misclassification, we add "noise" to the activation space,
using beneficial perturbations stored in bias units added to the parameter space
(Supplementary Fig. 3.S1b) calculated by the input's own correct class to assist
in correct classification. To understand beneficial perturbations, we first explain
beneficial directions. Beneficial directions are vectors that point toward the
61
direction of high confidence classification region for each class (Fig. 3.3b); 𝐵𝐷
𝐵𝐷
are the beneficial directions that point to the high confidence
classification region of digit 1 (digit 2). The point 𝐴 represents the activation of
the normal neurons of each layer generated from an input image of task A. 𝐴 𝐵𝐷
( 𝐴 𝐵𝐷 ) pushes the activation 𝐴 across the decision boundary of 𝑅
( 𝑅 ) and toward 𝑅
( 𝑅
). Thus, the network would classify the 𝐴 𝐵𝐷
( 𝐴 𝐵𝐷 ) as digit 1 (2) with high confidence. To overcome catastrophic
forgetting, we create some beneficial perturbations for each task and store them
in task-dependent bias units (Fig.3.4, Supplementary Fig. 3.S1b). Beneficial
perturbations allow a neural network to operate in different modes by biasing
the network toward that particular task, even though the shared normal weights
become contaminated by other tasks. The beneficial perturbations for each task
are created by aggregating the beneficial direction vectors sequentially for each
class through mini-batch backpropagation. For example, during the training of
task A, the network has been trained on two images from digit 1 (1
and 1
)
and two images from digit 2 (2
and 2
). The beneficial perturbations for task
A are the summation of the beneficial directions calculated from each image
( 𝐵𝐷
𝐵𝐷 𝐵𝐷 𝐵𝐷 ) in Fig. 3.3c, { 𝐵𝐷
is the beneficial direction for
sample j in class i). During the training of task B, with gradient descent, the point
𝐴 (Fig.3.3b) is drifted to 𝐴 which lies inside the classification regions for task
B ( 𝑅 or 𝑅 ). The drifted 𝐴 alone cannot be classified as digit 1 or 2 since it
lies outside of the classification regions of task A ( 𝑅 or 𝑅 ). However, during
testing of task A, after training task B, adding beneficial perturbations for task A
to the drifted activation 𝐴 drags it back to the correct classification regions
for task A ( ( 𝑅 or 𝑅 ) in Fig. 3.3c). Thus, beneficial perturbations bias the
neural network toward that task and push task representations back to their
initial task-optimal working region. Note that in this work we focus on adding
more compact beneficial perturbations to the activation space, as adding
perturbations to the input space has already been explored in adversarial attack
62
methods, and adding perturbations to the parameter space is unlikely to be
scalable due to the very large number of parameters in a typical neural network.
Figure 3.3. Defining adversarial perturbations in input space vs. beneficial
perturbations in activation space. We consider two digits recognition tasks;
Task A (recognizing 1s and 2s) and Task B (recognizing 3s and 4s). (a)
Adversarial directions (AD). Adding adversarial perturbations (calculated from
digits 1) to input digits 2 can be viewed as adding an adversarial direction vector
(gray arrow) to the clear input image of digit 2 in the input space. Thus, the
network misclassifies the clear input image of digit 2 as digit 1. Beneficial
directions are not operated as adding beneficial perturbations to the clear input
image of digit 2 in the input space to assist the correct classification (orange
arrow). (b) Beneficial directions (class specific) for each class of task A. 𝑅 ( 𝑅 )
is the classification region (region of constant estimated label) of digit 1 (digit 2)
from the MNIST dataset. Subregion 𝑅
( 𝑅
) is the high (low) confidence
classification region of digit 1, and likewise for 𝑅
( 𝑅
) for digit 2. The
point 𝐴 is the activations of normal neurons of each layer from an input image
of task A. It lies in the intersection of 𝑅
and 𝑅
. 𝐵𝐷
𝐵𝐷
are
beneficial directions for class digit 1 (digit 2). 𝐴 𝐵𝐷 , blue arrows,( 𝐴 𝐵𝐷
, red arrows) pushes the activation 𝐴 across the decision boundary of 𝑅
63
( 𝑅 ) and towards 𝑅
( 𝑅
). Thus, the network classifies 𝐴 𝐵𝐷 ( 𝐴 𝐵𝐷
) as digit 1 (digit 2) with high confidence. (c) After training task B, beneficial
perturbations (task specific) for task A push the drifted representation of inputs
from task A back to its initial optimal working region of task A. 𝑅 ( 𝑅 ) is the
classification region (region of constant estimated label) of digit 3 (digit 4) from
the MNIST dataset. 𝐵𝐷
𝐵𝐷
is a beneficial direction for digit 1 (digit 2).
During the training of task A, the network has been trained on two images from
digit 1 (1
and 1
) and two images from digit 2 (2
and 2
). Thus, the
beneficial perturbations for task A are the vector ( 𝐵𝐷
𝐵𝐷 𝐵𝐷 𝐵𝐷 ).
After training task B, with gradient descent, point 𝐴 in b) is drifted to the 𝐴
which lies inside of the classification regions of task B ( 𝑅 or 𝑅 ). The drifted
point 𝐴 alone cannot be correctly classified as digit 1 or 2 because it lies
outside of the classification region of task A ( 𝑅 or 𝑅 ). At test time, adding
beneficial perturbations for task A to the activations of 𝐴 , can drag it back the
correct classification regions for task A (intersection of 𝑅 and 𝑅 ). Thus, it
biases the network's outputs toward the correct classification region and push
task representations back to their initial task-optimal working region.
64
Figure 3.4. Beneficial perturbation network (BD + EWC or BD + PSP variant)
with two tasks. (a) Structure of beneficial perturbation network. (b) Train on
task A. Backpropagating through the network to bias units for tasks A in
beneficial direction (FGSD) using input's own correct class (digits label 1 and
2), normal weights (gradient descent). (c) Test on task A. Feed the input images
to the network. Activating bias units for task A and adding the stored beneficial
perturbations to the activations. The beneficial perturbations bias the network
to mode on classifying digits 1, 2 task. (d) Train on task B. Backpropagating
through the network to bias units for tasks B in beneficial direction (FGSD) using
input's own correct class (digits label 3 and 4), normal weights (constrained by
EWC or PSP). (e) Test on task B. Feed the input images to the network.
Activating bias units for task B and adding the stored beneficial perturbations to
the activations. The beneficial perturbations bias the network to mode on
classifying digits 3, 4 task.
3.6 Beneficial perturbation network
We implemented two variants of BPN: BD + EWC and BD + PSP (Experiments).
65
The backbone - BD (updating extra out-of-network bias units in beneficial
directions to create beneficial perturbations) is the same for both methods. The
only difference is BD + EWC (BD + PSP) uses EWC (PSP) method to retrain
the normal weights while attempting to minimize disruption of old tasks. Here,
we choose BD + EWC to explain our method (for BD + PSP, see
Supplementary). We use a scenario with two tasks for illustration; task A -
recognizing MNIST digit 1s, 2s, task B - recognizing MNIST digit 3s, 4s. BPN
has task-dependent bias units 𝐵𝐼𝐴 𝑆 ∈𝑅
, K is the number of normal
neurons in each layer, i is the layer number, and t is the task number) in each
layer to store the beneficial perturbations. The beneficial perturbations are
formulated as an additive contribution to each layer's weighted activations.
Unlike most adversarial perturbations, beneficial perturbations are not specific
to each example, but are applied to all examples in each task (Fig. 3.3 c, d).
We define beneficial perturbations as a task-dependent bias term:
𝑉 σ𝑊 𝑉 𝑏
𝐵𝐼𝐴𝑆 ∀ 𝑖 ∈ 1, 𝑛 1
where 𝑉 is the activations at layer i, 𝑊 is the normal weights at layer i,
𝐵𝐼𝐴𝑆
is the task dependent bias units at layer i for task t, σ(.) is the nonlinear
activation function at each layer, 𝑏 is the normal bias term at layer i, n is the
number of layers.
For a simple fully connected network (Fig. 3.4 a), the forward functions are
𝑉 σ 𝑊 𝑋 𝑏
𝐵𝐼𝐴𝑆 2
𝑉 𝜎 𝑊 𝑉 𝑏
𝐵𝐼𝐴𝑆 3
𝑦 𝑆𝑜𝑓𝑡𝑚𝑎𝑥 𝑊 𝑉 𝑏
𝐵𝐼𝐴𝑆 4
where y is the output logits, 𝑋 is the input data for task t, Softmax is the
normalization function, other notations are the same as in Eqn.1. During the
training of a specific task, the bias units are the product of two terms: 𝑀 ∈𝑅
and 𝑊 ∈𝑅
(H is the hidden dimension (a hyper-parameter), K is the
number of normal neurons in each layer, and t is the task number). After training
a specific task, we discard both 𝑀 and 𝑊 , and only keep their product 𝐵𝐼𝐴𝑆
,
66
reducing memory and parameter costs to a negligible amount (0.3% increase
for parameters per task, and 4*K Bytes increase per layer per task, it is just a
bias term). After training on different sequential tasks, at test time, the stored
beneficial perturbations from the specific bias units can bias the neural network
outputs to each task. Thus, these allow the BPN to switch into different modes
to process different tasks. We use the forward and backward rules (Alg.1, Alg.2)
to update the BPN.
For training, first, during the training of task A, our goal is to maximize the
probability 𝑃𝑦 𝑦 𝑋
, 𝑊 , 𝐵𝐼𝐴 𝑆 ∀ 𝑖 ∈ 1, 𝑛 by selecting the bias units
corresponding to tasks A . Thus, we set up our optimization function as:
where 𝑦 is the true label for data in task A (MNIST input images 1, 2), 𝑋 is
the data for task A, other notations are the same as notations in Eqn.1. We
update 𝑀 in the beneficial direction (FGSD) as ϵ𝑠𝑖𝑔𝑛 ∇
𝐿𝑀 , 𝑦 to
generate beneficial perturbations for task A, where 𝑀 are the first term of bias
units for task A. We update 𝑊 (the second term of bias units for task A) in
the gradient direction. The factorization allows the bias units for task A to better
learn the beneficial perturbations for task A (a vector towards the work space of
task A that has non-negligible network response for MNIST digits 1, 2, similar
to Fig.3.3 b, c ). We use a softmax cross entropy loss to optimize Eqn.1. After
training task A, the bias units for task A 𝐵𝐼𝐴𝑆
are the product of 𝑀 and 𝑊 .
We discard 𝑀 and 𝑊 to reduce the memory storage and parameter costs
and freeze the 𝐵𝐼𝐴𝑆
to ensure that the beneficial perturbations are not being
corrupted by other tasks (Task B). Then, we discard all of the MNIST input
images 1, 2 because all of the information is stored inside the bias units for task
A and we do not need to replay these images when we train on the following
sequential tasks.
67
After training task A, during the training of task B (Fig.3.4 d), our goal is to
maximize the
Probability 𝑃𝑦 𝑦 𝑋
, 𝑊 , 𝐵𝐼𝐴 𝑆 by selecting the bias units corresponding
to tasks B. To minimize the disruption for task A, we apply EWC or PSP
constraints on normal weights. We set up our optimization function as
where 𝑦 is the true label for data in task B (MNIST input images 3, 4), 𝑋 is
the data for task B, EWC(.) is the EWC constraint
93
on normal weights, other
notations are the same as in Eqn.1. In the loss function of Alg.2,λ𝐹
𝑊
𝑊
∗
is the EWC constraint on the normal weights, where j labels each parameter,
𝐹 is the Fisher information matrix for each parameter j (determine which
parameters are most important for a task
93
), λ sets how important the old task
is compared to the new one, 𝑊 is normal weight j, and 𝑊 is the optimal
normal weight j after training on task A. Apart from the additional EWC
constraint, training task B and all subsequent tasks then simply proceeds in the
same manner as for task A above.
For testing, after training task B, we test the accuracy for task A on a test set by
manually activating the bias units corresponding to task A (Fig.3.4 c, Alg.1).
Although the shared normal weights have been contaminated by task B, the
integrity of bias units for task A that store the beneficial perturbations still can
bias the network outputs to task A (set the network into a mode to process input
from task A, see Results). In another word, the task-dependent bias units can
still maintain a high probability - 𝑃𝑦 𝑦 𝑋
, 𝑊 , 𝐵𝐼𝐴 𝑆 for task A. During
testing of task B, we test the accuracy for task B on a test set by manually
activating the bias units corresponding to task B (Fig. 3.4 e, Alg. 1). The bias
68
units for task B can bias the network outputs to task B and maintain a high
probability - 𝑃𝑦 𝑦 𝑋
, 𝑊 , 𝐵𝐼𝐴 𝑆 for task B, in case the shared normal
weights are further modified by later tasks. In scenarios with more than two
tasks, the forward and backward algorithms for later tasks are the same as for
task B, except that they will select and update their own bias units.
In sum, beneficial perturbations act upon the network not by adding biases to
the input data (like adversarial examples do, Fig.3.3 a), but instead by dragging
the drifted activations back to the correct working region in activation space for
the current task ( Fig.3.1 and Fig.3.3 c). The intriguing properties of task-
dependent beneficial perturbations on maintaining high probabilities for
different tasks can further be explained in two ways. The beneficial
perturbations from the bias units can be viewed as features that capture how
"furry" the images are for task A (or B). Olshausen et al.
113
showed that training
a neural network only on these features is sufficient to make correct
classification on the dataset that generates these features. They argued that
these features have sufficient information for a neural network to make correct
classification. In our continual learning scenarios, although the shared normal
weights ( 𝑊 ) have been contaminated after the sequential training of all tasks,
by activating corresponding bias units, the task-dependent bias units still have
sufficient information to bias the network toward that task. In other words, the
task-dependent bias units can maintain high probabilities -
𝑃𝑦 𝑦 𝑋
, 𝑊 , 𝐵𝐼𝐴 𝑆 for task A or 𝑃𝑦 𝑦 𝑋
, 𝑊 , 𝐵𝐼𝐴 𝑆 for task B. Thus,
bias units can assist the network to make correct classification. In addition,
Elsayed et al.
114
showed how a carefully computed adversarial perturbations for
each new task embedded in the input space can repurpose machine learning
models to perform a new task. Here, these beneficial perturbations can be
viewed as task-dependent beneficial "programs"
114
in the parameter space.
Once activated, these task-dependent "programs" can maximize the probability
69
for corresponding tasks.
3.7 Experiment
3.7.1 Experimental Setup for Incremental Tasks
To demonstrate that BPN is very parameter efficient and can learn different
tasks in an online and continual manner, we used a fully-connected neural
network with 5 hidden layers of 300 ReLU units. We tested it on three public
computer vision datasets with "single-head evaluation", where the output space
consists of all the classes from all tasks learned so far.
70
1. Incremental MNIST. A variant of the MNIST dataset
102
of handwritten digits
with 10 classes, where each task introduces a new set of classes. We
consider 5 tasks; each new task concerns examples from a disjoint subset
of 2 classes.
2. Incremental CIFAR-10. A variant of the CIFAR object recognition dataset
103
with 10 classes. We consider 5 tasks; each new task has 2 classes.
3. Incremental CIFAR-100. A variant of the CIFAR object recognition dataset
103
with 100 classes. We consider 10 tasks; each new task has 2 classes. We
use 20 classes for CIFAR-100 experiment.
3.7.2 Experimental Setup for Eight Sequential Object Recognition Tasks
To demonstrate the superior performance of BPN across different datasets and
domains, we consider a sequence of eight object recognition datasets:
1. Oxford Flowers
115
for fine-grained flower classification (8,189 images in 102
categories);
2. MIT Scenes
116
for indoor scene classification (15,620 images in 67
categories);
3. Caltech-UCSD Birds
117
for fine-grained bird classification (11,788 images in
200 categories);
4. Stanford Cars
118
for fine-grained car classification (16,185 images of 196
categories);
5. FGVC-Aircraft
119
for fined-grained aircraft classification (10,200 images in
70 categories);
6. VOC actions
120
, the human action classification subset of the VOC
challenge 2012 (3,334 images in 10 categories);
71
7. Letters, the Chars74K datasets
121
for character recognition in natural
images (62,992 images in 62 categories);
8. the Google Street View House Number SVHN dataset
122
for digit recognition
(99,289 images in 10 categories).
To have a fair comparison, we use the same AlexNet
123
architecture
pretrained on ImageNet
124
as Aljundi et al.
125,126
, and tested on 8 sequential
tasks with a "multi-head evaluation", where each task has its own classification
layer (introduce same parameter costs for every method) and output space. All
methods have a task oracle at test time to decide which classification layer to
use. We run the different methods on the following sequence: Flower -> Scenes
-> Birds -> Cars -> Aircraft -> Action -> Letters -> SVHM.
3.7.3 Experimental Setup for 100 permuted MNIST dataset
T o demonstrate that BPN has capacity to accommodate a large number of tasks,
we tested it on 100 permuted MNIST datasets generated from randomly
permuted handwritten MNIST digits. We consider 100 tasks; each new task has
10 classes. We used a fully connected neural network with 4 hidden layers of
128 ReLu Units (a core network with small capacity) to compare the
performances of different methods. The type 4 methods, such as the Parameter
Superposition (PSP
94
) would exhaust the unrealized capacity and inevitably
dilute the capacity of the core network under a large number of tasks: in their
Fig.3.2, with a network that has 128 hidden units (leftmost panel), the average
task performance for all tasks trained so far, is 95% after training one task, but
decreases to 50% after training fifty tasks. While a larger network with 2048
hidden units shows much smaller decrease from 96% to about 90% (see Fig.
3.2 in their paper, rightmost panel). The reason is that this method generates a
random diagonal binary matrix for each task, which in essence is a key or
selector for that task. As more and more tasks are learned, those keys start to
overlap more, causing interference among tasks. In comparison, BPN can
72
counteract the dilution, hence it can accommodate a large number of tasks.
3.7.4 Our model and baselines
We compared the proposed Beneficial Perturbation Network ( Beneficial
Perturbation + Elastic Weight Consolidation (the eleventh model), BD + EWC
(variant 1) and Beneficial Perturbation + Parameter Superposition (the twelfth
model), BD + PSP (variant 2)) to 11 alternatives to demonstrate its superior
performance.
1. Single Task Learning (STL). We consider several 5-layer fully-connected
neural networks. Each network is trained for each task separately. Thus, STL
does not suffer from catastrophic forgetting at all. It is used as an upper bound.
2. Elastic Weight Consolidation (EWC)
93
. The loss is regularized to avoid
catastrophic forgetting.
3. Gradient Episodic Memory with task oracle (GEM (*))
97
, GEM uses a task
oracle to build a final linear classifier (FLC) per task. The final linear classifier
adapts the output distributions to the subset of classes for each task. GEM uses
an episodic memory to store a subset of the observed examples from previous
tasks, which are interleaved with new data from the latest task to produce a
new classifier for all tasks so far. We use notation GEM (*) for the rest of the
paper, where * is the size of episodic memory (number of training images stored)
for each class.
4. Incremental Moment Matching (IMM)
109
, IMM incrementally matches the
moment of the posterior distribution of the neural network with a L2 penalty and
equally applies it to changes to the shared parameters.
73
5. Learning without forgetting (LwF)
111
First, LwF freezes the shared
parameters while learning a new task. Then, LwF trains all the parameters until
convergence.
6. Encoder based lifelong learning (EBLL)
105
Based on LwF, using an
autoencoder to capture the features that are crucial for each task.
7. Synaptic Intelligence (SI)
127
While training on new task, SI estimates the
importance weights in an online manner. Parameters important for previous
tasks are penalized during the training of new task.
8. Memory Aware Synapses (MAS)
125
Similar to SI method, MAS estimates the
importance weights through the sensitivity of the learned function on training
data. Parameters important for previous tasks are penalized during the training
of new task.
9. Sparse coding through Local Neural Inhibition and Discounting (SLNID)
126
SLNID proposed a new regularizer that penalizes neurons that are active at the
same time to create sparse and decorrelated representations for different tasks.
10. Parameter Superposition (PSP)
94
PSP used task-specific context matrices
to map different inputs from different tasks to different subspaces spanned by
rows in the weight matrices to avoid interference between tasks. We use the
binary superposition model of PSP throughout the paper, because it is not only
more memory efficient, but also, in our testing, it performed better than other
PSP variants (e.g., complex superposition).
11. BD + EWC (ours): Beneficial Perturbation Network (variant 1). The first term
74
𝑀 of the bias units is updated in the beneficial direction (BD) using FGSD
method. The second term 𝑊 of the bias units is updated in the gradient
direction. The normal weights are updated with EWC constraints.
12. BD + PSP (ours): Beneficial Perturbation Network (variant 2). The first term
( 𝑀 ) of the bias units is updated in the beneficial direction (BD) using FGSD
method. The second term ( 𝑊 ) of the bias units is updated in the gradient
direction. The normal weights are updated using PSP (binary superposition
model, Supplementary).
13. GD + EWC: The update rules and network structure are the same as BD +
EWC, except the first term ( 𝑀 ) of the bias units is updated in the Gradient
direction (GD). This method has the same parameter costs as BD + EWC . The
failure of GD + EWC suggests that the good performance of BD + EWC is not
from the additional dimensions provided by bias units.
3.8 Results
3.8.1 The beneficial perturbations can bias the network and maintain the
decision boundary
To show the advantages of our method are really from the beneficial
perturbations and not just from additional dimensions to the neural network, we
compare between updating the first term of the bias units in the beneficial
direction (BD + EWC which comes from beneficial perturbations) and in the
gradient direction (GD + EWC, which just comes from the additional dimensions
that our bias units provide). We use a toy example (classifying 3 groups of
Normal distributed clusters) to demonstrate it and to visualize the decision
boundary (Fig.3.5). We randomly generate 3 normal distributed clusters
different locations. We have two tasks - Task 1: separate the black cluster from
75
the red cluster. Task 2: separate the black cluster from the light blue cluster. The
yellower (bluer) the heatmap, the higher (lower) the confidence that the neural
network classifies a location into the black cluster. After training task 2, both
plain gradient descent and GD + EWC forget task 1 (dark blue boundary around
the red cluster disappeared). However, BD + EWC not only learns how to
classify task 2 (clear decision boundary between light blue and black clusters),
but also remembers how to classify the old task 1 (clear decision boundary
between red and black clusters). Thus, the beneficial perturbations are what
can bias the network outputs and maintain the decision boundary for each task,
not just adding more dimensions.
Figure 3.5. Visualization of classification regions: classify 3 randomly generated
normal distributed clusters. Task 1: separate black from red clusters. Task 2:
separate black from light blue clusters. The yellower (bluer) the heatmap, the
higher (lower) the chance the neural network classifies a location as the black
cluster. After training task 2, only BD + EWC remembers task 1 by maintaining
its decision boundary between the black and red clusters. Both plain gradient
descent and GD + EWC forget task 1 entirely.
76
3.8.2 Quantitative analysis for incremental tasks
Our BPN achieves a comparable or better performance than PSP, GEM, EWC,
GD + EWC in "single-head" evaluations, where the output space consists of all
the classes from all tasks learned so far. In addition, it introduces negligible
parameter and memory storage costs per task. Fig.3.6 and Tab.3.1 summarize
performance for all datasets and methods. STL has the best performance since
it trained for each task separately and did not suffer from catastrophic forgetting
at all. Thus, STL is the upper bound. BD + EWC performed slightly worse than
STL (1%,4%,1% worse for incremental MNIST, CIFAR-10, CIFAR-100
datasets). BD + EWC achieved comparable or better performance than GEM.
On incremental CIFAR-100 (10 tasks, 2 classes per task), BD + EWC
outperformed PSP, GEM (256) and GEM (10) by 1.80%, 6.96%, and 22.4%.
BD + PSP outperformed PSP, GEM (256) and GEM (10) by 2.40%, 7.59%, and
23.1%. By comparing the memory storage costs (Tab.3.1, Supplementary), to
achieve similar performance, BD + EWC only introduces an additional 4,808
Bytes memory per task, which is only 0.1% of the memory storage cost required
by GEM (256). BD + PSP only introduces 20,776 Bytes, or 0.44% of the
memory storage cost required by GEM (256). The memory storage costs of BD
+ EWC is 30% of that of PSP. The memory storage costs of BD + PSP is of the
same order of magnitude as PSP . EWC alone rapidly decreased to 0% accuracy.
This confirms similar results on EWC performance on incremental
datasets
92,128–130
in "single-head" evaluations although EWC generally
performs well in "multi-head" tasks. GD + EWC has the same additional
dimensions as BD + EWC, but GD + EWC failed in the continual learning
scenario. This result suggests that it is not the additional dimensions of the bias
units, but the beneficial perturbations, which help overcome catastrophic
forgetting.
77
Figure 3.6. Results for a fully-connected network with 5 hidden layers of 300
ReLU units. (a) Incremental MNIST tasks (5 tasks, 2 classes per task). (b)
Incremental CIFAR-10 tasks (5 tasks, 2 classes per task). For a and b, the
dashed line indicates the start of a new task. The vertical axis is the accuracy
for each task. The horizontal axis is the number of epochs. (c) Incremental
CIFAR-100 tasks (10 tasks, 2 classes per task). The vertical axis is the accuracy
for task 1. The horizontal axis is the number of tasks
TABLE 3.1
Task 1 performance with "single-head" evaluation after training all sequential
tasks on incremental MNIST, CIFAR-10 and CIFAR-100 Dataset. We include
additional memory storage costs per task (extra components that are necessary
to be stored onto the disks after training each task, Supplementary) of GEM ,
BD+EWC, BD + PSP and PSP method.
78
3.8.3 Quantitative analysis for eight sequential object recognition tasks
The eight sequential object recognition tasks demonstrate the superior
performance of BPN (BD + PSP or BD + EWC) compared to the state-of-the-
art and the ability to learn sequential tasks across different datasets and
different domains. Our BPN achieves much better performance than IMM [28],
LwF [30], EWC [23], EBLL [40], SI [55], MAS [1], SLNID [2], PSP [4] in "multi-
head" evaluations, where each task has its own classification layer and output
space. After training on the 8 sequential object recognition datasets, we
measured the test accuracy for each dataset and calculated their average
performance (Tab.3.2). On average, BD + PSP (ours) outperforms all other
methods: PSP (7.52% better), SLNID (8.02% better), MAS (11.73% better), SI
(16.60% better), EBLL (17.07% better), EWC (17.75% better), LwF (18.96%
better) and IMM (35.62% better). Although MAS, SI and EBLL performed better
than EWC alone, with the help of our beneficial perturbations (BD), BD + EWC
can achieve a better performance than these methods: MAS (0.34% better), SI
(4.71% better), EBLL (5.13% better) and EWC (5.74% better). By including the
BD (BD + PSP and BD + EWC), we can significantly boost performance when
compared to using PSP or EWC alone (black arrows in the Tab.3.2).
TABLE 3.2
Test accuracy (in percent correct) achieved by each method with "multi-head"
evaluation for each dataset after training on the 8 sequential object recognition
79
datasets. (Dash (--) means that the results are not available in their papers. Star
(*) means that we didn't reproduce the methods and the results were taken from
SLNID [2] and MAS [1]. Thus, we keep the same percentage table format as
theirs).
3.8.4 Quantitative analysis for 100 permuted MNIST dataset
100 permuted MNIST dataset demonstrates that our BPN has capacity to
accommodate a large number of tasks. After training 100 permuted MNIST
tasks, the average task performance of BD + PSP is 30.14% better than PSP.
The average task performance of BD + EWC is 35.47% higher than EWC
(Fig.3.7.a). As the number of tasks increases (Fig.3.7.a), the average task
performance of BD + PSP becomes increasingly better than PSP. The reason
is that adding new tasks significantly dilutes the capacity of the original network
in Type 4 methods (e.g., PSP) as there are limited routes or subspaces to form
sub-networks. In this case, even though the core network can no longer fully
separate each task, the Beneficial perturbations (BD) can drag the
misrepresented activations back to their correct work space of each task and
recover their separation (as demonstrated in Fig.3.3). Thus, the BD of BD +
PSP can still increase the capacity of the network and boost the performance.
Similarly, BD components in BD + EWC can boost performance, increasing the
capacity of the network to accommodate more tasks than EWC alone
(Fig.3.7.b). In addition, after training 100 tasks (Fig.3.7. b), the accuracy of BD
+ EWC for the first 50 tasks is higher than PSP, likely because BD+EWC did
80
not severely dilute the core network's capacity while PSP did. This means BD
+ EWC has a larger capacity than PSP. In contrast, the lower performance of
the last 50 tasks for BD + EWC comes from the constraints of EWC (do not
allow the parameters of the network learned from the new tasks to have large
deviations from the parameters trained from old tasks). Although the
performance of PSP is much better than EWC, with the help of BD, BD + EWC
still reaches a similar performance as PSP.
Figure 3.7. 100 permuted MNIST datasets results for a fully-connected network
with 4 hidden layers of 128 ReLU units. This network is relatively small for these
81
tasks and hence does not offer much available redundancy or unrealized
capacity. (a) The average task accuracy of all tasks trained so far as the number
of tasks increases. (b) After training 100 tasks, the average task accuracy for a
group 10 tasks. We use t-test to validate the results.
3.9 Discussion
We proposed a fundamentally new biologically plausible type of method -
beneficial perturbation network (BPN), a neural network that can switch into
different modes to process independent tasks, allowing the network to create
potentially unlimited mappings between inputs and outputs. We successfully
demonstrated this in the continual learning scenario. Our experiments
demonstrate the performance of BPN is better than the state-of-the-art. 1) BPN
is more parameter efficient (0.3% increase per task) than the various network
expansion and network partition methods. it does not need a large episodic
memory to store any data from previous tasks, compared to episodic memory
methods, or large context matrices, compared to partition methods. 2) BPN
achieves state-of-the-art performance across different datasets and domains.
3) BPN has a larger capacity to accommodate a higher number of tasks than
the partition networks. Through visualization of classification regions and
quantitative results, we validate that beneficial perturbations can bias the
network towards a task, allowing the network to switch into different modes.
Thus, BPN significantly contributes to alleviating catastrophic forgetting and
achieves much better performance than other types of methods.
Elsayed et al.
114
showed how carefully computed adversarial perturbations
embedded in the input space can repurpose machine learning models to
perform a new task without changing the parameters of the models. This attack
finds a single adversarial perturbation for each task, to cause the model to
perform a task chosen by the adversary. This adversarial perturbation can thus
82
be considered as a program to execute each task. Here, we leverage similar
ideas. But, in sharp contrast, instead of using malicious programs embedded in
the input space to attack a system, we embedded beneficial perturbations
('beneficial programs') into the network's parameter space (the bias terms),
enabling the network to switch into different modes to process different tasks.
The goal of both approaches is similar - maximizing the probability P(current
task | image input, program)) of the current task given the image input and the
corresponding program for the current task. This can be achieved by either
forcing the network to perform an attack task in Elsayed et al., or assisting it to
perform a beneficial task in our method. The addition of programs to either input
space (Elsayed et al.'s method) or the network's activation space (our method)
helps the network maximize this probability for a specific task.
We suggest that the intriguing property of the beneficial perturbations that can
bias the network toward a task might come from the property of adversarial
subspaces. Following the adversarial direction, such as by using the fast
gradient sign method (FGSD)
131
, can help in generating adversarial examples
that span a continuous subspace of large dimensionality (adversarial subspace).
Because of “excessive linearity” in many neural networks
112,132
, due to features
including Rectified linear units and Maxout, the adversarial subspace often
takes a large portion of the total input space. Once an adversarial input lies in
the adversarial subspace, nearby inputs also tend to lie in it. Interestingly, this
corroborates recent findings by Ilyas et al.
113
that imperceptible adversarial
noise can not only be used for adversarial attacks on an already-trained network,
but also as features during training. For instance, after training a network on
dog images perturbed with adversarial perturbation calculated from cat images,
the network can achieve a good classification accuracy on the test set of cat
images. This result shows that those features (adversarial perturbations)
calculated from the cat training sets, contain sufficient information for a machine
learning system to make correct classification on the test set of cat images. In
83
our method, we calculate those features for each task, and store them into the
bias units. In this case, although the normal weights have been modified
(information from old tasks are corrupted), the stored beneficial features for
each task have sufficient information to bias the network and enable the
network to make correct predictions.
BPN is loosely inspired by its counterpart in the human brain: having task-
dependent modules such as bias units in our Beneficial Perturbation Network,
and long-term memories in hippocampus (HPC,
133
) in a brain network, are
crucial for a system to switch into different modes to process different tasks.
During weight consolidation, the HPC
134–137
fuses features from different tasks
into coherent memory traces. Over days to weeks, as memories mature, the
HPC progressively stores permanent abstract high-level long-term memories to
remote memory storage areas (neocortical regions). The HPC can then
maintain and mediate their retrieval independently when a specific memory is
in need.
We suggest that when a specific memory is retrieved, it helps the HPC switch
into distinct modes to process different tasks. Thus, our analogy between HPC
and BPN can be formulated as: during the training of BPN, updating the shared
normal weights using EWC or PSP in theory leads to distinct task-dependent
representations (similar to the coherent memory traces in HPC). However,
some overlap between these representations is inevitable because model
parameters become too constrained for EWC, or PSP runs out of unrealized
capacity of the core network. To circumvent this effect, Bias units (akin to the
long-term memories in the neocortical areas) are trained independently for each
task. At test time, bias units for a given task are activated to push
representations of old tasks back to their initial task-optimal working regions in
an analogous manner to maintaining and mediating the retrieval of Long-term
memories independently in HPC.
84
An alternative biological explanation evokes the concept of factorized codes. In
biological neuronal populations, neurons can be active for one task or, in many
cases, for more than one tasks. At the population level, different tasks are
encoded by different neuronal ensembles which can overlap. In our model, the
PSP component deploys binary keys to activate task-specific readouts in
hidden layers, in an analogy to neuronal task ensembles. When activating a BD
component for a task, we would be further disambiguating a task-specific
ensemble, particularly across neurons which are active for more than one task.
The reason for this is that adding task-specific beneficial perturbations to
activations of hidden layers can shift the distribution of the net activation (akin
to a DC offset or carrier frequency). Evidence from nonhuman primate
experiments
138,139
and human behavioral results
140
support this factorized code
theory. Electrophysiological experiments using monkeys demonstrated that
neurons in prefrontal cortex are either representing competing categories
independently
138
or could represent multiple categories
139
. In human behavior
experiments, "humans tend to form factorized representation that optimally
segregated the tasks
140
". In addition, recent neural network simulations
141
demonstrated that "network developed mixed task selectivity similar to
recorded prefrontal neurons after learning multiple tasks sequentially with a
continual learning technique". Thus, having factorized representations for
different tasks is important for enabling life-long learning and designing a
general adaptive artificial intelligence system.
4 Beneficial Perturbation Network for Defending Adversarial examples
4.1 Abstract
Deep neural networks can be fooled by adversarial attacks: adding carefully
computed small adversarial perturbations to clean inputs can cause
85
misclassification on state-of-the-art machine learning models. The reason is
that neural networks fail to accommodate the distribution drift of the input data
caused by adversarial perturbations. Here, we present a new solution -
Beneficial Perturbation Network (BPN) - to defend against adversarial attacks
by fixing the distribution drift. During training, BPN generates and leverages
beneficial perturbations (somewhat opposite to well-known adversarial
perturbations) by adding new, out-of-network biasing units. Biasing units
influence the parameter space of the network, to preempt and neutralize future
adversarial perturbations on input data samples. To achieve this, BPN creates
reverse adversarial attacks during training, with very little cost, by recycling the
training gradients already computed. Reverse attacks are captured by the
biasing units, and the biases can in turn effectively defend against future
adversarial examples. Reverse attacks are a shortcut, i.e., they affect the
network's parameters without requiring instantiation of adversarial examples
that could assist training. We provide comprehensive empirical evidence
showing that 1) BPN is robust to adversarial examples and is much more
running memory and computationally efficient compared to classical adversarial
training. 2) BPN can defend against adversarial examples with negligible
additional computation and parameter costs compared to training only on clean
examples; 3) BPN hurts the accuracy on clean examples much less than classic
adversarial training; 4) BPN can improve the generalization of the network 5)
BPN trained only with Fast Gradient Sign Attack can generalize to defend PGD
attacks.
4.2 Introduction
Neural networks have led to a series of breakthroughs in many fields, such as
image classification tasks
142,143
, and natural language processing
142,144
.
Model performance on clean examples was the main evaluation criterion for
these applications until the unveiling of weaknesses to adversarial attacks by
86
Szegedy et al. and Biggio et al.
145,146
. Neural networks were shown to be
vulnerable to adversarial perturbations: carefully computed small perturbations
added to legitimate clean examples to create so-called "adversarial examples"
can cause misclassification on state-of-the-art machine learning models. The
reason is that adding adversarial perturbations to the input image introduces a
distribution drift in the input data. Although the adversarial perturbations are
often too small to be recognized by human eyes, the resulting distribution drift
is sufficient to cause misclassification on machine learning models. To fix
distribution drifts, a question arises: can we simulate reverse adversarial
attacks during training to preempt and neutralize the effects of future adversarial
perturbations?
In this paper, we define the Beneficial Perturbations Network (BPN). BPN
introduces a reverse adversarial attack to defend against adversarial examples.
The key new idea is that BPN generates and leverages beneficial perturbations
during training (somewhat opposite to adversarial perturbations, check Eqn.5,
Eqn.6 and Eqn.7 for detailed mathematical expressions) stored in extra, out-of-
network biasing units. These units can influence the parameter space of the
network, to fix distribution drifts at test time by neutralizing the effects of
adversarial perturbations on data samples. The central difference between
adversarial and beneficial perturbations is that, instead of adding input "noise"
at test time (adversarial perturbations) calculated from other classes to force
the network into misclassification, we add "noise" during training to the
parameter space (beneficial perturbations), calculated from the input's own
correct class to assist correct classification.}
87
Figure 4.1: Difference in training pipelines between adversarial training and
BPN to defend against adversarial examples. (a) classical adversarial training
has two steps: (1) Generating adversarial perturbations from corresponding
clean examples and adding adversarial perturbations to the clean examples
(creation of adversarial examples). (2) Training the network, usually on both
clean and adversarial examples. (b) BPN creates a shortcut with only one step:
training on clean examples. The feasibility of this shortcut is because BPN can
generate beneficial perturbations during training on clean examples, with
negligible additional costs, 0% (0.006%) increase for forward (backward) pass
compared to training a network on clean examples. BPN is more
computationally and running memory efficient compared to typical adversarial
training, by recycling the gradients already computed during clean training to
create beneficial biases directly in the parameter space of the network, instead
of having to instantiate new perturbed images. The learned beneficial
perturbations can neutralize the effects of adversarial perturbations of the data
samples at test time.
88
We evaluated BPN on multiple datasets (MNIST, FashionMNIST and
TinyImageNet) on three experimental scenarios:
i. Training a network on clean examples only (our main use case scenario).
This experimental scenario is preferable in the case of a modest computational
budget, and where one wants to preserve clean sample accuracy while still
achieving moderate robustness to adversarial examples. In this case, BPN can
defend against adversarial examples with negligible additional computation
costs (0% increase for forward pass and 0.006% for backward pass) when
compared to simple clean training. As a comparison, during so-called
adversarial training which is the current SOTA (see Sec.4.3), the network
creates one or more adversarial examples per clean sample which means at
least twice the computational power.
ii. Training on adversarial examples only. This scenario can be used when
having a modest computational budget that prioritizes robustness to adversarial
examples while still wanting to preserve some amount of clean sample accuracy.
When using only adversarial examples, the decision boundaries more sensitive
to adversarial directions are strengthened, and this has an interesting effect of
indirectly causing the model to learn some degree of clean sample
representation. When compared to a classic network trained only on adversarial
examples, BPN is more robust to future adversarial attacks, while also
performing much better on clean samples that in fact it has never been trained
on.
iii. Training on both clean and adversarial examples. This experimental
scenario can be used when having abundant computational budget. In this case,
BPN is shown to be marginally superior to classical adversarial training on both
clean and adversarial examples. The reason is that BPN can further improve
89
the generalization of the network through diversification of the training set
147–
151
. In addition, networks trained with classical adversarial learning have very
poor generalizability to attacks that they have not been trained on. It is infeasible
and expensive to introduce all unknown attack samples into the adversarial
training
148
. Here, we found experimentally that BPN trained only with FGSM
can not only defend FGSM attacks pretty well, but also generalize to defend
attacks that it has never been trained on (e.g., PDG attack).
To lay out the foundation of our approach we start by introducing the following
key concepts: Sec.4.3: adversarial training; in Sec.4.4, we explain the
difference between BPN and adversarial training in fixing distribution drifts of
input data (Sec.4.4.1) & structure, updating rules (Sec.4.4.2, Sec.4.4.5),
computation costs (Sec.4.4.4) and extension to deep convolutional network
(Sec.4.4.6) for BPN. We then present experiments (Sec.4.5), results (Sec.4.6)
and discussion (Sec.4.7).
4.3 Related Work – adversarial training
Researchers have proposed a number of adversarial defense strategies to
increase the robustness of deep learning systems. Adversarial training
131,152
,
in which a network is trained on both adversarial examples ( 𝑥
) and clean
examples ( 𝑥
) with class labels 𝑦 , is perhaps the most popular defense
against adversarial attacks, withstanding strong attacks. Adversarial examples
are the summation of adversarial perturbations lying inside the input space ( δ
)
and clean examples: 𝑥
𝑥
δ
. Given a classifier with a classification
loss function 𝐿 and parameters θ, the objective function of adversarial training
is:
min
𝐿 𝑥
, 𝑥
, 𝑦 ; θ 1
However, adversarial training suffers from at least three difficulties:
90
1. Expensive in terms of running memory and computation costs. On larger
datasets, such as ImageNet, adversarial training can take multiple days on a
single GPU. Kannan et al.
153
used 53 P100 GPUs and Xie et al.
154
used 100
V100s for target adversarial training on ImageNet. Tramer et al.
148
generate
more than one adversarial example for each clean example. These
implementations require at least double the amount of running memory on GPU,
to store those adversarial examples alongside the clean examples. In addition,
during adversarial training, the network has to train on both clean and
adversarial examples; hence, adversarial training typically requires at least
twice the computation power than just training on clean examples.
2. Accuracy trade-off.
Although adversarial training can improve robustness against adversarial
examples, it sometimes hurts accuracy on clean examples. Thus, there is an
accuracy trade-off between the adversarial examples and clean examples
147,149–151
. Because most of the test data in real applications are clean examples,
test accuracy on clean data should be as good as possible. Thus, this accuracy
trade-off hinders the practical usefulness of adversarial training because it often
ends up lowering performance on the original dataset.
3. Impractical to foresee multiple attacks.
Networks trained with classical adversarial learning have very poor
generalizability to attacks that they have not been trained on. Thus, even
though one might have sufficient computational resources to train a network on
both adversarial and clean examples, it is infeasible and expensive to introduce
all unknown attack samples into the adversarial training. For example, Tramer
et al.
148
proposed Ensemble Adversarial Training which can increase the
diversity of adversarial perturbations in a training set by generating adversarial
perturbations transferred from other models. They won the competition on
91
Defenses against Adversarial Attacks, though again at an extraordinary
computation and running memory cost. In summary, a crucial milestone for the
field of adversarial learning is achieving a model that can generalize to unseen
attacks.
4.4 Beneficial Perturbation Network
Three spaces of a neural network are important: 1) The input space is the space
of input data (e.g., pixels of an image); 2) the parameter space is the space of
all the weights and biases of the network; 3) the activation space is the space
of all outputs of all neurons in all layers in the network.
4.4.1 High-level ideas - difference between BPN and adversarial training in
fixing distribution drifts of input data.
Figure 4.2. Difference between adversarial training ( 𝑎 - 𝑎 , input space) and
92
BPN ( 𝑏 - 𝑏 , activation space) in fixing the data distribution drifts caused by
adversarial perturbations on recognizing handwritten digits "1" versus "2". ( 𝑎 ):
After training a model on clean input images, digits "1" and "2" are separated
by a purple decision boundary. ( 𝑎 ): Adding adversarial perturbations to test
input images of digit 1 can be viewed as adding adversarial direction vectors
(red arrows δ
) to the clean (non-perturbated) input images. Such adversarial
vectors cross the decision boundary, forcing the neural network into
misclassification (here from digit 1 to digit 2). ( 𝑎 ): Adversarial training: training
a model on both clean and adversarial examples leads the model to learn a
new decision boundary to incorporate both clean and adversarial examples, but
at great computation and running memory cost. ( 𝑏 ) and ( 𝑏 ) are similar to ( 𝑎 )
and ( 𝑎 ), but they are represented in activation space. ( 𝑏 ): BPN. Theoretically,
beneficial perturbations can be seeing as the opposite of adversarial
perturbations: they act to reverse the latter. Beneficial perturbations work in
activation space instead of the classical input space. In BPN, adding beneficial
perturbations to the activation representation of adversarial examples
corresponds to adding beneficial direction vectors (green arrows δ
) to the
representations of adversarial examples of digit 1. The resulting vectors cross
the decision boundary and drag the misclassified adversarial examples back to
the correct classification region.
First, how do adversarial attacks fool a neural network? For example, consider
a task of recognizing handwritten digits "1" versus "2". (Fig.4.2, 𝑎 , 𝑎 in
input space or their representations Fig.4.2, 𝑏 , 𝑏 in activation space).
Adversarial perturbations aimed at misclassifying an image of digit 1 as digit 2
may be obtained by backpropagating from the class digit 2 to the input space,
following any of the available adversarial directions. In input or activation
space, adding adversarial perturbations to the input image can be viewed as
adding an adversarial direction vector (red arrows δ
) to the clean (non-
perturbated) input image of digit 1. The resulting vector crosses the decision
93
boundary (input distribution drift problem). Therefore, adversarial perturbations
can force the neural network into misclassification, here from digit 1 to digit 2
because the network failed to accommodate the distribution drift of input data.
The primary goal of both adversarial training and BPN is to make the network
more robust to adversarial attacks, but they differ in how they accomplish that.
1) In adversarial training, after training the network on both clean and
adversarial examples, the network learns a new and more robust decision
boundary that can accommodate input distribution drifts (Fig.4.2, 𝑎 , 𝑎 ).
In adversarial training, the decision boundary robustness is achieved via a
data-driven approach. In other words, by including adversarial samples for
each clean image, the decision boundary is enlarged so as to incorporate
both clean and adversarial examples inside the same-label classification
region. In the deployment stage, because the decision boundary is
strengthened, it becomes harder to adversarially fool the network.
2) In BPN, to strengthen the decision boundary, we add beneficial perturbations
to the activation representation of adversarial examples. In Fig.4.2, 𝑏 , 𝑏 , this
corresponds to adding a beneficial perturbations vector (green arrows δ
) to
the activation representations of adversarial examples of digit 1. The resulting
vector crosses the decision boundary and drags the misclassified adversarial
examples back to the correct classification region (recovering from the input
distribution drift caused by adversarial perturbations). Thus, the beneficial
perturbations have the effect of neutralizing the adversarial perturbations and
recovering the clean examples in the activation space. In mathematical terms,
as δ
∗
and δ
cancel out, we have:
𝑥
∗
𝑥
∗
δ
∗
δ
(2)
In Eqn.2, 𝑥
∗
, δ
∗
are activation representations of clean examples and
adversarial perturbations, respectively. As a result, BPN can achieve
94
robustness and correctly classify both clean and adversarial examples by
training only on clean samples. Hence, unlike adversarial training, which
requires several adversarial samples per clean image, BPN achieves the same
goal via a much cheaper route: neutralizing future adversarial attacks with only
the addition of low-cost beneficial perturbations in activation space.
4.4.2 Formulation of beneficial perturbations
Beneficial perturbations are formulated as an additive contribution to each
layer's weighted activations (Fig.4.3 b):
𝑉 𝑊
𝑉 𝑏
(3)
where 𝑊 , 𝑉 and 𝑏
is the weight, activation and beneficial perturbation
bias at layer i. A beneficial perturbation bias has the same structure as the
normal bias term 𝑏 (Fig.4.3 a), but it is used to store the beneficial
perturbations ( δ
).
Figure 4.3. Structure difference between normal network (baseline) and BPN
95
for forward (a-b) and backward pass (c-d). (a) Forward rules of normal network
(baseline). (b) Forward rules of BPN. Beneficial perturbation bias ( 𝑏
) is the
same as normal bias ( 𝑏 ) in forward pass. (c) Backward rules of normal
network (baseline). We only demonstrated the update rules for normal bias term.
(d) Backward rules of BPN. The difference is that we update the beneficial
perturbations bias term using FGSM.
4.4.3 Creation of reverse adversarial attack
Two questions arise. 1. How to create a reverse adversarial attack (beneficial
perturbation). 2. How beneficial perturbation biases are trained in a way that is
opposite to adversarial directions? Instead of adding input "noise" (adversarial
perturbations) to the input space calculated from other classes (as in
adversarial training), in BPN we add "noise" (beneficial perturbations ( δ
) to
the activation space to assist in correct classification. These correct "noises"
are learned using the gradient calculated by the input's own correct label ( 𝑦
) .
In BPN, to create beneficial perturbations, the first step is computing the
adversarial direction ( 𝑑
, Eqn.4, FGSM
131
), using the input's own correct label
( 𝑦
). However, instead of using the adversarial direction directly to create an
adversarial sample, we invert the sign ( 𝑑𝑏
𝑑
) in Eqn.5and perform
gradient descent towards to opposite direction in Eqn.6, away from the
adversarial vector ( 𝑑
). This computation is repeated for each layer 𝑖 .
𝑑
ϵ 𝑠𝑖𝑔𝑛 ∇
𝐿𝑏
, 𝑦
, θ 4
𝑑𝑏
𝑑
5
𝑏
𝑏
η 𝑑 𝑏
6
where η is the learning rate, 𝑑
is the adversarial perturbation, 𝑏
is the
beneficial perturbation bias, 𝑑𝑏
is the gradient for beneficial perturbation
bias, ϵ is a hyperparameter that decides how far we go towards the Fast
Gradient Sign direction, 𝑦
is the true label (input's own correct class) and
96
θ are the parameters of the neural network.
The creation of beneficial perturbations only requires the true labels and the
same gradients as we would normally have to train a Vanilla deep neural
network. Thus, we can generate the gradient for beneficial perturbations at layer
𝑖 by recycling the computed gradients while we are training the network on
classic gradient descent (Fig.4.3). Therefore we can consolidate Eqn.5 and
Eqn.6 into Eqn.7:
𝑏
𝑏
η ϵ 𝑠𝑖𝑔𝑛 𝐺𝑟𝑎 𝑑
7
where, 𝑑𝑏
is the gradient for beneficial perturbations bias and 𝐺𝑟𝑎 𝑑
is
the gradient calculated by classical stochastic gradient descent from next
layer 𝑖 1, ϵ is same as Eqn.5.
4.4.4 Computation Costs
T o generate beneficial perturbations, we do not introduce any extra computation
costs beyond FGSM. The forward (backward) pass computation costs of BPN
are only 0.00% (0.006%) FLOPS more than the costs of the base network
trained on clean examples only (Tab.4.1). BPN creates a shortcut (Fig.4.1) in
the training process that simulates a reverse adversarial attack in activation
space. Because of this shortcut, we don't have to instantiate the adversarial
examples in the original image space (input space) like in standard adversarial
training. This advantage saves a lot of time and enables BPN to defend against
adversarial examples robustly, even after only using clean examples for training.
Table 4.1. Computation costs of BPN trained on clean examples compared to
a classical network trained on clean examples on RestNet-50. For forward
(backward) pass, the computation costs of BPN are 0.00% (0.006%) FLOPS
more than the classical network.
97
4.4.5 Loss Function, Forward and Backward Rules
We present the loss function, forward and backward rules for BPN -
where 𝑑 ∗
is the gradient for *, 𝐺𝑟𝑎 𝑑
is the gradient calculated by stochastic
gradient descent from the next layer, 𝑥 and 𝑦
are the image inputs and its
true label.. f and L is the network model and cross entropy loss, j is an iterator
over the first dimension of 𝐺𝑟𝑎 𝑑
, other notations are same as notations in
forward rules.
4.4.6 Extending BPN to deep convolutional networks
Most deep convolutional neural networks are made with two parts: a feature
extraction part (convolutional and non-linear layers) and a classifier (fully
connected layers). Here, we introduce beneficial perturbations bias ( 𝑏
) to
the last few fully connected layers of the deep convolutional network,
replacing the normal bias term (Fig.4.4). We use FGSM (Eqn.7) to update
those beneficial perturbation biases.
98
Figure 4.4. BPN extension to deep convolutional neural network. Deep
convolutional neural network is made with two parts: feature extraction part
(convolutional and non-linear layers) and classifier part (fully connected layer).
We introduce beneficial perturbation bias (replace the normal bias term) to the
last few fully connected layers of the deep convolutional network and update
them using FGSM.
4.5 Experiments
4.5.1 Datasets
MNIST. MNIST
102
is a dataset with handwritten digits, with a training set of
60,000 examples, and a test set of 10,000 examples.
FashionMNIST. FashionMNIST
155
is a dataset of article images, with a training
set of 60,000 examples, and a test set of 10,000 examples.
TinyImageNet. TinyImageNet is a subset of the ImageNet
156
- a large visual
dataset. TinyImageNet consists of 200 classes and has a training set of 100k
examples, and a test set of 10k examples.
99
4.5.2 Network structure
For MNIST and FashionMNIST (LeNet).We use the convolutional and non-
linear layers of LeNet as feature extraction part
102
(classical LeNet). Then, for
the classifier part, we create our version of LeNet (LeNet with beneficial
perturbation bias) by adding beneficial perturbation biases into the fully
connected layers, replacing the normal biases.
For TinyImageNet (ResNet-18). We use the convolutional and non-linear
layers of ResNet-18
143
for the feature extraction part (classical ResNet-18).
Then, we use three fully connected layers with 1028 hidden units as a
classifier. We create our version of ResNet-18 (ResNet-18 with beneficial
perturbation bias) by adding beneficial perturbation biases into the fully
connected layers, replacing the normal biases. We trained the BPN (ResNet-
18) with 5000 epochs on TinyImageNet.
4.5.3 Various attack methods
To demonstrate how BPN can successfully defend against adversarial attacks,
we used the advertorch toolbox
147
to generate adversarial examples and
tested our BPN structure on adversarial examples generated from various
attack methods. We employ only white box attacks, where the attacker has
access to the model's parameters. All adversarial examples were generated
directly from the same model that they attacked (e.g., BPN generates
adversarial examples against itself and likewise the classical network
generates adversarial examples against itself). We tested three adversarial
attacks:
(1) PGD Linf
157
: Projected Gradient Descend Attack with order = Linf.
(2) PGD L2
157
: Projected Gradient Descend Attack with order = L2.
(3) FGSM
131
: One step fast gradient sign method.
100
4.6 Results
4.6.1 In scenario 1: BPN can defend adversarial examples with additional
negligible computational costs
When the neural network can only be trained on clean examples because of
modest computation budgets, the biggest achievement of BPN is that it can
defend against adversarial examples with only a low computational overhead.
BPN achieves much better test accuracy on adversarial examples than a
classical network (baseline, Tab.4.2, MNIST: 98.88% vs. 18.08%,
FashionMNIST: 54.07% vs. 11.87%, TinyImageNet: 53.29% vs. 1.45% ).
Thus, for companies with modest computation resources, BPN can help a
system achieve moderate robustness against adversarial examples, while
only introducing additional negligible computation costs. For example, on
FashionMNIST, our method only uses 59% training time compared to
adversarial training that uses just one adversarial example per clean example,
saving 43.51 minutes training time for 500 training epochs on an NVIDIA
Tesla-V100 platform. Just for reference, the classical Stochastic gradient
descent clean training would be at 50% compared to the same adversarial
training framework. Mathematically, our method only introduces 0.006% cost
compared to clean training because we only introduced one Sign and one
multiplication operation for each fully connected layer Tab.4.1 (which means
BPN approach a 50% training time compared to adversarial training).
However, in our current implementation, we incur an extra 9% cost because
we created a custom layer in Pytorch framework that introduces a lot of
overhead. This can be greatly improved if the custom layers are incorporated
into the Pytorch framework with a C++ implementation. The computational
savings would be huge on a larger dataset such as Imagenet
158
).
101
Table 4.2. Scenario 1 Training on only clean examples for both BPN and
classical network (CN) because of modest computational budget. Testing on
clean examples (Cln Ex) and adversarial examples (Adv Ex) (generated by
FGSM, ϵ = 0.3 for MNIST, FashionMNIST and TinyImageNet). CN does
poorly on adversarial examples. While BPN can successfully defend
adversarial examples (emphasized by light cyan background).
4.6.2 In scenario 2: BPN can alleviate the decay of clean sample accuracy
When the neural network can only be trained on adversarial examples because
of modest computation power, BPN becomes naturally more robust to
adversarial examples, but it also performs much better on clean samples that
in fact it has never been trained on. With a classical neural network (adversarial-
only network), training it on only adversarial examples produces a high test
accuracy on adversarial examples but it hurts test accuracy on clean examples
(which it never saw) (Tab.4.3). The clean sample accuracy this adversarial-only
network decreases from 99.01%, 89.17%, 64.30% (upper-bound classical
network trained only clean images to 95.54%, 65,64%, 18.67% (Tab.4.3) for
MNIST, FashionMNIST and ImageNet datasets respectively. Compared to this
adversarial-only network, BPN not only achieves better test accuracy on
adversarial examples (Tab.4.3, MNIST: 99.27%, FashionMNIST: 92.07%,
TinyImageNet: 79.92%), but also achieves a better test accuracy on clean
examples (Tab.4.3, MNIST: 97.32%, FashionMNIST: 71.54%, TinyImageNet:
20.69%). This accuracy on clean examples is still worse than the upper bound
(classic network trained only on clean examples), but it is much better than the
accuracy of the classic network trained only with adversarial images. The
102
reason is that beneficial perturbations would convert some adversarial
examples into clean examples because of the neutralization (Eqn.2) effect. As
a consequence, the increased diversity of the clean examples improves the
generalization of BPN.
Table 4.3. Scenario 2: Training on only adversarial examples for both BPN and
classical network (CN) with modest computational budget. Testing on clean
examples (Cln Ex) and adversarial examples (Adv Ex) (generated by FGSM, ϵ
= 0.3 for MNIST , FashionMNIST and TinyImageNet). BPN can achieve a better
classification accuracy on clean examples than CN (emphasized by light cyan
background).
4.6.3 In scenario 3: BPN can alleviate the decay of clean sample accuracy
When the neural network can be trained on both clean and adversarial
examples because of abundant computation power, BPN is shown to be
marginally superior than classical adversarial training (Tab.4.4). The reason is
that BPN can further improve the generalization of the network by diversifying
the training set with the neutralization (Eqn.2) effect. BPN can achieve slightly
higher accuracy on clean examples than the classical network (Tab.4.4 MNIST
99.13% vs. 99.09%, FashionMNIST 89.65% vs. 89.49%, TinyImageNet 66.84%
vs. 66.56%). In addition, BPN can achieve higher accuracy on adversarial
examples than the classical Network MNIST 97.62% vs. 97.01%,
FashionMNIST 95.39% vs. 94.98%, TinyImageNet 88.16% vs. 85.75%).
103
However, we should normally avoid this scenario because training on both
clean and adversarial examples is expensive in terms of running memory and
computation costs. In addition and most importantly, it is infeasible and
expensive to introduce all unknown attack samples into the adversarial training
148
(this would require generating at least one attack-specific adversarial image
per attack). Instead of this data-driven data-augmentation approach, our model
should be generalizable to unknown attacks, see our results Sec.4.6.4.
Table 4.4. Scenario 3: Training on both clean and adversarial examples for both
BPN and classical network (CN) with abundant computational power. Testing
on clean examples (Cln Ex) and adversarial examples (Adv Ex) (generated by
FGSM, ϵ = 0.3 for MNIST, FashionMNIST and ImageNet). BPN is marginally
superior to CN. However, we should avoid this scenario because it is expensive
in terms of running memory computation costs and infeasible to introduce to all
unknown attack samples into the adversarial training.
4.6.4 BPN can generalize to unseen attacks that it has never been trained on
Here, we go back to scenario 1 to test the generalization ability of BPN. When
the neural network can only be trained on clean examples because of modest
computational budget, we trained BPN on clean examples with FGSM and
tested on adversarial examples generated by various attack methods (e.g.,
FGSM and PGD attacks).
In standard adversarial training, a model can only defend against the kind of
104
adversarial examples that it has been trained on
131,153
(e.g., a model trained on
adversarial examples created by FGSM attack can only defend FGSM attack
and would fail to defend other attacks such as PGD attack). Similarly, if BPN is
updated only by the adversarial direction generated by FGSM attack, the
expectation would be that the BPN can only defend FGSM attack. However,
from Tab.4.5, we found that BPN trained only with FGSM can not only defend
FGSM attacks pretty well, but also can generalize to improve the robustness
against even harder attacks that it has never been trained on (e.g., PGD attack
in Tab.4.5). This feature alone has an edge compared to standard adversarial
training since it is infeasible to introduce all unknown attack samples into
adversarial training.
The influence of the adversarial perturbation budget is discussed in the
Supplementary section. We discuss more possibilities to further improve the
robustness and generalization of BPN by exploring more structures and training
procedures of BPN in the Supplementary section.
Table 4.5. BPN can generalize to unseen attacks. We go back to scenario 1
where both BPN and classical network (CN) are trained on clean examples of
MNIST and TinyImageNet because of modest computation budget. Testing on
adversarial examples generated by FGSM attack ( ϵ = 0.3) and PGD attack ( ϵ=
0.3, number of iteration = 40, attack step size = 0.01, random initialization =
True, order = Linf or L2). BPN trained with FGSM is not only robust to FGSM
attack, but also can successfully generalize to defend PGD attacks that it has
never been trained on.
105
4.7 Discussion
We proposed a new solution for defending against adversarial examples, which
we refer to as Beneficial Perturbation Network (BPN). BPN, for the first time,
leverages the beneficial perturbations (opposite to well-known adversarial
perturbations) to counteract the effects of adversarial perturbations input data.
Compared to adversarial training, this approach introduces four main
advantages - We demonstrated that bf (1) BPN can effectively defend
adversarial examples with negligible additional running memory and
computation costs; (2) BPN can alleviate the accuracy trade-off - hurts the
accuracy on clean examples less than classical adversarial training; (3) The
increased diversity of the training set can improve generalization of the network;
(4) Compared to adversarial training that can only defend the kind of adversarial
examples that is has been trained on. We found experimentally that BPN has
the ability to generalize to unseen attacks that it has never been trained on.
4.7.1 Intriguing property of beneficial perturbations
We suggest that the intriguing property of the beneficial perturbations that
neutralize the effects of adversarial examples might come from the property of
adversarial subspaces. Following the adversarial direction, such as by using
the fast gradient sign method (FGSD)
131
, can help in generating adversarial
examples that span a continuous subspace of large dimensionality (adversarial
subspace). Because of “excessive linearity” in many neural networks
132,153
, due
to features including Rectified linear units and Maxout, the adversarial
subspace often takes a large portion of the total input space. Once an
adversarial input lies in the adversarial subspace, nearby inputs also tend to lie
in it. Interestingly, this corroborates recent findings by Ilyas et al.
113
that
106
imperceptible adversarial noise can not only be used for adversarial attacks on
an already-trained network, but also as features during training. For instance,
after training a network on dog images perturbed with adversarial perturbation
calculated from cat images, the network can achieve a good classification
accuracy on the test set of cat images. This result shows that those features
(adversarial perturbations) calculated from the cat training sets, contain
sufficient information for a machine learning system to make correct
classification on the test set of cat images. Here, Beneficial perturbations would
operate in analogous manner. We calculate those features and store them into
the beneficial perturbation bias. In this case, although the inputs data have been
modified (distribution shifts of input data - information are corrupted by
adversarial perturbations), the stored beneficial features have sufficient
information to neutralize the effects of adversarial examples and enable the
network to make correct predictions.
4.7.2 Beneficial perturbations: the opposite "twins" of adversarial perturbations
Beneficial perturbations can be viewed as the opposite "twins" of adversarial
perturbations. Much research is underway on how to generate more and more
advanced adversarial perturbations
131,145,157,159,160
to fool the more and more
sophisticated machine learning systems. However, there is a little research
161
on how to generate beneficial perturbations and possible applications of
beneficial perturbations. For example, Wen et al.
161
have demonstrated that
beneficial perturbations can largely eliminate catastrophic forgetting (training
the same neural network on a new task would destroy the knowledge learned
from the old tasks) on subsequent tasks.
107
5 What can we learn from misclassified ImageNet images?
5.1 Abstract
Understanding the patterns of misclassified ImageNet images is particularly
important, as it could guide us to design deep neural networks that generalize
better. However, the richness of ImageNet imposes difficulties for researchers
to visually find any useful patterns of misclassification. Here, to help find these
patterns, we propose "Superclassing ImageNet dataset". It is a subset of
ImageNet which consists of 10 superclasses, each containing 7-116 related
subclasses (e.g., 52 bird types, 116 dog types). By training neural networks
on this dataset, we found that: (i) Misclassifications are rarely across
superclasses, but mainly among subclasses within a superclass. (ii) Ensemble
networks trained each only on subclasses of a given superclass perform better
than the same network trained on all subclasses of all superclasses. Hence, we
propose a two-stage Super-Sub framework, and demonstrate that: (i) The
framework improves overall classification performance by 3.3%, by first
inferring a superclass using a generalist superclass-level network, and then
using a specialized network for final subclass-level classification. (ii) Although
the total parameter storage cost increases to a factor N+1 for N superclasses
compared to using a single network, with finetuning, delta and quantization
aware training techniques this can be reduced to 0.2N+1. Another advantage
of this efficient implementation is that the memory cost on the GPU during
inference is equivalent to using only one network. The reason is we initiate each
subclass-level network through addition of small parameter variations (deltas)
to the superclass-level network. (iii) Finally, our framework promises to be more
scalable and generalizable than the common alternative of simply scaling up a
vanilla network in size, since very large networks often suffer from overfitting
and gradient vanishing.
108
5.2 Motivation
Deep neural networks have led to a series of breakthroughs for image
classification
123,162
. To improve classification accuracy, researchers proposed
different kinds of strategies, including designing better network structures (e.g.,
ResNet
143
and DenseNet
163
), designing better optimizers (e.g., Adam
optimizer
56
and RMSprop optimizer
164
), applying better loss formulation (e.g.,
contrastive loss
165
), and using large-scale pretraining
166
.
Figure 5.1. Superclassing ImageNet dataset. Superclassing ImageNet dataset
is a subset of the ImageNet dataset. It contains broad classes which each
subsume several of the original ImageNet classes. It is consisted of 10
superclasses – Bird, Boat, Car, Cat, Dog, Fruit, Fungus, Insect, Monkey and
Truck superclasses. Each superclass contains several subclasses, for example,
the Bird superclass is consisted of Chicken, Ostrich, Black Swan subclasses
and etc.
However, few researchers have explored the patterns of misclassifications by
current DNNs. For instance, Dodge et al.
167
found that DNNs performance is
much lower than human performance on distorted images. Understanding
these patterns is particularly important because it could guide us to design more
109
robust DNNs that make fewer mistakes and generalize better.
In addition, while in early days the answer to improving performance was often
to build ever larger networks, this approach has more recently been shown to
not scale well because of (i) overfitting
50,168,169
and (ii) gradient vanishing
170
– (i)
When there is too much overfitting, this results in poor generalization in the test
set. For example, Laurence et al.
169
(in their Fig.5.2) showed that, generally, as
the network size increases, training error decreases but this ultimately leads to
overfitting. As a consequence, test errors exhibit a "U" shaped behavior as a
function of network size. (ii) The gradient vanishing problem limits the number
of parameters and layers for scaling up the size of networks. Residual
networks
143
are famous for easing the training of networks that are substantially
deeper. They create a gradient highway for reducing the effect of gradient
vanishing. However, residual networks still cannot avoid the problem entirely.
For example, in Tab.4.6 of the original residual network paper
143
, performance
of a residual network with 1202 layers is lower than that of residual networks
with 32, 44, 56, 110 layers. We find similar scaling issues as detailed in our
Results. This prompts us to explore alternatives for boosting performance with
more parameters rather than naively expanding the size of a single network.
5.3 Observations and General Framework
Often, in image classification research, datasets with the richness of ImageNet
158
are desired. However, the richness of the datasets imposes extraordinary
difficulties for researchers to visually find any useful pattern of errors. For
example, there are 1000 classes in ImageNet. It is impractical to find any useful
pattern in a 1000 by 1000 confusion matrix. To smooth the path towards finding
these patterns, without adding complexity to the ImageNet dataset, we
“superclass” ImageNet –"Superclassing ImageNet dataset". Superclassing
ImageNet dataset is a hierarchical subset of ImageNet that contains broad
110
superclasses (derived from Wordnet, see Sec.4.1) which each subsume
several of the original ImageNet classes (Fig.5.1). Our dataset contains 10
superclasses and each superclass contains 7 to 116 subclasses, for a total of
253 subclasses. Thus, by adopting this hierarchical subdivision, we initially only
navigate through 10 superclasses instead of the 1000 classes in the original
ImageNet dataset. We found useful patterns in a 10 by 10 confusion matrix
(Fig.5.2) and in the summary of classification accuracies for all subclasses
within each superclass (Fig.5.3). We highlight two observations from training
convolutional neural networks on this dataset:
Figure 5.2 Confusion matrix for inter-superclasses prediction. Correct
classification (green), mild wrong classification (orange), severe wrong
111
classification (red). The convolutional neural network (ResNet-18) barely make
inter-superclasses misclassifications as the averaged correct prediction
performance is 96.52%.
(i) Misclassifications are seldom across superclasses (Fig.5.2), but mainly
intra-superclass.
We trained a ResNet-18 neural network (pretrained on ImageNet) on all images
from all superclasses of the Superclassing ImageNet dataset to make a 10-way
superclass prediction. We named it the Superclass Network. The high values
on the diagonal of the confusion matrix in Fig.5.2 mean that the Superclass
Network barely makes any inter-superclass mistakes, as the averaged correct
prediction performance is 96.52%. For example, the probability of recognizing
a bird image as belonging to the Bird superclass is 96.77%, while the probability
of recognizing a bird image as belonging to the Car superclass is only 0.08%.
In addition, for a network that is still trained on all images from all superclasses,
but for which the goal is to perform a 253-way classification directly among all
subclasses, the averaged correct prediction performance decreases to 71.18%.
This lower performance suggests that most misclassifications are in intra-
superclass predictions.
(ii) Ensemble networks trained each only on subclasses of a given
superclass (upperbound) perform better than the same network trained
on all subclasses of all superclasses (lowerbound; Fig.5.3).
In more details:
1. Upperbound (with a superclass oracle): We trained a ResNet-18 neural
network (pretrained on ImageNet) on images from a particular superclass in the
superclassing ImageNet dataset. We named it a Subclass Network. There are
ten superclasses in the superclassing ImageNet dataset, so we have ten
Subclass Networks. For a test image, with a superclass oracle, we select the
right Subclass Network and only infer the subclasses inside the superclass it
112
belongs to (upperbound). For example, a Subclass Network for the Dog
superclass is trained only on dog images. It would only predict the 116 dog
subclasses (e.g., Husky, Malamute, Papillon, etc). The correct prediction
performance is 75.07% when averaged across all 10 superclasses.
2. Lowerbound (without a superclass oracle): we trained a ResNet-18 neural
network (pretrained on ImageNet) on all images from all subclasses in the
Superclassing ImageNet dataset. For a test image, the network would infer the
subclass among all 253 subclasses in all superclasses, without any prior
knowledge of which superclass the test image belongs to (lowerbound). For
exposition purposes, in Fig.5.3), the results are shown for all subclasses within
each superclass. The averaged performance across all 253 subclasses is
71.18%. The performance gap can be quantified as: upperbound is 5.47%
better than lowerbound.
Figure 5.3. Performance gap between classifying an image for subclasses
inside the superclass it belongs to (upperbound, green, with a superclass oracle,
the averaged correct prediction performance is 75.07%) and classifying an
image for all subclasses from all superclasses (lowerbound, without a
113
superclass oracle, red, the averaged correct prediction performance is 71.18%).
Upperbound is 5.47% better than lowerbound.
In this paper, our goal is to design a framework such that its classification
performance can approach the upperbound performance without requiring a
superclass oracle. Here, we propose a two-stage Super-Sub framework
(Fig.5.4) to reduce the performance gap. In details, in the first stage, we use a
Superclass Network to decide the superclass of the test image. In the second
stage, we choose the corresponding Subclass Network according to the
superclass decision made in the first stage. This Subclass Network then outputs
the final fine-grained subclass of the test image. More generally, our two-stage
Super-Sub framework enables a neural network system to automatically adjust
its behavior according to different inputs
171
, here achieved through the
aforementioned hierarchical two-step inference. From our first observation, the
Superclass Network barely makes inter-superclass mistakes, so the role of the
Superclass Network can be approximated as a superclass oracle, helping to
choose the correct Subclass Network for subclass classification. As a result,
our two-stage Super-Sub framework is 3.30% better than the lowerbound on
the new Superclassing ImageNet dataset.
Nonetheless, our vanilla implementation of the framework incurs a large cost in
both parameter storage and inference runtime memory on the GPU (a factor of
N+1 for N superclasses compared to a single network). To improve this, we
propose a method that keeps the same inference runtime memory as only one
network, and that reduces the parameter storage cost to 0.2N+1 based on
finetuning
172
, delta and quantization aware training (QAT)
173
techniques, all
while preserving the same accuracy boost. Most importantly, our framework
promises to be more scalable and generalizable than the common alternative
of simply scaling up a vanilla network in size since very large networks suffer
from overfitting and gradient vanishing. We demonstrate our framwork's
114
capacity for scalability (see Results Sec.5.6) by increasing the representational
capacity starting from an already very large network, i.e. Resnet-152.
5.4 Our approach
5.4.1 The Superclassing ImageNet dataset
Superclassing ImageNet dataset (Fig.5.1) is a new dataset that contains broad
classes which each subsume several of the original ImageNet classes. It is a
subset of the ImageNet Dataset that contains 10 superclasses with 253
subclasses. It contains 311,279 training images and 12,650 test images.
Making custom dataset: We leverage the WordNet hierarchy according to which
ImageNet is organized 174 to create the Superclassing ImageNet dataset with
10 superclasses (Tab.5.1):
Tabel 5.1. Superclassing ImageNet dataset: For each row, we present the
name of each Superclass and its contained number of subclasses, number of
training images and number of test images.
115
Figure 5.4. Two-stage Super-Sub framework: On stage 1, a Superclass
Network is trained to decide the superclass of a test image. On stage 2, we
choose the corresponding Subclass Network to decide the subclass inside the
decided superclass in stage 1.
5.4.2 Super-Sub framework
Our goal is to approach the upperbound performance without having a
superclass oracle. We proposed a two-stage Super-Sub framework to achieve
this goal (Fig.5.4): In the first stage, a Superclass Network is trained to
decide the superclass of a test image. Since the Superclass Network barely
makes inter-superclass mistakes, it can be approximated as a superclass
oracle. The approximated superclass oracle can help us select the right
Subclass Network for the second stage. In the second stage, we use the right
Subclass Network to decide the subclasses inside the superclass decided in
the first stage.
5.4.2.1 Vanilla Network Architectures and Vanilla Inference Rules
Superclass Network: The Superclass Network decides the superclass of a test
116
image (e.g., Bird, Boat, Cat, etc.) We trained a ResNet-18 neural network on
all images from all superclasses of Superclassing ImageNet dataset. Formally,
we trained the Superclass Network by minimizing the loss function as:
min
𝐿𝑓
𝑥
, 𝑦
; 𝜃
1
where 𝑓
and L are the Superclass Network and cross entropy loss. 𝑥
and 𝑦
are the image inputs from all superclasses and its superclass
label. θ
is the parameters of the Superclass Network. The learning rate is
0.01. The training epochs are 200.
Vanilla Subclass Networks: Each Subclass Network decides the subclasses
within the corresponding superclass. Each Vanilla Subclass Network is a
ResNet-18. We trained each Subclass Network on images from each
particular superclass of our Superclassing ImageNet Dataset. Each Subclass
Network would only infer in the range of the subclasses for the superclass it
corresponds to. If there are N superclasses, there are N Subclass Networks.
For example, a Subclass Network for Bird superclass is trained on bird
images only. It would only predict the subclasses (e.g., Chicken, Ostrich,
Black Swan, etc.) within the Bird superclass. Formally, we trained the
Subclass Networks by minimizing the loss function as:
min
𝐿𝑓
𝑥
, 𝑦
; 𝜃
, ∀ 𝑖 𝐵𝑖𝑟𝑑 , 𝐵𝑜𝑎𝑡 , 𝐶𝑎𝑟 ,..., 𝑇𝑟𝑢𝑐𝑘 2
where i is an iterator that indicates the current superclass. 𝑓
is the
Subclass Network corresponding to superclass i. L is the cross entropy loss.
𝑥
and 𝑦
are the input images from superclass i and its corresponding
subclass labels. θ
is the parameters of the Subclass Network for
superclass i. For our experiments, we use a learning rate of 0.01 with 100
epochs. The settings of all hyper-parameters are the same for all subclass
networks.
Vanilla Inference Rules:
117
Formally, we define the vanilla inference rules as following:
Figure 5.5. Efficient implementation of Super-Sub framework: one network at
a time for inference. The compressed deltas (20% of its original size) between
superclass network and subclass finetuned networks are stored inside the
main memory. The subclass finetuned networks are reconstructed by adding
the corresponding deltas to the parameters of the Superclass Network. This
approach is less time consuming because the compress deltas incur less I/O
cost.
5.4.2.2 Efficient Network Architectures and Efficient Inference Rules
In this subsection, our goal is to implement the Super-sub framework efficiently
with particular focus on deployment of deep learning systems on edge devices.
It is costly to design an edge device that is capable of loading all networks to its
GPU as required by our vanilla implementation. Memory-wise, it is cheaper to
load one network at a time to the GPU. However, on edge devices switching
118
between different networks incurs expensive (in terms of time) input/output (I/O)
operations between GPU and main memory. Since at inference time we need
to keep switching between the super-class network and sub-network, the cost
of the I/O operations would become expensive. Here, we propose an efficient
implementation (Fig.5.5) that trades-off the expensive I/O operations to likely
cheaper add operations.
The problem domains of Subclass Networks are subsets of the problem
domain of the Superclass Network. For example, recognizing an image of a
Husky in the dog superclass is a subproblem of recognizing an image of dog
in various superclasses. Since it is a domain reduction problem
174
, Subclass
Networks can be created by finetuning from the Superclass Network efficiently
(Subclass Finetuned Networks). We found that the deltas between the
parameters of the Superclass Network and each Subclass Finetuned Network
are small and because of this, we are able to greatly compress them. Thus, in
our efficient implementation, to recreate a Subclass Network, we load the
compressed deltas (reducing the I/O operations to 20%), and add them in the
GPU to the parameters of the Superclass Network (Fig.5.5).
Superclass Network: The network architecture and training of the Superclass
Network in the efficient implementation is the same as the Vanilla Superclass
Network described in Sec.5.4.2.1.
Subclass Finetuned Network: Instead of training each Subclass Network from
scratch for each superclass, we created it by finetuning the Superclass
network on the images from each particular superclass of the superclassing
ImageNet dataset, by minimizing the loss function:
𝑓 _
𝑥
, 𝑦
; θ
← min
𝐿𝑓
𝑥
, 𝑦
; 𝜃
,
119
∀ 𝑖 ∈ 𝐵𝑖𝑟𝑑 , 𝐵𝑜𝑎𝑡 , 𝐶𝑎𝑟 ,..., 𝑇𝑟𝑢𝑐𝑘 3
where, 𝑓 _ is the Subclass Finetuned Network for superclass $i$,
other notations are the same as in Eqn.1 and Eqn.2. The hyper-parameters
are the same in the Vanilla Subclass Networks.
Deltas: The deltas for a particular superclass are created with the subtraction
of the parameters of the Superclass Network from that of the corresponding
Subclass Finetuned Network:
𝑑𝑒𝑙𝑡𝑎 𝑠 θ
θ
4
where 𝑑𝑒𝑙𝑡𝑎 𝑠 corresponds to the deltas for superclass i. other notations are
as in Eqn.1 and Eqn.2. 𝑑𝑒𝑙𝑡𝑎 𝑠 is calculated through mini-batch gradient
descent via Eqn.3. Since these gradients are in half-precision (16 bits), 𝑑𝑒𝑙𝑡𝑎 𝑠
is represented in half-precision (50% of the size to store the Subclass
Finetuned network in full-precision 32 bits). Then, we use 7-Zip to further
compress 𝑑𝑒𝑙𝑡𝑎 𝑠 (reduce to 44% of the size to store the Subclasses
Finetuned network in full-precision) and store them. In the following, we show
how to further reduce the compressed delta size to a final 20% with
Quantization Aware Training techniques. For inference, to recreate the
Subclass Finetuned Network, we can load the corresponding compressed
deltas to GPU and add them to the Superclass Network.
Efficient Inference Rules: Formally, we defined the efficient inference rules as:
Quantization Aware Training (QAT): To further reduce the compressed size of
120
the deltas, we used Quantization aware training (QAT)
173
, available in the
Pytorch Framework, to finetune the Superclass Network (See
Supplementary). With QAT, the values are limited to 256 possibilities in 8-bit
quantization. Thus, with less variability, the compressed size of the deltas is
further reduced to 20% compared to storing the original quantized Subclass
Network.
In summary, in our efficient implementation there is always only one network
running on the GPU. Furthermore, the extra main memory cost for each
Subclass Network is reduced to around 20% of its original size.
5.5 Experiments
We evaluated our Vanilla and Efficient implementations on the Superclassing
ImageNet Dataset. We report classification accuracy on each superclass and
average accuracy across all superclasses. In addition, we report the
compression ratio for storing the deltas used in the efficient implementation.
To our best knowledge, we are the first to leverage the patterns of
misclassified images to improve the scalability and generalizability of DNNs.
Because our method is in a very novel direction, we did not find directly
applicable benchmarks to compare to. Instead, we propose the lowerbound
and upperbound defined in Sec.5.3 to act as natural references to our
framework's performance.
5.6 Results and Analysis
The Super-Sub framework improved overall classification performance without
a superclass oracle: On average, with a RestNet-18 backbone, our Efficient
(Vanilla) implementation is 3.20% (3.30%) better than the lowerbound
121
(Fig.5.6) and approaches the upperbound. The performances of Efficient and
Vanilla implementations are almost equivalent. With a RestNet-152 backbone,
our Vanilla implementation is 8.00% better than the lowerbound.minimal too
domain-specific modification.
Figure 5.6. The performances of Efficient (rose-red) and Vanilla (yellow)
Implementations of the Super-Sub framework. Both Efficient and Vanilla
Implementations reduced the performance gap between upperbound (green)
and lowerbound (red). The rose-red and yellow eclipses represent the relative
position of efficient and vanilla implementations' performances to the
upperbound and lowerbound. The definitions of upperbound and lowerbound
are the same as that of Fig.5.3.
Deltas are small and can be greatly compressed: We found that the deltas of
parameters between the Superclass Network and Subclass Finetuned
Networks are small (Fig.5.7.a,b; black curve fluctuates around 0). The
histograms of deltas for different superclasses further confirmed this
(Fig.5.7.c). In addition, we found we can compress the deltas because their
distributions are highly concentrated around 0 (Fig.5.7.c). The average
compression ratio is 0.44 without using QAT (Fig.5.7.d) and is 0.20 with QAT
(Fig.5.7.e). This means that the GPU only needs to load 20% of the size of the
original quantized Subclass Finetuned Network to recreate it from an already
122
loaded Superclass Network.
Super-sub framework promises to be more scalable and generalizable: To
demonstrate this, we pick a starting point of a large network (ResNet-152) as
a backbone of our Super-Sub framework. Resnet-152, in terms of memory
size, marks a threshold after which gradient vanishing and overfitting can start
occurring, prohibiting further layer increases within the traditional ResNet
regime. This is often referred as the "U" shaped test error performance as a
function of network size. We show that if we scale up Resnet-152 using our
super-sub framework, which incurs in adding 440 MB of memory, we observe
a huge performance boost from 78.43% to 84.63% in accuracy. In contrast, if
Resnet-152 is scaled up naively just by adding layers, which here would be
approximately equal to expanding to a Resnet-455 (Super-Sub framework
with a Resnet-152 backbone Resnet-455 in terms of memory), we observe
a stark performance degradation, with the accuracy lowering to 68.01%.
123
Figure 5.7. a-b) Visualization of weights of a convolutional layer and a batch
normalization layer in a Superclass Network (red) and Subclass Finetuned
Network (for fruit superclass, blue). The black curve is the difference between
the two networks. c) Various histograms of deltas between Superclass network
and Subclass Finetuned Network for each superclass. For d-e) The dash line
is the averaged compressed ratio across all superclasses. d) Compressed ratio
for Subclass Networks without using QAT. e) Compressed ratio for Subclass
Networks with QAT.
124
5.7 Discussion
In this paper, we studied the patterns of misclassified ImageNet images and
leveraged these patterns to improve the scalability and generalizability of a
DNN. By training a neural network on "Superclassing ImageNet Dataset", we
found that: (i) Misclassifications are rarely inter-superclasses but mainly intra-
superclass. (ii) Ensemble networks trained each only on subclasses of a given
superclass perform better than the same network trained on all subclasses of
all superclasses. Hence, we proposed the Super-Sub framework and
demonstrated that: (i) The Super-Sub framework reduced the performance
gap without having a superclass oracle, (ii) We leverage techniques to create
an efficient implementation with low parameter and memory costs. (iii) We
demonstrated that our Super-Sub framework promises to be more scalable
and generalizable.
To date, scaling up the size and depth of DNN quickly becomes prohibitive
due to the emergence of gradient vanishing
170
and overfitting issues. Our two-
stage super-sub framework provides a new and efficient way of scaling up a
DNN beyond the current limitations. Our framework subdivides the overall
classification problem domain into separate smaller sub-domains that are
semantically grouped. Each image goes first through a high-level classification
by a Superclass Network which then routes it to the appropriate Subclass
Network specialized for the sub-domain fine-grained classification. In our
approach, each sub-domain is only a subset of the original overall
classification problem. Hence, each can be expressed through a smaller
model with less parameters and therefore less overfitting when compared to
using a larger model to tackle directly the original problem. As previously
discussed, naively increasing the size and depth within one DNN is not
scalable because of gradient vanishing. Here, our model addresses this by
increasing the number of parameters "laterally" into several separate network
125
modules (sub-domains) which naturally confers more robustness to gradient
vanishing and allows a much greater potential for upward scalability.
In sum, our approach is a more scalable and generalizable alternative to just
naively expanding the size of a single network. Nonetheless, despite those
excellent benefits, two-stage Super-Sub framework has some limitations: (i)
Errors in first stage are unrecoverable; We could create better superclasses to
maximize the Superclass Network performance using a data-driven approach.
(ii) We have not yet fully implemented the efficient version on an edge device.
This may require a next generation of devices that could support
uncompressed, weight operations, etc, if our approach gains sufficient general
interest.
5 Conclusion
My research so far has shown the symbiotic closed-loop relationship between
artificial intelligent systems and primate brains. I have successfully developed
bio-inspired artificial intelligence systems. In addition, I leveraged tools from
mathematics and artificial intelligence to better understanding the
mechanisms of primate brains. In the future, I hope my Ph.D. research could
further help 1) the development and deployment of next-generation artificial
intelligence systems, 2) the development of a personalized approach in the
intervention of brain disorders and 3) . the development of effective and
economic approaches in early diagnoses of brain disorders.
6 Acknowledgements
I would like to thank my advisor Prof. Dr. Laurent Itti for his strong supports and
advices over the years. I would also like to thank all members of the USC iLab
and those who have supported, inspired, and entertained me during my
research.
The work in this thesis was supported by: C-BRIC (one of six centers in JUMP,
aSemiconductor Research Corporation (SRC) program sponsored by DARPA),
126
the Army Research Office (W911NF2020053), and the Intel and CISCO
Corporations.
7 References
1. Wessberg, J. et al. Real-time prediction of hand trajectory by ensembles
of cortical neurons in primates. Nature (2000) doi:10.1038/35042582.
2. Velliste, M., Perel, S., Spalding, M. C., Whitford, A. S. & Schwartz, A. B.
Cortical control of a prosthetic arm for self-feeding. Nature (2008)
doi:10.1038/nature06996.
3. Taylor, D. M., Tillery, S. I. H. & Schwartz, A. B. Direct cortical control of
3D neuroprosthetic devices. Science (80-. ). (2002)
doi:10.1126/science.1070291.
4. Carmena, J. M. et al. Learning to control a brain-machine interface for
reaching and grasping by primates. PLoS Biol. (2003)
doi:10.1371/journal.pbio.0000042.
5. Li, Z. et al. Unscented Kalman filter for brain-machine interfaces. PLoS
One (2009) doi:10.1371/journal.pone.0006243.
6. Wu, W. et al. Neural Decoding of Cursor Motion Using a Kalman Filter.
Adv. Neural Inf. Process. Syst. 15 Proc. 2002 Conf. (2003)
doi:10.1.1.6.8776.
7. Wu, W., Gao, Y., Bienenstock, E., Donoghue, J. P. & Black, M. J.
Bayesian population decoding of motor cortical activity using a Kalman
filter. Neural Comput. (2006) doi:10.1162/089976606774841585.
8. Gao, Y., Black, M. J., Bienenstock, E., Wu, W. & Donoghue, J. P. A
quantitative comparison of linear and non-linear models of motor cortical
activity for the encoding and decoding of arm motions. in International
IEEE/EMBS Conference on Neural Engineering, NER (2003).
doi:10.1109/CNE.2003.1196789.
9. Eden, U. T., Frank, L. M., Barbieri, R., Solo, V. & Brown, E. N. Dynamic
127
Analysis of Neural Encoding by Point Process Adaptive Filtering. Neural
Comput. (2004) doi:10.1162/089976604773135069.
10. Eden, U. T. Point process adaptive filters for neural data analysis:
Theory and applications. in Proceedings of the IEEE Conference on
Decision and Control (2007). doi:10.1109/CDC.2007.4434708.
11. Hochreiter, S. & Schmidhuber, J. Long Short-Term Memory. Neural
Comput. (1997) doi:10.1162/neco.1997.9.8.1735.
12. Glaser, Joshua I and Chowdhury, Raeed H and Perich, Matthew G and
Miller, Lee E and Kording, K. P. Machine learning for neural decoding.
(2017).
13. Moeendarbary, E. et al. The soft mechanical signature of glial scars in
the central nervous system. Nat. Commun. (2017)
doi:10.1038/ncomms14787.
14. Kozai, T. D. Y., Jaquins-Gerstl, A. S., Vazquez, A. L., Michael, A. C. &
Cui, X. T. Brain tissue responses to neural implants impact signal
sensitivity and intervention strategies. ACS Chem. Neurosci. (2015)
doi:10.1021/cn500256e.
15. Duffau, H. Brain Plasticity and Reorganization Before, During, and After
Glioma Resection. in Glioblastoma (2016). doi:10.1016/B978-0-323-
47660-7.00018-5.
16. Tkach, D., Reimer, J. & Hatsopoulos, N. G. Observation-based learning
for brain-machine interfaces. Current Opinion in Neurobiology (2008)
doi:10.1016/j.conb.2008.09.016.
17. Gallego, J. A., Perich, M. G., Chowdhury, R. H., Solla, S. A. & Miller, L.
E. A stable, long-term cortical signature underlying consistent behavior.
bioRxiv (2018) doi:10.1101/447441.
18. Gao, P. et al. A theory of multineuronal dimensionality, dynamics and
measurement. bioRxiv (2017) doi:10.1101/214262.
19. Pandarinath, C. et al. Inferring single-trial neural population dynamics
using sequential auto-encoders. Nat. Methods (2018)
128
doi:10.1038/s41592-018-0109-9.
20. Farshchian A, Gallego J A, Cohen J P, et al. Adversarial Domain
Adaptation for Stable Brain-Machine Interfaces. in (2019).
21. Sussillo, David and Stavisky, Sergey D and Kao, Jonathan C and Ryu,
Stephen I and Shenoy, K. V. making brain-machine interfaces robust to
future neural variability. Nat. Commun. 7, 13749 (2016).
22. Degenhart, Alan D and Bishop, William E and Oby, Emily R and Tyler-
Kabara, Elizabeth C and Chase, Steven M and Batista, Aaron P and
Byron, M. Y. Stabilization of a brain--computer interface via the
alignment of low-dimensional spaces of neural activity. Nat. Biomed.
Eng. 1--14 (2020).
23. Gerstner, W., Kistler, W. M., Naud, R. & Paninski, L. Neuronal
dynamics: From single neurons to networks and models of cognition.
Neuronal Dynamics: From Single Neurons to Networks and Models of
Cognition (2014). doi:10.1017/CBO9781107447615.
24. Ian J. Goodfellow, Jean Pouget-Abadie ∗, Mehdi Mirza, Bing Xu, David
Warde-Farley, Sherjil Ozair†, Aaron Courville, Y. B. Generative
Adversarial Nets Ian. Vet. Immunol. Immunopathol. (2013)
doi:10.1016/j.vetimm.2013.08.005.
25. Wei, L., Hu, L., Kim, V., Yumer, E. & Li, H. Real-Time Hair Rendering
Using Sequential Adversarial Networks. in Lecture Notes in Computer
Science (including subseries Lecture Notes in Artificial Intelligence and
Lecture Notes in Bioinformatics) (2018). doi:10.1007/978-3-030-01225-
0_7.
26. Jetchev, N. & Bergmann, U. The Conditional Analogy GAN: Swapping
Fashion Articles on People Images. in Proceedings - 2017 IEEE
International Conference on Computer Vision Workshops, ICCVW 2017
(2018). doi:10.1109/ICCVW.2017.269.
27. Chen, Xi and Duan, Yan and Houthooft, Rein and Schulman, John and
Sutskever, Ilya and Abbeel, P. Infogan: interpretable representation
129
learning by information maximizing generative adversasrial nets. in
Advances in neural information processing systems (2016).
28. Mirza, M. & Osindero, S. CGAN. CoRR (2014)
doi:10.1017/CBO9781139058452.
29. Odena, Augustus and Olah, Christopher and Shlens, J. Conditional
image synthesis with auxiliary classifier gans. (2017).
30. Daniel Ho, Eric Liang, Ion Stoica, Pieter Abbeel, X. C. Population Based
Augmentation: Efficient Learning of Augmentation Policy Schedules. in
ICML (2019).
31. Salfi, F., D’Atri, A., Tempesta, D., De Gennaro, L. & Ferrara, M.
Boosting slow oscillations during sleep to improve memory function in
elderly people: A review of the literature. Brain Sciences (2020)
doi:10.3390/brainsci10050300.
32. Arnold, A., Nallapati, R. & Cohen, W. W. A comparative study of
methods for transductive transfer learning. in Proceedings - IEEE
International Conference on Data Mining, ICDM (2007).
doi:10.1109/ICDMW.2007.109.
33. Min Lin, Qiang Chen, S. Y. Network in network. (2013).
34. Tchumatchenko, T., Geisel, T., Volgushev, M. & Wolf, F. Spike
correlations - What can they tell about synchrony? Frontiers in
Neuroscience (2011) doi:10.3389/fnins.2011.00068.
35. Dyer, E. L. et al. A cryptography-based approach for movement
decoding. Nat. Biomed. Eng. (2017) doi:10.1038/s41551-017-0169-7.
36. Ijspeert, A. J., Nakanishi, J., Hoffmann, H., Pastor, P. & Schaal, S.
Dynamical movement primitives: Learning attractor models formotor
behaviors. Neural Computation (2013) doi:10.1162/NECO_a_00393.
37. Schaal, S. Dynamic Movement Primitives -A Framework for Motor
Control in Humans and Humanoid Robotics. Adapt. Motion Anim. Mach.
(2006).
38. Poggio, T. & Bizzi, E. Generalization in vision and motor control. Nature
130
(2004) doi:10.1038/nature03014.
39. Nuyujukian, P. et al. Performance sustaining intracortical neural
prostheses. J. Neural Eng. (2014) doi:10.1088/1741-2560/11/6/066003.
40. Thoroughman, K. A. & Shadmehr, R. Learning of action through
adaptive combination of motor primitives. Nature (2000)
doi:10.1038/35037588.
41. Stroud, J. P., Porter, M. A., Hennequin, G. & Vogels, T. P. Motor
primitives in space and time via targeted gain modulation in cortical
networks. Nat. Neurosci. (2018) doi:10.1038/s41593-018-0276-0.
42. Shankar, T., Pinto, L., Tulsiani, S. & Gupta, A. DISCOVERING MOTOR
PROGRAMS BY RECOMPOSING DEMONSTRATIONS. in ICLR 2020
(2020).
43. Costa, R. M., Ganguly, K., Costa, R. M. & Carmena, J. M. Emergence
of Coordinated Neural Dynamics Underlies Neuroprosthetic Learning
and Skillful Control. Neuron (2017) doi:10.1016/j.neuron.2017.01.016.
44. Golub, M. D. et al. Learning by neural reassociation. Nat. Neurosci.
(2018) doi:10.1038/s41593-018-0095-3.
45. Sadtler, P. T. et al. Neural constraints on learning. Nature (2014)
doi:10.1038/nature13665.
46. Gold, J. I. & Shadlen, M. N. The Neural Basis of Decision Making. Annu.
Rev. Neurosci. (2007) doi:10.1146/annurev.neuro.29.051605.113038.
47. Felsen, G. & Dan, Y. A natural approach to studying vision. Nature
Neuroscience (2005) doi:10.1038/nn1608.
48. Paninski, L., Pillow, J. & Lewi, J. Statistical models for neural encoding,
decoding, and optimal stimulus design. Progress in Brain Research
(2007) doi:10.1016/S0079-6123(06)65031-0.
49. Paninski, L. Superlinear Population Encoding of Dynamic Hand
Trajectory in Primary Motor Cortex. J. Neurosci. (2004)
doi:10.1523/JNEUROSCI.0919-04.2004.
50. Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I. &
131
Salakhutdinov, R. Dropout: A simple way to prevent neural networks
from overfitting. J. Mach. Learn. Res. (2014).
51. Senior, A., Heigold, G., Ranzato, M. & Yang, K. An empirical study of
learning rates in deep neural networks for speech recognition. in
ICASSP, IEEE International Conference on Acoustics, Speech and
Signal Processing - Proceedings (2013).
doi:10.1109/ICASSP.2013.6638963.
52. Schuster, M. & Paliwal, K. K. Bidirectional recurrent neural networks.
IEEE Trans. Signal Process. (1997) doi:10.1109/78.650093.
53. LeCun, Y. A theoretical framework for Back-Propagation. Proceedings
of the 1988 Connectionist Models Summer School (1988)
doi:10.1007/978-3-642-35289-8.
54. Odena, A. Semi-supervised learning with generative adversarial
networks. (2016).
55. Montavon, G. et al. Neural Networks: Tricks of the Trade. Springer
Lecture Notes in Computer Sciences (2012).
56. Kingma, D. P. & Ba, J. L. Adam: A method for stochastic optimization. in
3rd International Conference on Learning Representations, ICLR 2015 -
Conference Track Proceedings (2015).
57. Hatsopoulos, N., Joshi, J. & O’Leary, J. G. Decoding Continuous and
Discrete Motor Behaviors Using Motor and Premotor Cortical
Ensembles. J. Neurophysiol. (2004) doi:10.1152/jn.01245.2003.
58. Black, M. J. et al. Connecting brains with machines: the neural control of
2D cursor movement. in International IEEE/EMBS Conference on
Neural Engineering, NER (2003). doi:10.1109/CNE.2003.1196893.
59. Wu, W. et al. Modeling and decoding motor cortical activity using a
switching Kalman filter. IEEE Trans. Biomed. Eng. (2004)
doi:10.1109/TBME.2004.826666.
60. Brockwell, A. E. Recursive Bayesian Decoding of Motor Cortical Signals
by Particle Filtering. J. Neurophysiol. (2004) doi:10.1152/jn.00438.2003.
132
61. Shoham, S. et al. Statistical encoding model for a primary motor cortical
brain-machine interface. IEEE Trans. Biomed. Eng. (2005)
doi:10.1109/TBME.2005.847542.
62. Shanechi, Maryam Modir and Wornell, Gregory W and Williams, Ziv and
Brown, E. N. A parallel point-process filter for estimation of goal-directed
movements from neural signals. in 2010 IEEE International Conference
on Acoustics, Speech and Signal Processing (2010).
63. Truccolo, W., Eden, U., Fellows, M., Donoghue, J. & Brown, E. A Point
Process Framework for Relating Neural Spiking Activity to Spiking
History, Neural Ensemble, and Extrinsic Covariate Effects. J.
Neurophysiol. (2005) doi:doi: 10.1152/jn.00697.2004.
64. Shanechi, M. M., Wornell, G. W., Williams, Z. M. & Brown, E. N.
Feedback-controlled parallel point process filter for estimation of goal-
directed movements from neural signals. IEEE Trans. Neural Syst.
Rehabil. Eng. (2013) doi:10.1109/TNSRE.2012.2221743.
65. Shanechi, M. M., Orsborn, A. L. & Carmena, J. M. Robust Brain-
Machine Interface Design Using Optimal Feedback Control Modeling
and Adaptive Point Process Filtering. PLoS Comput. Biol. (2016)
doi:10.1371/journal.pcbi.1004730.
66. Shanechi, M. M. et al. Rapid control and feedback rates enhance
neuroprosthetic control. Nat. Commun. (2017)
doi:10.1038/ncomms13825.
67. Glaser, Joshua I and Chowdhury, Raeed H and Perich, Matthew G and
Miller, Lee E and Kording, K. P. Machine learning for neural decoding.
in arXiv preprint arXiv:1708.00909 (2017).
68. Tseng, P.-H., Urpi, N. A., Lebedev, M. & Nicolelis, M. Decoding
movements from cortical ensemble activity using a long short-term
memory recurrent network. Neural Comput. 31, 1085–1113 (2019).
69. Sadras, N., Pesaran, B. & Shanechi, M. M. A point-process matched
filter for event detection and decoding from population spike trains. J.
133
Neural Eng. 16, 66016 (2019).
70. Harth, E. M., Csermely, T. J., Beek, B. & Lindsay, R. D. Brain functions
and neural dynamics. J. Theor. Biol. 26, 93–120 (1970).
71. Gerstner, W. & Kistler, W. M. Spiking neuron models: Single neurons,
populations, plasticity. (Cambridge university press, 2002).
72. Dayan, P. & Abbott, L. F. Theoretical neuroscience: computational and
mathematical modeling of neural systems. (Computational
Neuroscience Series, 2001).
73. Gollisch, T. & Meister, M. Rapid neural coding in the retina with relative
spike latencies. Science (80-. ). 319, 1108–1111 (2008).
74. Aldworth, Z. N., Dimitrov, A. G., Cummins, G. I., Gedeon, T. & Miller, J.
P. Temporal encoding in a nervous system. PLoS Comput Biol 7,
e1002041 (2011).
75. Hallock, R. M. & Di Lorenzo, P. M. Temporal coding in the gustatory
system. Neurosci. \& Biobehav. Rev. 30, 1145–1160 (2006).
76. Jolivet, R., Rauch, A., Lüscher, H.-R. & Gerstner, W. Predicting spike
timing of neocortical pyramidal neurons by simple threshold models. J.
Comput. Neurosci. 21, 35–49 (2006).
77. Carleton, A., Accolla, R. & Simon, S. A. Coding in the mammalian
gustatory system. Trends Neurosci. 33, 326–334 (2010).
78. Kostal, L., Lansky, P. & Rospars, J.-P. Neuronal coding and spiking
randomness. Eur. J. Neurosci. 26, 2693–2701 (2007).
79. Discrete Wavelet Transform (DWT). in Encyclopedia of Multimedia (ed.
Furht, B.) 188 (Springer US, 2008). doi:10.1007/978-0-387-78414-
4_305.
80. Daubechies, I. Ten lectures on wavelets. vol. 61 (Siam, 1992).
81. Butts, D. A. et al. Temporal precision in the neural code and the
timescales of natural vision. Nature 449, 92–95 (2007).
82. Quiroga, R. Q., Nadasdy, Z. & Ben-Shaul, Y. Unsupervised spike
detection and sorting with wavelets and superparamagnetic clustering.
134
Neural Comput. 16, 1661–1687 (2004).
83. Yang, Y., Kamboh, A. & Andrew, J. M. Adaptive threshold spike
detection using stationary wavelet transform for neural recording
implants. in 2010 Biomedical Circuits and Systems Conference
(BioCAS) 9–12 (2010).
84. Brychta, R. J. et al. Wavelet methods for spike detection in mouse renal
sympathetic nerve activity. IEEE Trans. Biomed. Eng. 54, 82–93 (2006).
85. Robinson, N., Vinod, A. P., Guan, C., Ang, K. K. & Peng, T. K. A
Wavelet-CSP method to classify hand movement directions in EEG
based BCI system. in 2011 8th International Conference on Information,
Communications \& Signal Processing 1–5 (2011).
86. Robinson, N., Vinod, A. P., Ang, K. K., Tee, K. P. & Guan, C. T. EEG-
based classification of fast and slow hand movements using wavelet-
CSP algorithm. IEEE Trans. Biomed. Eng. 60, 2123–2132 (2013).
87. Zhang, M. et al. Extracting wavelet based neural features from human
intracortical recordings for neuroprosthetics applications. Bioelectron.
Med. 4, 11 (2018).
88. Carotti, E. S. G., Shalchyan, V., Jensen, W. & Farina, D. Denoising and
compression of intracortical signals with a modified MDL criterion. Med.
\& Biol. Eng. \& Comput. 52, 429–438 (2014).
89. Lee, G. C. F., Libedinsky, C., Guan, C. & So, R. Use of wavelet
transform coefficients for spike detection for a Robust Intracortical Brain
Machine Interface. in 2017 8th International IEEE/EMBS Conference on
Neural Engineering (NER) 540–543 (2017).
90. McCloskey, M. & Cohen, N. J. Catastrophic Interference in
Connectionist Networks: The Sequential Learning Problem. Psychol.
Learn. Motiv. - Adv. Res. Theory (1989) doi:10.1016/S0079-
7421(08)60536-8.
91. French, R. M. Dynamically constraining connectionist networks to
produce distributed, orthogonal representations to reduce catastrophic
135
interference. in Proceedings of the Sixteenth Annual Conference of the
Cognitive Science Society (2019). doi:10.4324/9781315789354-58.
92. Rios, A. & Itti, L. Closed-Loop GAN for continual Learning. arXiv Prepr.
arXiv1811.01146 (2018).
93. Kirkpatrick, J. et al. Overcoming catastrophic forgetting in neural
networks. Proc. Natl. Acad. Sci. U. S. A. (2017)
doi:10.1073/pnas.1611835114.
94. Cheung, B., Terekhov, A., Chen, Y., Agrawal, P. & Olshausen, B.
Superposition of many models into one. in Advances in Neural
Information Processing Systems (2019).
95. Rusu, A. A. et al. Progressive neural networks. arXiv Prepr.
arXiv1606.04671 (2016).
96. Yoon, J., Yang, E., Lee, J. & Hwang, S. J. Lifelong Learning with
Dynamically Expandable Networks. (2018).
97. Lopez-Paz, D. & others. Gradient episodic memory for continual
learning. in Advances in Neural Information Processing Systems 6467–
6476 (2017).
98. Zeng, G., Chen, Y., Cui, B. & Yu, S. Continual learning of context-
dependent processing in neural networks. Nat. Mach. Intell. 1, 364–372
(2019).
99. Mallya, A., Davis, D. & Lazebnik, S. Piggyback: Adapting a single
network to multiple tasks by learning to mask weights. in Proceedings of
the European Conference on Computer Vision (ECCV) 67–82 (2018).
100. Du, X., Charan, G., Liu, F. & Cao, Y. Single-Net Continual Learning with
Progressive Segmented Training (PST). arXiv Prepr. arXiv1905.11550
(2019).
101. Yoon, J., Kim, S., Yang, E. & Hwang, S. J. ORACLE: Order robust
adaptive continual learning. (2019).
102. LeCun, Y., Bottou, L., Bengio, Y. & Haffner, P. Gradient-based learning
applied to document recognition. Proc. IEEE 86, 2278–2324 (1998).
136
103. Krizhevsky, A. & Hinton, G. Learning multiple layers of features from
tiny images. (2009).
104. Rebuffi, S.-A., Kolesnikov, A., Sperl, G. & Lampert, C. H. icarl:
Incremental classifier and representation learning. in Proceedings of the
IEEE Conference on Computer Vision and Pattern Recognition 2001–
2010 (2017).
105. Rannen, A., Aljundi, R., Blaschko, M. B. & Tuytelaars, T. Encoder based
lifelong learning. in Proceedings of the IEEE International Conference
on Computer Vision 1320–1328 (2017).
106. Farajtabar, M., Azizan, N., Mott, A. & Li, A. Orthogonal gradient descent
for continual learning. in International Conference on Artificial
Intelligence and Statistics 3762–3773 (2020).
107. Srivastava, R. K., Masci, J., Kazerounian, S., Gomez, F. &
Schmidhuber, J. Compete to compute. in Advances in neural
information processing systems 2310–2318 (2013).
108. Masse, N. Y., Grant, G. D. & Freedman, D. J. Alleviating catastrophic
forgetting using context-dependent gating and synaptic stabilization.
Proc. Natl. Acad. Sci. 115, E10467--E10475 (2018).
109. Lee, S. W., Kim, J. H., Jun, J., Ha, J. W. & Zhang, B. T. Overcoming
catastrophic forgetting by incremental moment matching. in Advances in
Neural Information Processing Systems (2017).
110. Aljundi, R., Babiloni, F., Elhoseiny, M., Rohrbach, M. & Tuytelaars, T.
Memory Aware Synapses: Learning What (not) to Forget BT - Computer
Vision – ECCV 2018. Eccv (2018).
111. Li, Z. & Hoiem, D. Learning without Forgetting. IEEE Trans. Pattern
Anal. Mach. Intell. (2018) doi:10.1109/TPAMI.2017.2773081.
112. Florian Tramèr, Nicolas Papernot, Ian Goodfellow, D. B. P. M. The
Space of Transferable Adversarial Examples. in arXiv.1704.03453
(2017).
137
113. Ilyas, A. et al. Adversarial examples are not bugs, they are features. in
Advances in Neural Information Processing Systems (2019).
114. Elsayed, G. F., Goodfellow, I. & Sohl-Dickstein, J. Adversarial
Reprogramming of Neural Networks. arXiv Prepr. arXiv1806.11146
(2018).
115. Nilsback, M.-E. & Zisserman, A. Automated flower classification over a
large number of classes. in 2008 Sixth Indian Conference on Computer
Vision, Graphics \& Image Processing 722–729 (2008).
116. Quattoni, A. & Torralba, A. Recognizing indoor scenes. in 2009 IEEE
Conference on Computer Vision and Pattern Recognition 413–420
(2009).
117. Wah, C., Branson, S., Welinder, P., Perona, P. & Belongie, S. The
caltech-ucsd birds-200-2011 dataset. (2011).
118. Krause, J., Stark, M., Deng, J. & Fei-Fei, L. 3d object representations
for fine-grained categorization. in Proceedings of the IEEE International
Conference on Computer Vision Workshops 554–561 (2013).
119. Maji, S., Rahtu, E., Kannala, J., Blaschko, M. & Vedaldi, A. Fine-grained
visual classification of aircraft. arXiv Prepr. arXiv1306.5151 (2013).
120. Everingham, M. et al. The pascal visual object classes challenge: A
retrospective. Int. J. Comput. Vis. 111, 98–136 (2015).
121. De Campos, T. E., Babu, B. R., Varma, M. & others. Character
recognition in natural images. VISAPP (2) 7, (2009).
122. Netzer, Y. et al. Reading digits in natural images with unsupervised
feature learning. (2011).
123. Krizhevsky, A., Sutskever, I. & Hinton, G. E. Imagenet classification with
deep convolutional neural networks. in Advances in neural information
processing systems 1097–1105 (2012).
124. Russakovsky, O. et al. Imagenet large scale visual recognition
challenge. Int. J. Comput. Vis. 115, 211–252 (2015).
125. Aljundi, R., Babiloni, F., Elhoseiny, M., Rohrbach, M. & Tuytelaars, T.
138
Memory aware synapses: Learning what (not) to forget. in Proceedings
of the European Conference on Computer Vision (ECCV) 139–154
(2018).
126. Aljundi, R., Rohrbach, M. & Tuytelaars, T. Selfless sequential learning.
arXiv Prepr. arXiv1806.05421 (2018).
127. Zenke, F., Poole, B. & Ganguli, S. Continual learning through synaptic
intelligence. in 34th International Conference on Machine Learning,
ICML 2017 (2017).
128. Kemker, R. & Kanan, C. Fearnet: Brain-inspired model for incremental
learning. arXiv Prepr. arXiv1711.10563 (2017).
129. Parisi, G. I., Kemker, R., Part, J. L., Kanan, C. & Wermter, S. Continual
lifelong learning with neural networks: A review. Neural Networks (2019)
doi:10.1016/j.neunet.2019.01.012.
130. Kemker, R., McClure, M., Abitino, A., Hayes, T. L. & Kanan, C.
Measuring catastrophic forgetting in neural networks. in Thirty-second
AAAI conference on artificial intelligence (2018).
131. Goodfellow, I. J., Shlens, J. & Szegedy, C. Explaining and harnessing
adversarial examples (2014). arXiv Prepr. arXiv1412.6572 (2014).
132. Goodfellow. Adversarial Examples and Adversarial Training. Stanford
cs231n Lecture16 slides. (2017).
133. Bakker, A., Kirwan, C. B., Miller, M. & Stark, C. E. L. Pattern separation
in the human hippocampal CA3 and dentate gyrus. Science (80-. ).
(2008) doi:10.1126/science.1152882.
134. Lesburguères, E. et al. Early tagging of cortical networks is required for
the formation of enduring associative memory. Science (80-. ). (2011)
doi:10.1126/science.1196164.
135. Squire, L. R. & Alvarez, P. Retrograde amnesia and memory
consolidation: a neurobiological perspective. Curr. Opin. Neurobiol.
(1995) doi:10.1016/0959-4388(95)80023-9.
136. Frankland, P. W. & Bontempi, B. The organization of recent and remote
139
memories. Nat. Rev. Neurosci. 6, 119 (2005).
137. Helfrich, R. F. et al. Bidirectional prefrontal-hippocampal dynamics
organize information transfer during sleep in humans. Nat. Commun.
10, 1–16 (2019).
138. Roy, J. E., Riesenhuber, M., Poggio, T. & Miller, E. K. Prefrontal cortex
activity during flexible categorization. J. Neurosci. 30, 8519–8528
(2010).
139. Cromer, J. A., Roy, J. E. & Miller, E. K. Representation of multiple,
independent categories in the primate prefrontal cortex. Neuron 66,
796–807 (2010).
140. Flesch, T., Balaguer, J., Dekker, R., Nili, H. & Summerfield, C.
Comparing continual task learning in minds and machines. Proc. Natl.
Acad. Sci. 115, E10313--E10322 (2018).
141. Yang, G. R., Joglekar, M. R., Song, H. F., Newsome, W. T. & Wang, X.-
J. Task representations in neural networks trained to perform many
cognitive tasks. Nat. Neurosci. 22, 297–306 (2019).
142. Brown, T. B. et al. Language models are few-shot learners. arXiv Prepr.
arXiv2005.14165 (2020).
143. He, K., Zhang, X., Ren, S. & Sun, J. Deep residual learning for image
recognition. in Proceedings of the IEEE conference on computer vision
and pattern recognition 770–778 (2016).
144. Devlin, J., Chang, M.-W., Lee, K. & Toutanova, K. Bert: Pre-training of
deep bidirectional transformers for language understanding. arXiv
Prepr. arXiv1810.04805 (2018).
145. Szegedy, C. et al. Intriguing properties of neural networks. arXiv Prepr.
arXiv1312.6199 (2013).
146. Biggio, B. et al. Evasion attacks against machine learning at test time. in
Joint European conference on machine learning and knowledge
discovery in databases 387–402 (2013).
147. Di, X., Yu, P. & Tian, M. Towards Adversarial Training with Moderate
140
Performance Improvement for Neural Network Classification. arXiv
Prepr. arXiv1807.00340 (2018).
148. Tramèr, F. et al. Ensemble adversarial training: Attacks and defenses.
arXiv Prepr. arXiv1705.07204 (2017).
149. Raghunathan, A., Xie, S. M., Yang, F., Duchi, J. C. & Liang, P.
Adversarial Training Can Hurt Generalization. arXiv Prepr.
arXiv1906.06032 (2019).
150. Stanforth, R., Fawzi, A., Kohli, P. & others. Are Labels Required for
Improving Adversarial Robustness? arXiv Prepr. arXiv1905.13725
(2019).
151. Zhang, H. et al. Theoretically principled trade-off between robustness
and accuracy. arXiv Prepr. arXiv1901.08573 (2019).
152. Huang, R., Xu, B., Schuurmans, D. & Szepesvári, C. Learning with a
Strong Adversary. CoRR abs/1511.03034, (2015).
153. Kannan, H., Kurakin, A. & Goodfellow, I. Adversarial logit pairing. arXiv
Prepr. arXiv1803.06373 (2018).
154. Xie, C., Wu, Y., Maaten, L. van der, Yuille, A. L. & He, K. Feature
denoising for improving adversarial robustness. in Proceedings of the
IEEE Conference on Computer Vision and Pattern Recognition 501–509
(2019).
155. Xiao, H., Rasul, K. & Vollgraf, R. Fashion-mnist: a novel image dataset
for benchmarking machine learning algorithms. arXiv Prepr.
arXiv1708.07747 (2017).
156. Deng, J. et al. Imagenet: A large-scale hierarchical image database. in
2009 IEEE conference on computer vision and pattern recognition 248–
255 (2009).
157. Madry, A., Makelov, A., Schmidt, L., Tsipras, D. & Vladu, A. Towards
deep learning models resistant to adversarial attacks. arXiv Prepr.
arXiv1706.06083 (2017).
158. Deng, J. et al. ImageNet: A Large-Scale Hierarchical Image Database.
141
in CVPR09 (2009).
159. Kurakin, A., Goodfellow, I. & Bengio, S. Adversarial examples in the
physical world. arXiv Prepr. arXiv1607.02533 (2016).
160. Narodytska, N. & Kasiviswanathan, S. P. Simple black-box adversarial
perturbations for deep networks. arXiv Prepr. arXiv1612.06299 (2016).
161. Wen, S., Rios, A., Ge, Y. & Itti, L. Beneficial Perturbation Network for
Designing General Adaptive Artificial Intelligence Systems. IEEE Trans.
Neural Networks Learn. Syst. (2021)
doi:10.1109/TNNLS.2021.3054423.
162. LeCun, Y. et al. Backpropagation applied to handwritten zip code
recognition. Neural Comput. 1, 541–551 (1989).
163. Huang, G., Liu, Z., Van Der Maaten, L. & Weinberger, K. Q. Densely
connected convolutional networks. in Proceedings of the IEEE
conference on computer vision and pattern recognition 4700–4708
(2017).
164. Tieleman, T. & Hinton, G. E. Neural networks for machine learning.
Coursera (Lecture 65-RMSprop) (2012).
165. Khosla, P. et al. Supervised contrastive learning. arXiv Prepr.
arXiv2004.11362 (2020).
166. Mahajan, D. et al. Exploring the limits of weakly supervised pretraining.
in Proceedings of the European Conference on Computer Vision
(ECCV) 181–196 (2018).
167. Dodge, S. & Karam, L. A study and comparison of human and deep
learning recognition performance under visual distortions. in 2017 26th
international conference on computer communication and networks
(ICCCN) 1–7 (2017).
168. Salman, S. & Liu, X. Overfitting mechanism and avoidance in deep
neural networks. arXiv Prepr. arXiv1901.06566 (2019).
169. Lawrence, S., Giles, C. L. & Tsoi, A. C. Lessons in neural network
training: Overfitting may be harder than expected. in AAAI/IAAI 540–545
142
(1997).
170. Hochreiter, S., Bengio, Y., Frasconi, P., Schmidhuber, J. & others.
Gradient flow in recurrent nets: the difficulty of learning long-term
dependencies. (2001).
171. Han, Y. et al. Dynamic neural networks: A survey. arXiv Prepr.
arXiv2102.04906 (2021).
172. Tajbakhsh, N. et al. Convolutional neural networks for medical image
analysis: Full training or fine tuning? IEEE Trans. Med. Imaging 35,
1299–1312 (2016).
173. Krishnamoorthi, R. Quantization aware training.
174. Tawarmalani, M. & Sahinidis, N. V. Convexification and global
optimization in continuous and mixed-integer nonlinear programming:
theory, algorithms, software, and applications. vol. 65 (Springer Science
\& Business Media, 2013).
175. Sussillo, D., Stavisky, S. D., Kao, J. C., Ryu, S. I. & Shenoy, K. V.
Making brain-machine interfaces robust to future neural variability. Nat.
Commun. (2016) doi:10.1038/ncomms13749.
9. Supplementary
9.1 Supplementary for rapid transfer of brain-machine interfaces to new
neuronal ensembles or participants
9.1.1 Supplementary figures
143
Supplementary Figure 1.1 Step 1: training a neural spike synthesizer on the
neural data from session one of Monkey one to learn a direct mapping from
kinematics to spike trains and to capture the embedded neural attributes. Step
2: freezing the generator that captures the embedded neural attributes and fine-
tuning the readout modules for different sessions or subjects to allow variations
in neural attributes, using the neural data from session two of Monkey one or
the neural data from session one of Monkey two. Then, synthesizing a large
amount of spike trains that are suitable for another session or subject. Step 3:
training a BCI decoder for another session or subject using the same small
amount of real neural data used for fine-tuning (in step 2) and a large amount
144
of synthesized spike trains (in step 2). Step 4: testing the same BCI decoder on
an independent test set from another session or subject.
Supplementary Figure 1.2: For a-d), we calculated the hand velocity direction
for each 300ms and calculated the spike counts during that 300ms. We plotted
the spikes counts vs hand velocity direction for each real and virtual neuron.
The red line is the velocity neural tuning curve fitted by a cosine function. The
black dot is the spike counts for each bin at each angle. The heatmap counts
how many black dots are in an area. a) real velocity neural tuning curve for
neuron 32. b) generated velocity neural tuning curve for neuron 32. c) real
velocity neural tuning curve in velocity space for neuron 57. d) generated
velocity neural tuning curve for neuron 57. e) histogram of preferred direction
for real neurons. f) histogram of preferred direction for virtual neurons.
145
Supplementary Figure 1.3: a) Position activity maps for all virtual and real
neurons with clipped color bar.
146
Supplementary Figure 1.4: Cross-session decoding. The GAN-Augmentation,
Mutation-Augmentation, Stretch-Augmentation, Real-Concatenation and Real-
Only methods are shown in red, purple, orange, blue and green curves with an
error bar. The horizontal axis is the number of minutes of neural data from the
session two of Monkey C used. The vertical axis is correlation coefficient
between the decoded kinematics and real kinematics on an independent test
set from the session two of Monkey C. Synthesized spike trains that capture
the neural attributes accelerate the training of a BCI decoder for the cross-
session decoding.
147
Supplementary Figure 1.5: Cross-subject decoding. The GAN-Augmentation,
Mutation-Augmentation, Stretch-Augmentation, Real-Concatenation and Real-
Only methods are shown in red, purple, orange, blue and green curves with an
error bar. Cross-subject decoding. The horizontal axis is the number of minutes
of neural data from Monkey M used. The vertical axis is the correlation
coefficient between the decoded kinematics and real kinematics on an
independent test set from the Monkey M. When the neural data from another
subject is limited, synthesized spike trains that capture the neural attributes
improve the cross-subject decoding performance on acceleration. Even with
ample neural data for both subjects, the neural attributes learned from one
subject can transfer some useful knowledge that improves the best achievable
decoding performance on the acceleration of another subject.
148
Supplementary Figure 1.6: Correlations across neural spike trains samples
for each neuron sorted by the averaged correlation coefficient (Methods) for
each neuron. (a) Random permutation baseline. Blue curve is the correlation
between synthesized (from the spike synthesizer) and real neural data with
shaded blue error bar. Red curve is the correlation between real spike trains
through random permutation with shaded red error bar. (b) Homogeneous
Poisson distribution baseline. Correlations across neural spike trains samples
for each neuron sorted by the averaged correlation coefficient for each neuron.
Blue curve is the correlation between synthesized (from the spike synthesizer)
and real neural data with shaded blue error bar. Red curve is the correlation
between synthesized neural (from a homogeneous Poisson distribution) and
real neural data with shaded red error bar.
149
Supplementary Figure 1.7, detail structure of CC-LSTM-GAN.
150
Supplementary Figure 1.8, averaged performance for cross-session decoding
(11.73 minutes of neural data from S.2, M.C, which is the amount required for
real-only to achieve good performance) on GAN-augmentation (red curve) and
Real-only (green curve) methods with increasing number of dropped neurons.
We randomly dropped neurons each time and repeated 10 times for each
number of dropped neurons. GAN-augmentation is more sensitive to dropping
small numbers of neurons, suggesting that all artificial neurons contribute to the
overall performance. While Real-only is less sensitive to dropping small
numbers of neurons, its overall performance is lower and noisier with respect
to which exact neurons are dropped (larger error bars).
151
Supplementary Figure 1.9: Visualization examples of actual movement
trajectory for GAN-Augmentation, Mutation-Augmentation, Stretch-
Augmentation, Real-Concatenation and Real-Only methods compared to
ground truth.
152
Supplementary Figure 1.10: Normalized velocity activity map, constructed as
the histogram of neural activity as a function of velocity. (a) Velocity activity
map for real neuron 35 normalized across the workspace. (b) corresponding
velocity activity map for virtual neuron 35. (c,d) Velocity activity maps for real
and virtual neuron 51 (e) Histogram of mean squared error between the real
and generated activity maps for all neurons. The purple line is the trimmed
averaged mean square error (based on 99% samples, 0.11) between real
neurons. It provides a reasonable bound for quantifying the difference between
real and virtual neurons.
153
Supplementary Figure 1.11: Normalized acceleration activity map, constructed
as the histogram of neural activity as a function of velocity. (a) acceleration
activity map for real neuron 35 normalized across the workspace. (b)
corresponding acceleration activity map for virtual neuron 35. (c,d)
acceleration activity maps for real and virtual neuron 51 (e) Histogram of mean
squared error between the real and generated activity maps for all neurons.
The purple line is the trimmed averaged mean square error (based on 99%
samples, 0.10) between real neurons. It provides a reasonable bound for
154
quantifying the difference between real and virtual neurons.
9.1.2 Supplementary Discussion:
Stabilization
David et, al.
175
demonstrated that training an RNN decoder from many months
of previously recorded data can be more robust to future neural variability. Here,
our work is training a spike synthesizer (GAN) to learn good neural attributes
from a single session of neural data. Inspired by that paper, our spike
synthesizer could be trained with more general neural attributes across multiple
sessions and monkeys. With more general neural attributes, the spike
synthesizer could synthesize neural data that would enhance the cross-
sessions and cross-subjects decoding. Our work differs in at least three aspects:
1) Our spike synthesizer could learn more generalizable neural attributes
2) The decoding approach generalizes to multiple sessions and subjects
3) We can achieve saturating performance with much less historical training
data
Alan et, al.
22
demonstrated that a manifold-based stabilizer can help a BCI
decoder recover proficient control under different instability conditions such as
tuning change, drop-outs or baseline shifts. The authors hypothesized that
“even though the specific neurons being recorded may change over time, the
recorded population activity reflects a stable underlying representation of
movement intent that lies within the neural manifold”. This is a sound hypothesis
for the neural data collected from the same subject. However, this approach,
which requires a subset of stable electrodes cannot be used across subjects or
when the specific neurons change dramatically. Even within a subject, the
method requires a significant number of stable neurons in order to accomplish
the realignment. It will fundamentally not work between subjects. In comparison,
our spike synthesizer learned general neural attributes and could quickly adapt
itself to new sessions or subjects using limited additional neural data.
155
Possible Caveat
We could use large amounts neural data from multiple monkeys and multiple
sessions to build a better embedding of GAN. This embedding would be
expected to generalize better. Yet, for cross-subject decoding, more abundant
neural data might yield lower performance for the already “good” covariates
such as pos x and pos y. Since, in our work, the embedding of GAN is learned
from only one monkey, just fine-tuning the read-out module is not good enough
to synthesize neural data that match the distribution of the other subjects
perfectly. Thus, the combination of large amounts of additional neural data with
synthesized neural data could yield a lower performance on already good
covariates. This caveat could be addresses by training our CC-LSTM-GAN with
neural data from multiple monkeys and sessions.
Relative loss weights for equation 11 and 12
To better train a CC-LSTM-GAN, one must assign a big portion of the weights
to the GAN loss discriminator and generator losses (equation 11 and 12).
Weight ratios of other properties (e.g., decoder loss or inner product loss)
should be small compared to these. The exact weights were optimized using
trial and error, as it takes 3-4 days to train the CC-LSTM-GAN on one GPU,
which limited our ability to explore the space of all possible combinations more
thoroughly.
9.2 Supplementary for capturing spike train temporal pattern with Wavelet
Average Coefficient for Brain Machine Interface
9.2.1 Recursive equation of Wavelet framework for Kalman filter with sliding
window augmentation
156
where, 𝐾 is the gain, Σ
∗
is the covariance matrix for variable *.
9.2.2 Classical Kalman filters with sliding window augmentation
We use 5 ms bin size, 50 ms window size, 1 taps, 0 ms lag size and 5 ms
slide size. The state space of Kalman filter is:
𝑥𝑛 1 𝐴 𝑥 𝑛 𝑤𝑛 (6)
𝑦𝑛 𝐶 𝑥 𝑛 𝑣𝑛 (7)
where 𝑛 is the time instance, 𝑦𝑛 is the covariates, 𝑥𝑛 is the spike
counts, 𝑤𝑛 and 𝑣 𝑛 is Gaussian noise with zero mean, 𝐴 and 𝐶 is time
constant parameters need to be estimated in the training part. From the state
space model, the recursive equation of Kalman filter is:
157
where, 𝐾 _ 𝑛 is the gain, Σ
∗
is the covariance matrix for variable ∗.
158
Supplementary Figure 2.S1. Reconstruction neural signal from wavelet
coefficients ( 𝑑 ) and scaling function coefficients ( 𝑐 ). a) Neural spike trains. b)
Neural signal processed by a Poisson kernel. c) Neural signal reconstructed
by coefficients 𝑐 . d) Neural signal reconstructed by coefficients 𝑐 , 𝑑 . e)
Neural signal reconstructed by coefficients 𝑐 , 𝑑 , 𝑑 . f) Neural signal
reconstructed by coefficients 𝑐 , 𝑑 , 𝑑 , 𝑑 . g) Neural signal reconstructed by
coefficients 𝑐 , 𝑑 , 𝑑 , 𝑑 , 𝑑 .
159
Supplementary Figure 2.S2, Visualization of ankle movement. Ankle x has
larger amplitude and period than ankle y. a) Movement of ankle x, The
amplitude is roughly 0.25. The time period is roughly 3.2 seconds. b)
Movement of ankle y. The amplitude is roughly 0.05. The time period is
roughly 1.9 seconds.
160
Supplementary Figure 2.S3. Influence of window size and different
hyperparametersin 5-fold cross-validation for Kalman Filter. a, b) monkey 1’s
center out task for cursor position y and velocity y, 10~ms slide size, db3
basis. c, d) monkey 3’s locomotion task for left ankle x and ankle y, 10~ms
slide size, db3 basis.
161
Supplementary Figure 2.S4. Decoding performance for locomotion tasks and
center-out tasks measured by mean square errors between decoded
covariates and ground truths in 5-fold cross-validation.
9.3 Supplementary for Beneficial Perturbation Network for designing
general adaptive artificial intelligence systems
9.3.1 Clarification of memory storage costs
For the memory storage costs, we only consider what components are
necessary to be stored on the disk after training each task, that is, the cost of
storing the model for later re-use. In other words, the memory storage cost is
defined as “the number of bytes required to store all of the parameters in the
trained model" [19] (in table 2 of Iandola et al.). This is usually the metric that
is reported along with the number of operations to run a model (e.g., mobilenet
web page, darknet web page, SqueezeNet [19] and Additive Parameter
Superposition [52]). The extra memory storage costs of EWC are zero under
this definition. Indeed, even though it requires a lot of transient memory during
162
training, in the end the contents of this memory are used only to constrain
network weight updates, and they are discarded once a training run is complete.
At test time, a network trained with EWC has the same number of parameters,
uses the same amount of runtime memory, and the same amount of operations
as the original model. For example, consider that after training 5 sequential
tasks, we want to train a new task 6. There are 5 steps: 1) Load the trained
model from disk; 2) EWC would make a duplication of the parameters learned
so far and just loaded from disk, and put them into transient memory (RAM); 3)
During the training of task 6, EWC calculates the Fisher information matrix and
applies the EWC constraints, which relies on the contents of the transient
memory; 4) Delete the duplications of the parameters in the transient memory;
5) Save the parameters of latest model onto disk. In contrast, after training each
task, GEM needs to store some images from that task onto the disk. Likewise,
PSP needs to store the context matrix for each new task to disk, and BD + EWC
needs to store the bias units to disk. Thus, the extra memory storage costs of
BD + EWC is just the memory storage costs of the bias units (BD).
9.3.2 Clarification of parameter costs
Similar to memory storage costs, we only consider what components are
necessary to be stored on disk after training each task (0.3% increase per task),
that is, the cost of storing the model for later re-use. It should be noted that BPN
needs large additional weight matrices (called 𝑊 in the paper) during training.
Likewise, EWC essentially doubles the size of the network during training, to
create the Fisher information matrix used by this method. However, both our
weight matrices and EWC's Fisher information matrix are discarded after
training. So, while the overall growth in the number of parameters is negligible,
the number of parameters needed during training is surely higher than with
vanilla SGD.
163
9.3.3 Choice of Hyperparameter
We found a large λ (2 * 10
) for EWC constraint in Eqn.6 can effectively prevent
a large parameter drifting from old tasks. However, if the λ is too large, the
strict constraint would hinder the learning of new tasks.
We tested the hyperparameter H for 𝑀 and 𝑊 from 1 to 2500. The more
complex the task, the larger H is needed for the BPN since it provides more
degree of freedom to better learn beneficial perturbations [17, 7]. For example,
The H is 25 ~ 255 for Permuted MNIST task and is 100 ~ 900 for eight
sequential object recognition tasks. However, if the H is too large (e.g, 2500 for
eight object recognition tasks) that does not match the complexity of the task,
the performance of the BPN would decrease.
After 5 epochs of training, the network converged in Fig.3.6 as only 2 classes
per task need to be trained in that figure (but see Tab.3.2 and Fig.3.7 for the
more complex 8-dataset where each task may have up to 200+ classes; this
one required more epochs per task to converge, up to 300 epochs per dataset)
9.3.4 Difference between Transfer Learning and Continual Learning
Continual learning is a different idea from transfer learning.
For transfer learning, after learning the first task and when receiving a new task
to learn, the parameters of the network are finetuned on the new task data.
Thus, transfer learning is expected to suffer from forgetting the old tasks while
being advantageous for the new task. Though, the shared convolutional layers
benefits from a more general embeddings learned from a much more difficult
task (pretrained on ImageNet model).
164
In comparison, for continual learning, the focus is on learning new independent
tasks sequentially without forgetting previous task. To achieve this focus, our
BPN updates the shared normal weight using EWC or PSP. In theory, this
update lead to orthogonal, hence non-overlapping and local task representation.
However, in reality, the overlapping is inevitable because the parameters
become too constrained for EWC, or PSP runs out of unrealized capacity of the
core network. Thus, we introduce bias units trained independently for each task.
At test time, bias units for a given task are activated to push representations of
old tasks back to their initial task-optimal working regions.
One of the benefits of continual learning is that learning new tasks can be aided
by the knowledge already accumulated while learning the previous tasks.
9.3.5 Algorithms for BD+PSP
The forward and backward rules for BD + PSP are detailed in Alg.3 and Alg.4.
165
Supplementary Figure 3.S1. The dashed line indicates the start of a new task.
(a) Flow chart of a typical Type 4 method, here illustrated using a pictorial
representation similar to that of Progressive Segmented Training (PST) [8]. PST
subdivides the core network by freezing the important weight parameters for
old tasks, and allowing new tasks to update the remaining free weight
parameters. (b) Flow chart of Beneficial Perturbation Network (BPN; new type
5 method). BPN adds task-dependent beneficial perturbations to the activations,
biasing the network toward that task, and retrains the weight parameters of the
core network, with constraints from EWC [23], PSP [4], or other similar
approach.
Abstract (if available)
Abstract
Recent technology improvements such as the emerging of deep neural network enable us to have a better understanding the mechanisms of human brains. Reversely, better understanding the mechanisms of human brains help us build a better bio-inspired system, letting them to control themselves like humans, think like humans and react like humans. The two entities form a symbiotic relationship. During the last five years study with Prof. Dr. Laurent Itti, I have been working on several projects to explore the interaction between better understanding the human brains using machine learning tools and designing a better bio-inspired deep machine learning system. There are two research pathways: (i.) From Artificial Intelligence system (A.I.) to brain - In the first pathway, we leverage current mathematic and A.I. tools (e.g., Deep generative model, Recurrent neural network, wavelet transform) to better understanding brain functions and structures. These understandings would improve the applicability of brain computer interfaces and enhance the diagnosis and intervention of brain disorders; (ii.) From brain to A.I. - In the second pathway, current A.I. systems have some limitations (e.g., catastrophic forgetting, adversarial examples). We leveraged inspirations from primate brains (e.g., Hippocampus, Vision Cortex) to design better bio-inspired A.I. systems to address these limitations. These improvements would help us to design a next generational general and adaptive artificial intelligence systems.
Linked assets
University of Southern California Dissertations and Theses
Conceptually similar
PDF
Neuroscience inspired algorithms for lifelong learning and machine vision
PDF
Learning controllable data generation for scalable model training
PDF
A meta-interaction model for designing cellular self-organizing systems
PDF
Learning invariant features in modulatory neural networks through conflict and ambiguity
PDF
Responsible artificial intelligence for a complex world
PDF
Dynamical representation learning for multiscale brain activity
PDF
Federated and distributed machine learning at scale: from systems to algorithms to applications
PDF
Artificial Decision Intelligence: integrating deep learning and combinatorial optimization
PDF
Deciphering protein-nucleic acid interactions with artificial intelligence
PDF
Visual representation learning with structural prior
PDF
Towards learning generalization
PDF
Exploiting mechanical properties of bipedal robots for proprioception and learning of walking
PDF
The symbolic working memory system
PDF
Smart monitoring and autonomous situation classification of humans and machines
PDF
Biologically inspired approaches to computer vision
PDF
Neural networks for narrative continuation
PDF
Foundation models for embodied AI
PDF
Transfer learning for intelligent systems in the wild
PDF
Learning lists and gestural signs: dyadic brain models of non-human primates
PDF
Artificial intelligence for low resource communities: Influence maximization in an uncertain world
Asset Metadata
Creator
Wen, Shixian
(author)
Core Title
Interaction between Artificial Intelligence Systems and Primate Brains
School
Viterbi School of Engineering
Degree
Doctor of Philosophy
Degree Program
Computer Science
Degree Conferral Date
2022-05
Publication Date
03/01/2022
Defense Date
01/25/2022
Publisher
University of Southern California
(original),
University of Southern California. Libraries
(digital)
Tag
artificial intelligence,bio-inspired algorithms,brain-computer interface,lifelong learning,Neuroscience,OAI-PMH Harvest
Format
application/pdf
(imt)
Language
English
Contributor
Electronically uploaded by the author
(provenance)
Advisor
Itti, Laurent (
committee chair
), Nevatia, Ram (
committee member
), Schweighofer, Nicolas (
committee member
)
Creator Email
shixianw@usc.edu
Permanent Link (DOI)
https://doi.org/10.25549/usctheses-oUC110765154
Unique identifier
UC110765154
Legacy Identifier
etd-WenShixian-10409
Document Type
Dissertation
Format
application/pdf (imt)
Rights
Wen, Shixian
Type
texts
Source
20220301-usctheses-batch-914
(batch),
University of Southern California
(contributing entity),
University of Southern California Dissertations and Theses
(collection)
Access Conditions
The author retains rights to his/her dissertation, thesis or other graduate work according to U.S. copyright law. Electronic access is being provided by the USC Libraries in agreement with the author, as the original true and official version of the work, but does not grant the reader permission to use the work if the desired use is covered by copyright. It is the author, as rights holder, who must provide use permission if such use is covered by copyright. The original signature page accompanying the original submission of the work to the USC Libraries is retained by the USC Libraries and a copy of it may be obtained by authorized requesters contacting the repository e-mail address given.
Repository Name
University of Southern California Digital Library
Repository Location
USC Digital Library, University of Southern California, University Park Campus MC 2810, 3434 South Grand Avenue, 2nd Floor, Los Angeles, California 90089-2810, USA
Repository Email
cisadmin@lib.usc.edu
Tags
artificial intelligence
bio-inspired algorithms
brain-computer interface
lifelong learning