NLP Zero to One: Transformers (Part 13/30)
Scaled Dot-product Attention, Multi-Headed Self Attention,
Introduction..
Transformers is a novel neural architecture that proved to be a recent success in machine learning translation. Like encoder-decoder models, Transformer is an architecture for transforming one sequence into another using the encoder and decoder. The difference is from the previously RNN based sequence-to-sequence models is that transformer does-not use any Recurrent Networks (GRU, LSTM, etc.) as neither encoder nor decoder. So the transformers eliminated the need for using the RNN connections in the encoder and decoder networks.
The idea of transformer is that instead of using the RNN for accumulating the memory, transformers uses multi-headed attention directly on the input sequence. This makes it similar to a feed forward networks allowing the computation to be performed in parallel. Only attention-mechanisms without any RNN (Recurrent Neural Networks) can significantly improve on the results in translation task and other tasks.
Self-Attention..
The idea of self attention is that we want to model how each word of the sequence is influenced by all the other words in a given sequence. So the idea is to get the self-attention-weights “ a ” which tells us how word of the sequence is influenced by all the other words in the sequence. Let represent the word of interest by Q and other words in sequence as K.
The SoftMax function is applied to the self-attention weights “a” to have a distribution between 0 and 1. The computed self-attention weights “a” are then applied to all the words in the sequence which is represented as “V”.
Multi-Headed Self-Attention..
This self attention-mechanism is paralleled into side by side multiple mechanisms that is shown in the right plot. The self-attention mechanism is repeated multiple times with linear projections of Q, K and V. This parallelised version of self-attention is called multi-headed self-attention. This allows the system to learn from different representations of Q, K and V, which is beneficial to the model. These linear representations are done by multiplying Q, K and V by weight matrices W that are learned during the training[2].
Transformers..
In this section, lets try to understand the transformer architecture. As we already introduced its composed of attention-mechanisms without any RNN (Recurrent Neural Networks). Both Encoder and Decoder are composed of repeated multi-headed self-attention mechanism. The repeating of this multi-headed self-attention mechanism is described as Nx in the figure.
There are several important components and details in the transformer architecture. Lets discuss them briefly:
- Positional Encoding: We cannot use strings directly so inputs and outputs sequences are fed into a embedding layer and turn the sequences into an n-dimensional space since.
- Outputs(shifted right): Its very critical to understand why we needed to shift the outputs to right. We will need to have a lot of sentence pairs, we can start training our model. Let’s say we want to translate from English to Telugu. Our encoded input will be an English sentence and the input for the decoder will be a Telugu sentence shifted right.
In RNN based encoder-decoder we don’t have to give the shifted output sentence, the model simply used the prediction at time step “t ”as input to decoder at time step “t+1”. If you observe the shifting is already happening in RNN based encoder-decoder internally.
In Transformers we are employing the sequential model, so the input to decoder at time step“ t” should be independent of what happens at time step “t-1”. Thus, by shifting the decoder input by one position, our model needs to predict the target word/character for position i.
3. Prediction: Prediction is little different from the way we trained the model because we will not have the output sequence when prediction. Obviously we will have to use the output of decoder as each step and use it as input to the decoder. we need to run the decoder multiple times to complete the prediction of output sequence.