Softmax Gradient

By Alexandre Allauzen

\(\newcommand{\nclasses}{N}\) \(\newcommand{\nfeats}{D}\) \(\newcommand{\gold}{\tilde{k}}\)

\(\newcommand{\x}{\mathbf{x}}\) \(\newcommand{\w}{\mathbf{w}}\) \(\newcommand{\W}{\mathbf{W}}\) \(\newcommand{\a}{a}\) \(\newcommand{\pa}{\mathbf{\theta}}\) \(\newcommand{\loss}{\mathcal{L}(\pa)}\) \(\newcommand{\Z}{Z(\pa)}\)

The Softmax function is widely used in many machine learning models: maximum-entropy, as an activation function for the last layer of neural networks, … Assume, for instance, a classification task with \(\nclasses\) classes on output. To infer a probability distribution over these classes given an input vector \(\x\) of dimension \(\nfeats\), the softmax function is a common choice. The random variable \(C\) denotes the class to assign and its outcome set is a range of integers : \(c \in [1:\nclasses]\). For simplicity, the mention to this random variable is dropped and \(P(k|\x)\) is the probability of \(C=k\) given \(\x\). This distribution is parametrized by a set of parameters \(\pa\). We also omit the bias (or intercept) terms for the sake of clarity. By the way, you can assume an additional input set to 1.

\begin{align} \x &\in \mathbb{R}^{\nfeats}\\ \w_k &\in \mathbb{R}^{\nfeats},\textrm{parameter vector for class }k,\ \forall k \in [ 1:\nclasses] \\ \a_k &= \w_k^t \x,\textrm{ the dot product between } \w_k \textrm{ and } \x \textrm{ or the pre-activation function for NNet.}\\ P(c=k|\x) &= \frac{exp(\a_k)}{\sum_{k'} exp(\a_{k'})}, \textrm{ the probability of the class $k$ given the input }\x.\\ &= \frac{\exp(\a_k)}{\Z}, \textrm{ where $\Z$ is the normalization (or partition function)}. \\ log P(C=k|\x) &= \a_k - log(\Z), \textrm{ in its log form}. \end{align}

1 Conditional maximum likelihood

First, let us define the loss function for one training example:

\[ \loss = -log( P(k=\gold|\x)), \]

where \(\gold\) is the supervision information, the right class to predict, the gold-standard. With several training examples, just take the sum of their log probabilities.

This loss function that we want to minimize has been called the log-loss of the cross-entropy loss. It's just because the trend is to forget the well known past. This loss function is simply the minus log-likelihood. Its minimization is equivalent to maximize the likelihood of the paramaters on a given training set. When considering the whole training set, there is no closed solution for this optimization problem and therefore the gradient descent is an efficient solution. In an online training scenario, the parameters are therefore updated to increase the log-likelihood for a given training example.

2 Gradient of the softmax function w.r.t to parameters

The parameter set \(\pa\) gathers all the vectors \(\w_k\). This is a matrix \(\W\) in which \(\w_{kj}\) is the cell of row \(k\) and column \(j\) and also the component \(j\) of the vector associated to the class \(k\). Let consider the gradient of the loss function w.r.t \(w_{kj}\):

\begin{align} \frac{\partial \loss}{\partial w_{kj}} &= \nabla_{kj} \\ &= \frac{\partial a_{\gold}}{\partial w_{kj}} - \frac{\partial log(\Z)}{\partial w_{kj}}\\ &= \frac{\partial a_{\gold}}{\partial w_{kj}} - \frac{\partial log(\sum_{k'} exp(\a_{k'}))}{\partial w_{kj}} \end{align}

Note that for a given \(k\) and \(\forall j\), only \(a_k\) depends on \(k\), so for the first there are two cases:

  • \(k=\gold\): \(\frac{\partial a_{\gold}}{\partial w_{kj}} = \x_j\) because \(\w_k^t \x = \sum_{j'} w_{kj} x_{j'}\)
  • \(k\ne\gold\): \(\frac{\partial a_{\gold}}{\partial w_{kj}} = 0\)

Let define the Kronecker delta function \(\delta_{k,k'}\) as:

  • \(\delta_{k,k'}= 1\) if and only if \(k=k'\),
  • \(\delta_{k,k'}= 0\) otherwise.

Then we can compactly write: \[ \frac{\partial a_{\gold}}{\partial w_{kj}} = \delta_{k,\gold}x_j \] Now let process the derivative of the log-partition term.

\begin{align} \frac{\partial log(\Z)}{\partial w_{kj}}&= \frac{1}{\Z} \frac{\partial \Z}{\partial w_{kj}} \\ \frac{\partial \Z}{\partial w_{kj}} &= \frac{\partial\sum_{k'} exp(\a_{k'})}{\partial w_{kj}} = \frac{\partial exp(\a_{k})}{\partial w_{kj}} = exp(\a_{k})\frac{\partial \a_{k}}{\partial w_{kj}} \\ &= x_j exp(\a_{k}) \\ \frac{\partial log(\Z)}{\partial w_{kj}}&= \frac{exp(\a_{k})}{\Z}x_j = x_j P(k|\x) \end{align}

Now, we merge the two terms, rearrange a bit to get :

\[\frac{\partial \loss}{\partial w_{kj}} = - (\delta_{k,\gold} - P(k|\x))x_j\]