# Highlights

• Replace the RNN hidden state by a finite set of learnable states using a stochastic state transition mechanism
• “Stabilizes” the hidden state for long sequences
• “Better interpretability and explainability”

# Method

## Method summary

The hidden state of the RNN is replaced by $$S$$, a $$d \times k$$ matrix, where $$k$$ is the number of possible states (centroids) and $$d$$ is the dimension of the hidden states.

A probability distribution is computed using Eq.3, where $$\mathbf u_t \,$$ is the “regular” output of an RNN timestep.

Note: The computation is a softmax that uses a temperature parameter $$\tau$$, that pushes the distribution towards a one-hot encoding as it decreases.

Then, in theory, a new state is sampled, but this would make the RNN un-differentiable (see top of Fig.2); instead, the new state is a weighted sum of the centroids (see bottom of Fig.2) :

• For memory-less RNNs (such as GRU), it means that they can be represented by a Deterministic Finite Automata as $$\tau \rightarrow 0$$.
• For memory RNNs (such as LSTM), it pushes the representational power to the cell state instead of the hidden state.
• In all cases, it limits drifting (where the values of the hidden state increase in magnitude with the length of the sequence).

# Experiments

• IMDB (sentiment analysis)
• Sequential MNIST