Train your first LSTM Model for Text Generation

Machine-Learning Jan 20, 2021

A recurrent neural network (RNN) is a class of artificial neural networks where connections between nodes form a directed graph along a temporal sequence. This allows it to exhibit temporal dynamic behaviour. Derived from feedforward neural networks, RNNs can use their internal state (memory) to process variable length sequences of inputs. This makes them applicable to tasks such as unsegmented, connected handwriting recognition or speech recognition.

RNNs have the advantage that they generalise across sequences rather than learn individual patterns. They do this by capturing the dynamics of a sequence through loop connections and shared parameters. RNNs are also not constrained to a fixed sequence size and, in theory, can take all previous steps of the sequence into account. This makes them very suitable for analysing sequential data.

While RNNs seemed promising to learn time evolution in time series, they soon showed their limitations in long memory capability combined with Vanishing/Exploding gradient problem. This is when LSTM (Long Short Term Memory) sparked the interest of the deep learning community.

So, what is a LSTM?

An LSTM network is a special type of RNN. Indeed, LSTM networks follow the same chain-like structure of network copies as RNNs. The only difference is in the structure of network.

Structure of a LSTM

To overcome the problem of limited long memory capability, LSTM units use the concept of an additional hidden state to h(t): the cell state C(t). C(t) represents the network memory. A particular structure, called gates, allows you to remove (forget) or add (remember) information to the cell state C(t) at each time step based on the input values x(t) and hidden state h(t−1) (Figure above).

Each gate is implemented via a sigmoid layer that decides which information to add or delete by outputting values between 0 and 1. By multiplying the gate output point wise by a state, e.g. the cell state C(t−1), information is deleted (output of gate ≈ 0) or kept (output of gate ≈ 1).

In the above Figure, we see the network structure of an LSTM unit. Each LSTM unit has 3 gates. The “forget gate layer” at the beginning filters the information to throw away or to keep from the previous cell state C(t−1) based on the current input x(t) and the previous cell’s hidden state h(t−1).

The adding of information to the cell state C(t) consists of two layers: An “input gate layer” that decides which information we want to add and a “tanh layer” that forces the output between and -1 and 1. The outputs of the two layers are multiplied point wise and added to the previous, already filtered cell state C(t−1) to update it.

The last gate is the “output gate”. This decides which of the information from the updated cell state C(t) ends up in the next hidden state h(t). Therefore, the hidden state h(t) is a filtered version of the cell state C(t).


In this tutorial, we are going to build a neural network with LSTM layers with Keras and train it on Alice's Adventures in Wonderland by Lewis Carroll as a part of Project Gutenberg.


We need to load the ASCII text for the book into memory and convert all of the characters to lowercase to reduce the vocabulary that the network must learn.

Now that the book is loaded, we must prepare the data for modelling by the neural network. We cannot model the characters directly, instead we must convert the characters to integers. We can do this easily by first creating a set of all of the distinct characters in the book, then creating a map of each character to a unique integer.

we will split the book text up into subsequences with a fixed length of 100 characters, an arbitrary length. We could just as easily split the data up by sentences and pad the shorter sequences and truncate the longer ones.

Each training pattern of the network is comprised of 100 time steps of one character (X) followed by one character output (y). When creating these sequences, we slide this window along the whole book one character at a time, allowing each character a chance to be learned from the 100 characters that preceded it.

filename = "../input/alice-in-wonderland-gutenbergproject/wonderland.txt"
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()

chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c in enumerate(chars))

n_chars = len(raw_text)
n_vocab = len(chars)
print("Total Characters: " + str(n_chars))
print("Total Vocab: " + str(n_vocab))

seq_length = 100
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
	seq_in = raw_text[i:i + seq_length]
	seq_out = raw_text[i + seq_length]
	dataX.append([char_to_int[char] for char in seq_in])
n_patterns = len(dataX)
print("Total Patterns: "+ str(n_patterns))

X = numpy.reshape(dataX, (n_patterns, seq_length, 1))

X = X / float(n_vocab)

y = np_utils.to_categorical(dataY)

Let's have a look at the model:

model = Sequential()

model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2]), return_sequences=True))


model.add(Dense(y.shape[1], activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']), y, epochs=50, batch_size=64)

The simplest way to use the Keras LSTM model to make predictions is to first start off with a seed sequence as input, generate the next character then update the seed sequence to add the generated character on the end and trim off the first character. This process is repeated for as long as we want to predict new characters (e.g. a sequence of 1,000 characters in length).

We can pick a random input pattern as our seed sequence, then print generated characters as we generate them.

int_to_char = dict((i, c) for i, c in enumerate(chars))
start = numpy.random.randint(0, len(dataX)-1)
pattern = dataX[start]
print("\"" + ''.join([int_to_char[value] for value in pattern]) + "\"")
res = ""
# generate characters
for i in range(1000):
    x = numpy.reshape(pattern, (1, len(pattern), 1))
    x = x / float(n_vocab)
    prediction = model.predict(x, verbose=0)
    index = numpy.argmax(prediction)
    result = int_to_char[index]
    res = res + result
    seq_in = [int_to_char[value] for value in pattern]
    pattern = pattern[1:len(pattern)]

Let's have a look at the result:

" scaly friend replied.
 “there is another shore, you know, upon the other side.
 the further off fro"
m and saying all the lad of the little gorger of the darded of the door, and she was seriing that she was serting on the darked of the ground, and she was tereing the door and looked at the coor of the court, and she was teriing the pool as the corro in the coorersation.

‘there’s ali the door hine,’ said the daterpillar.

‘it was a good deal on the sable, alice  what had the door and looked at the coor of the court, and the mittle gouse in the white rabbit was a little bomwarsing the puher sat of the court and the pther side of the court, whth a little with the white rabbit was a little bomwarsing the puher side with a little with the while, and the thme the was a little bomw it out of the words and the coor, and she sat down and the coor of the white rabbit was a little bomwarsing the puher side with the white rabbit all the coorersation and rarier a little with the white rabbit was a little bomwarsing the perper of the while as she was a little bomwarsing the puher side with the whi


In this post you discovered how you can develop an LSTM recurrent neural network for text generation in Python with the Keras deep learning library.


Checkout the public notebook:

Text Generation using LSTM
Explore and run machine learning code with Kaggle Notebooks | Using data from Alice In Wonderland GutenbergProject


The tutorial part of this blog is an excerpt of an awesome blog from Machine Learning Mastery. Checkout their website here:

Machine Learning Mastery
Making developers awesome at machine learning.



Great! You've successfully subscribed.
Great! Next, complete checkout for full access.
Welcome back! You've successfully signed in.
Success! Your account is fully activated, you now have access to all content.