Multinomial Regression and the Softmax Activation Function Gary Cottrell
Notation reminder We have N data points, or patterns, in the training set, with the pattern number as a superscript: {(x 1,t 1 ), (x 2,t 2 ), (x n,t n ), (x N,t N )}, and t n is the target for pattern n. The x s are (usually) vectors of dimension d, so the n th input pattern is: For the output, the weighted sum of the input (a.k.a. the net input) is written: And if there are multiple outputs, k=1,..,c, we write: and then the k th output is: where g is the activation function. 2
What if we have multiple classes? E.g., 10 classes, as in the MNIST problem (to pick a random example) Then it would be great to force the network to have outputs that are positive and sum to 1, i.e., to represent the probability distribution over the output categories. The softmax activation function does this. 3
Softmax Note that since the denominator is over all categories, the sum over all categories is 1. This is a generalization of the logistic for multiple categories, because softmax can be written as: (proof left to the reader) 4
What is the right error The network looks like this: function? We use 1-out-of-c encoding for c categories, meaning, the target is 1 for the correct category, and 0 everywhere else. So for 10 outputs, the target vector would look like: (0,0,1,0,0,0,0,0,0,0) if the target was the third category. Using Kronecker s delta, we can write this compactly as: 5
What is the right error function? Now, the conditional probability for the n th example is Note that this picks out the j th output if the correct category is j, since t k will be 1 for the j th category and 0 for everything else. SO, the Likelihood for the entire data set is: 6
What is the right error function? SO, the Likelihood for the entire data set is: And the error, the negative log likelihood is: This is also called the cross-entropy error. 7
What is the right error function? The minimum of this error is at: which will be 0 if t is 0 or 1. but this still applies if t is an actual probability between 0 and 1, and then the above quantity is not 0. So we can subtract it off of the error to get an error function that will hit 0 when minimized (i.e., when y=t): 8
What is the right error function? So, we need to go downhill in this error. Recall gradient descent says: The second factor is just x j, as before. For the first factor, for one pattern n, the derivative with respect to the net input has to take into account all of the outputs, because changing the input to one output changes the activations of all the outputs, so: 9
What is the right error function? So, in order to do gradient descent, we need this derivative: The first factor, using, is The second factor, using the definition of softmax is: 10
What is the right error Which leads to: function? So, We get the delta rule again 11
So, the right way to do multinomial regression is: Start with a net like this: Initialize the weights to 0 Use the softmax activation function: For each pass through the data (an epoch): Randomize the order of the patterns Present a pattern, compute the output Update the weights according to the delta rule Repeat until the error stops decreasing (enough) or a maximum number of epochs has been reached. 12
Activation functions and " Forward propagation We ve already seen three kinds of activation functions: Binary threshold units output Logistic units net input Softmax units 13
Activation functions and " Forward propagation Some more: Rectified linear units (ReLu): Leaky ReLu: 14
Tanh: Activation functions and " Forward propagation Linear 15
Activation functions and " Forward propagation Stochastic: The logistic is treated as a probability of being 1 (else 0) 16
Activation functions and " Forward propagation These functions can be applied recursively: This is called forward propagation 17
Activation functions and " Forward propagation These functions can be applied recursively: This is called forward propagation 18
Activation functions and " Forward propagation Different layers need not have the same activation functions: One popular one (not shown here): ReLu in the hiddens, softmax at the output. 19