Rationalizing Neural Predictions

By Alexandre Allauzen

\( \global\def\normal{\mathcal{N}} \global\def\m{\mu} \global\def\dkl{D_{kl}} \global\def\X{\mathbf{X}} \global\def\Y{\mathbf{Y}} \global\def\Z{\mathbf{Z}} \global\def\x{\mathbf{x}} \global\def\y{\mathbf{y}} \global\def\z{\mathbf{z}} \global\def\M{\boldsymbol{\mu}} \global\def\C{\boldsymbol{\Sigma}} \global\def\pis{\boldsymbol{\Pi}} \global\def\pa{\boldsymbol{\theta}} \global\def\paold{\pa^{old}} \global\def\panew{\pa^{new}} \global\newcommand\lb{\mathcal{L}(q;\pa)} \global\newcommand\elbo{\textrm{elbo}} \global\newcommand\lsrc{I} \global\newcommand\ltrg{J} \)


Reading notes on "Rationalizing Neural Predictions", by Tao Lei, Regina Barzilay and Tommi Jaakkola, published at EMNLP 2016. You can download the paper and the slides.

Summary

The goal is a model that can both represent a text for classification purpose and explain its decision. The model learns first to extract pieces of an input text as justifications (called rationales) that are tailored to be short, coherent, and yet sufficient for making efficient predictions. The model can be decomposed into two steps:

  • The generator specifies a distribution over text fragments as candidate rationales. In fact, each word of the input text is associated with a binary hidden random variable to weight its importance for the next step.
  • The encoder takes the output of the generator to make the prediction.

The rationale extraction can be understood as a type of stochastic attention although architectures and objectives differ.

Generator

The goal is to associate to an input sequence of words, a sequence of hidden random variables, where each hidden variable indicates wether the associated word should be considered as rationale (\(z=1\)) or not (\(z=0\)).

Assume the input text is of sequence of words as input: \(\X = x_{1}^{\lsrc}\). The model associates to each input word a binary variable: \(\Z=z_{1}^{\lsrc}\). The generator reads the input text with a BiLSTM. To infer the probability of the sequence \(P(\Z|\X)\), independent and recurrent predictions are explored.

For NMT people this could be called the encoder. Maybe selector or simply filter could be used. The choice of terminology in this paper is for me confusing. However…

Encoder

Rationales are defined as the set of \(x_t\) such as \(z_t=1\). Therefore the input for the encoder is a selection of \(\X\). Then you can pick your favorite architecture to deal with this input. In the paper, they used RNNs and pick the last hidden state to make the final prediction

Joint optimization

From an input \(\X\) of length \(\lsrc\), it generates \(\lsrc\) binary variables \(\Z\). The generator estimate \(P(\Z|\X)\).

The authors first define a cost function as follows:

\begin{align} cost(\x,\z,\y) &= || \y - f_{\pa_{e}}(\x,\z) ||^2 + \lambda_1 ||\z|| + \lambda_2 \sum_t |z_t - z_{t-1} | \\ P(\Z=\z | \X=\x) &= g_{\pa_g}(\x) \end{align}

The cost function depends therefore on the value of \(\z\) in three ways:

  • First the term \(||\y - f_{\pa_{e}}(\x,\z) ||^2\) is the reconstruction error. The target is \(\y\) while the encoder predicts \(f_{\pa_{e}}(\x,\z)\).
  • Then, the term \(||\z||\) ensures that the selection (the number of \(z\) set to one) is as small as possible.
  • The last term \(\sum_t |z_t - z_{t-1} |\) favors contiguous selection (phrases).

The loss function to be optimized for each training example is :

\begin{align} \mathcal{L}(\pa_g, \pa_e, \x, \y) &= E_{\z\sim P(\Z|\X)} cost(\x,\y,\z) \\ &=\sum_{\z} P(\Z=\z | \X=\x) cost(\x,\y,\z) \\ &= \sum_{\z} g_{\pa_g}(\x) cost(\x,\y,\z) \end{align}

This expected cost is a workaround to deal with hidden variables. Minimizing the expected cost is challenging since it involves summing over all the possible choices of rationales \(\z\). Then the authors propose to sample \(\z\) from the generator to approximate the expectation.

However, the derivatives look bit wired. The cost function is considered as a constant wrt of the generator parameters. The term related to the norm of \(\z\) for instance, implies the expected norm of \(\z\). This expectation depends on the same parameters and could be included in the gradient ?

In fact, the assumption made through the paper is: given \(p(\Z|\X)\), \(\z\) is sampled and becomes then deterministic.

To summarize the inference step for training :

  • Forward propagation of \(\x\) through the generator gives you \(p(\Z|\X)\).
  • Sample a bunch of \(\z\) in this distribution.
  • Given \(\z\), build the input of the encoder and compute the

expected cost.

  • Update the parameters of the whole model given the expected gradients.

Inference

While for inference:

  • Forward propagation of \(\x\) through the generator gives you \(p(\z|\x)\).
  • Compute \(\z\) and then get the rationales (how ?)
  • Given \(\z\), build the input of the encoder and compute the answer.

Questions / Comments

Rationale

As written in the paper the notion of what counts as a rationale may be ambiguous in some contexts and the task of selecting rationales may therefore be challenging to evaluate. In the paper, they focus on two domains where ambiguity is minimal (or can be minimized).

Training and inference

  • For training, maybe I missed it but, there's no mention in the paper of the number of samples used to approximate the expectation.
  • The inference step raises a similar question: but how the second step is implemented ? Just apply a threshold on the probability ? 0.5 ?

Attention based model for classification

The rationale extraction can be understood as attention, even architectures and objectives differ. The discussion in the paper is not so convincing for me. It could be interesting to investigate that point, starting by this paper on stochastic attention, followed by attention is all you need.

Encoder

The encoder is a "simple" rnn and the last hidden state is taken as input for classification. While we expect a short sequence after the rationale extraction step. This could maybe bias the whole model to select word at end of the sequence, and then by backprop to favor rationale extraction at the end.

Loss function etc …

Note that the first regularization term could be l1 norm instead of l2, to favor sparsity. Maybe good for long documents.

More generally, is this formulation the best option ? Since \(\Z\) are hidden variables, could we adapt the Variational Auto Encoder to this task and can we use reparametrization trick ?