Train your first Neural Network with Attention for Abstractive Summarisation
Encoders and Decoders combined with seq2seq
have been game-changing for NLP research such that it is the base for sequence tasks. It essentially suggests encoding the complete sequence at once and then using this encoding as a context for the generation of decoded sequence or the target sequence.
seq2seq
with RNNs is great but with attention, it’s even better. The main issue with RNNs lies in their inability of providing parallelisation while processing. The processing of RNN is sequential, i.e. we cannot compute the value of the next time step unless we have the output of the current. This makes RNN-based approaches slow. Even using CNN in place of RNN with GPU acceleration does not make it much better. This where transformers come in.
What are Transformers?
The Transformer in NLP is a novel architecture that aims to solve sequence-to-sequence tasks while handling long-range dependencies with ease while being effective & computationally efficient model by allowing parallelisation too. The Transformer was proposed in the paper Attention Is All You Need by Google Research team.

Now, what is Abstractive Summarisation?
Abstractive Summarisation includes heuristic approaches to train the system in making an attempt to understand the whole context and generate a summary based on that understanding. This is a more human-like way of generating summaries and these summaries are more effective as compared to the extractive approaches.
For this tutorial, we train an Encoder-Decoder network with Transformers for abstractive summarisation. We will use the Inshorts Dataset which contains text from news articles to summaries (of 60 words). We will use Kaggle with GPU acceleration for training and evaluation.
For the scope of this tutorial, we will skip some micro-concepts and their code: Positional Encoding
(to retain the sequence of predicted words), Masking
(Padding to maxlen
of sequences & Look-ahead to only take into account words occurred yet for predicting the next word) and CustomSchedule
: a custom learning rate scheduler that helps faster convergence. (The public Kaggle notebook will be included at the end of this blog!)
Step 1: Preprocessing
For recognising the start and end of target sequences, we pad them with start <go>
and end <stop>
tokens. We fit a Tokenizer
on the sequences which essentially filters punctuation marks, converts text to lower case, maintains a vocabulary dict which is ordered by frequency of occurrence of words and contains a mapping of words with their token equivalents and finally, converts the textual data to tokens which can be directly inputted to the model.
news = pd.DataFrame(pd.read_excel("../input/inshorts-news-data/Inshorts Cleaned Data.xlsx", engine='openpyxl'))
news.drop(['Source ', 'Time ', 'Publish Date'], axis=1, inplace=True)
document = news['Short']
summary = news['Headline']
summary = summary.apply(lambda x: '<go> ' + x + ' <stop>')
filters = '!"#$%&()*+,-./:;=?@[\\]^_`{|}~\t\n'
oov_token = '<unk>'
document_tokenizer = tf.keras.preprocessing.text.Tokenizer(oov_token=oov_token)
summary_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters=filters, oov_token=oov_token)
document_tokenizer.fit_on_texts(document)
summary_tokenizer.fit_on_texts(summary)
inputs = document_tokenizer.texts_to_sequences(document)
targets = summary_tokenizer.texts_to_sequences(summary)
inputs = tf.keras.preprocessing.sequence.pad_sequences(inputs, maxlen=encoder_maxlen, padding='post', truncating='post')
targets = tf.keras.preprocessing.sequence.pad_sequences(targets, maxlen=decoder_maxlen, padding='post', truncating='post')
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
Step 2: Model Components
- Scaled Dot-Product: Base for attention computation in model.
def scaled_dot_product_attention(q, k, v, mask):
matmul_qk = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
if mask is not None:
scaled_attention_logits += (mask * -1e9)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
output = tf.matmul(attention_weights, v)
return output, attention_weights
2. Multi-Head Attention: we split the inputs into multiple heads, compute the attention weights using scaled dot-product attention and finally, concat output from all the heads.
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask):
batch_size = tf.shape(q)[0]
q = self.wq(q)
k = self.wk(k)
v = self.wv(v)
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
scaled_attention, attention_weights = scaled_dot_product_attention(
q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))
output = self.dense(concat_attention)
return output, attention_weights
3. Encoder
class EncoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(EncoderLayer, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
attn_output, _ = self.mha(x, x, x, mask)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(x + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layernorm2(out1 + ffn_output)
return out2
class Encoder(tf.keras.layers.Layer):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, maximum_position_encoding, rate=0.1):
super(Encoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
self.pos_encoding = positional_encoding(maximum_position_encoding, self.d_model)
self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
seq_len = tf.shape(x)[1]
x = self.embedding(x)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x = self.enc_layers[i](x, training, mask)
return x
4. Decoder
class DecoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(DecoderLayer, self).__init__()
self.mha1 = MultiHeadAttention(d_model, num_heads)
self.mha2 = MultiHeadAttention(d_model, num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
self.dropout3 = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)
attn1 = self.dropout1(attn1, training=training)
out1 = self.layernorm1(attn1 + x)
attn2, attn_weights_block2 = self.mha2(enc_output, enc_output, out1, padding_mask)
attn2 = self.dropout2(attn2, training=training)
out2 = self.layernorm2(attn2 + out1)
ffn_output = self.ffn(out2)
ffn_output = self.dropout3(ffn_output, training=training)
out3 = self.layernorm3(ffn_output + out2)
return out3, attn_weights_block1, attn_weights_block2
class Decoder(tf.keras.layers.Layer):
def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, maximum_position_encoding, rate=0.1):
super(Decoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
seq_len = tf.shape(x)[1]
attention_weights = {}
x = self.embedding(x)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x, block1, block2 = self.dec_layers[i](x, enc_output, training, look_ahead_mask, padding_mask)
attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
attention_weights['decoder_layer{}_block2'.format(i+1)] = block2
return x, attention_weights
Step 3: The Model
class Transformer(tf.keras.Model):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, pe_input, pe_target, rate=0.1):
super(Transformer, self).__init__()
self.encoder = Encoder(num_layers, d_model, num_heads, dff, input_vocab_size, pe_input, rate)
self.decoder = Decoder(num_layers, d_model, num_heads, dff, target_vocab_size, pe_target, rate)
self.final_layer = tf.keras.layers.Dense(target_vocab_size)
def call(self, inp, tar, training, enc_padding_mask, look_ahead_mask, dec_padding_mask):
enc_output = self.encoder(inp, training, enc_padding_mask)
dec_output, attention_weights = self.decoder(tar, enc_output, training, look_ahead_mask, dec_padding_mask)
final_output = self.final_layer(dec_output)
return final_output, attention_weights
Step 4: Training
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
EPOCHS = 20
BUFFER_SIZE = 20000
BATCH_SIZE = 64
encoder_maxlen = 400
decoder_maxlen = 75
encoder_vocab_size, decoder_vocab_size = 76362, 29661
for epoch in range(EPOCHS):
start = time.time()
train_loss.reset_states()
for (batch, (inp, tar)) in enumerate(dataset):
train_step(inp, tar)
# 55k samples
# we display 3 batch results -- 0th, middle and last one (approx)
# 55k / 64 ~ 858; 858 / 2 = 429
if batch % 429 == 0:
print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, batch, train_loss.result()))
if (epoch + 1) % 5 == 0:
ckpt_save_path = ckpt_manager.save()
print ('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))
print ('Epoch {} Loss {:.4f}'.format(epoch + 1, train_loss.result()))
print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
Step 5: Evaluation and Testing
Convert sequences of tokens to generated summary.
def evaluate(input_document):
input_document = document_tokenizer.texts_to_sequences([input_document])
input_document = tf.keras.preprocessing.sequence.pad_sequences(input_document, maxlen=encoder_maxlen, padding='post', truncating='post')
encoder_input = tf.expand_dims(input_document[0], 0)
decoder_input = [summary_tokenizer.word_index["<go>"]]
output = tf.expand_dims(decoder_input, 0)
for i in range(decoder_maxlen):
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoder_input, output)
predictions, attention_weights = transformer(
encoder_input,
output,
False,
enc_padding_mask,
combined_mask,
dec_padding_mask
)
predictions = predictions[: ,-1:, :]
predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
if predicted_id == summary_tokenizer.word_index["<stop>"]:
return tf.squeeze(output, axis=0), attention_weights
output = tf.concat([output, predicted_id], axis=-1)
return tf.squeeze(output, axis=0), attention_weights
def summarize(input_document):
# not considering attention weights for now, can be used to plot attention heatmaps in the future
summarized = evaluate(input_document=input_document)[0].numpy()
summarized = np.expand_dims(summarized[1:], 0) # not printing <go> token
return summary_tokenizer.sequences_to_texts(summarized)[0] # since there is just one translated document
The loss during final epoch of training was 2.5217
. Checkout the performance from the below screenshot:

It can be seen that the model is not performing so well. So, you can choose to tune the hyper parameters and retrain using below notebook to produce better results and let me know in the comments.
[Optional]
Find the public notebook with complete implementation below:

[Bonus]
Learn Transformers in-depth from the following article by Jay Alammar:

Cheers!