Introduction

This project was done while working for Agile Systems Lab @Georgia Tech. Hawkmoths display remarkable agility in their flight. Unlike traditional flight mechanisms, flapping flight is much more complex and relies on synchronous movement of flight muscles. Time codes encode more information about a flight than the spike count. This helps hawkmoths track and respond to external stimuli.

Goals

Action potentials from all major wing muscles of a tethered hawkmoth (Manduca sexta) were recorded as it responded to a stimulus of a flower. The motion of the flower was simulated and projected to estimate how the moth perceives it.

O'Shaughnessy et al (2020) developed a framework to find the causal explanations in a network's encoded latent space. This framework could find out what latent factors were responsible for what changes in the output of a given mapping. The RNN was able to predict muscle spikes given an input sequence of frames but it was not clear what was the causal reason behind the output.

The goal of the project was to determine the causal latent factors in the pretrained Recurrent Neural Network (RNN) using this Generative Causal Explanations (GCE) framework. 

Work Done

Unlike traditional vision, hawkmoths perceive changes in time. So an event polarity matrix based representation was made to estimate how the hawkmoth sees the events. A Recurrent Neural Network was already trained that mapped this event polarity matrix to the spikes in flight muscles.

The GCE framework uses a Convolutional Variational AutoEncoder (CVAE) as its base to encode the frames into a latent space. This is done using some initial convolutional and pooling layers and then linear layers in a neural network. The last encoder layer maps the input data to to the high dimensional latent space distribution in which it lies.

New datapoints can be then randomly generated from this learned distribution and then passed via a decoder network that helps us generate new datapoints similar to the original encoded data.

A normal VAE would use the loss function to measure the similarity between the input and output frames. This would ensure that the generated output is from the same gaussian as the one into which the input gets encoded. But in order to measure the causal influence of the factors in the latent space, we need to have a measure against which me measure causality. Thus the framework introduces the classifier (for which we want to find causal factors) at the end of the VAE decoder. 

Results

The convolutional variational autoencoder was initially tested on the MNIST 3-8 digit dataset which is also used by the authors in their paper. Sweeping across some of the dimensions in the latent space led to the digits changing from 3 to 8. While sweeping across some other dimensions only changed the skew or thickness of the digit lines. This is essentially biasing our network latent space to the classification of 3-8 digits.

Later the same model was tested on the actual EPM frame sequence. First the training was done without the causal term as it wasn't obvious what should be the classification criteria for it. The results obtained were pretty satisfactory as the new frames generated were matching the original ones with just 8 latent parameters used.

A normal VAE would use the loss function to measure the similarity between the input and output frames. This would ensure that the generated output is from the same gaussian as the one into which the input gets encoded. But in order to measure the causal influence of the factors in the latent space, we need to have a measure against which me measure causality. Thus the framework introduces the classifier (for which we want to find causal factors) at the end of the VAE decoder.

Later the same model was tested on the actual EPM frame sequence. First the training was done without the causal term as it wasn't obvious what should be the classification criteria for it. The results obtained were pretty satisfactory as the new frames generated were matching the original ones with just 8 latent parameters used.

But after this, there were many issues due to which the GCE framework could not be implemented on the data successfully.

The classifier for which we needed to find causal factors was the RNN. However, to find the causal factors, a large number of random samples are generated from the dataset.

Another major issue that was faced was the RNN had a regression output for the spike timings of the 44 different muscles. In the case of GCE, a binary output or a probability value for the digits was easily integrated into the framework. But a multi variable regression output required further researching into modelling the output as a gaussian distribution. Otherwise the classifier output could not be fed back into the network to find the causal term.

Contact me!

๐Ÿ“ž(404) 388-3944    ๐Ÿ“Atlanta, GA, 30309   ๐Ÿ“งbhushan.pawaskar@hotmail.com

ยฉ Copyright 2023 Bhushan Pawaskar - All Rights Reserved