Train your first Neural Network for Large Scale Text Classification

Natural Language Processing Jan 31, 2022

Classification with Neural Networks is a common practice. When we want to classify for multi-label, we can combine two, three or more labels into a single label (often called flattening) and train our NN on that. However, in 2014, a paper came out Large-scale Multi-label Text Classification — Revisiting Neural Networks which detailed it's finding on using simple NN to obtain better multi-label classification results than a traditional Back-Propagation Multi-Label learning model. Find the paper here.

Photo by Jason Goodman / Unsplash

In this tutorial, we will do a similar experiment to implement multi label classification. We will use the arXiv Paper Abstracts dataset to train a simple Dense NN model on Kaggle using Keras. Let's get into the tutorial.

Step 1: Preprocessing

  1. Convert the string labels to lists of strings
  2. Use stratified splits because of class imbalance
  3. Multi-label binarization
#Convert the string labels to lists of strings
arxiv_data_filtered["terms"] = arxiv_data_filtered["terms"].apply(
    lambda x: literal_eval(x)

# Use stratified splits because of class imbalance
test_split = 0.1

# Initial train and test split.
train_df, test_df = train_test_split(

# Splitting the test set further into validation
# and new test sets.
val_df = test_df.sample(frac=0.5)
test_df.drop(val_df.index, inplace=True)

#Multi-label binarization
terms = tf.ragged.constant(train_df["terms"].values)
lookup = tf.keras.layers.StringLookup(output_mode="multi_hot")
vocab = lookup.get_vocabulary()

def invert_multi_hot(encoded_labels):
    """Reverse a single multi-hot encoded label to a tuple of vocab terms."""
    hot_indices = np.argwhere(encoded_labels == 1.0)[..., 0]
    return np.take(vocab, hot_indices)

Step 2: Dataset Generator Function

max_seqlen = 150
batch_size = 128
padding_token = "<pad>"
auto =

def make_dataset(dataframe, is_train=True):
    labels = tf.ragged.constant(dataframe["terms"].values)
    label_binarized = lookup(labels).numpy()
    dataset =
        (dataframe["summaries"].values, label_binarized)
    dataset = dataset.shuffle(batch_size * 10) if is_train else dataset
    return dataset.batch(batch_size)
# Prepare the dataset
train_dataset = make_dataset(train_df, is_train=True)
validation_dataset = make_dataset(val_df, is_train=False)
test_dataset = make_dataset(test_df, is_train=False)

Step 3: The Model

def make_model():
    shallow_mlp_model = keras.Sequential(
            layers.Dense(512, activation="relu"),
            layers.Dense(256, activation="relu"),
            layers.Dense(lookup.vocabulary_size(), activation="sigmoid"),
        ]  # More on why "sigmoid" has been used here in a moment.
    return shallow_mlp_model

Step 4: Training

# Need to vectorize data before passing to NN
text_vectorizer = layers.TextVectorization(
    max_tokens=vocabulary_size, ngrams=2, output_mode="tf_idf"

# `TextVectorization` layer needs to be adapted as per the vocabulary from our
# training set.
with tf.device("/CPU:0"):
    text_vectorizer.adapt( text, label: text))

train_dataset =
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
validation_dataset =
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
test_dataset =
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto

# Training code snippet starts here
epochs = 20

shallow_mlp_model = make_model()
    loss="binary_crossentropy", optimizer="adam", metrics=["categorical_accuracy"]

history =
    train_dataset, validation_data=validation_dataset, epochs=epochs

Step 5: Evaluation

def plot_result(item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)


_, categorical_acc = shallow_mlp_model.evaluate(test_dataset)
print(f"Categorical accuracy on the test set: {round(categorical_acc * 100, 2)}%.")

We obtain a test accuracy of around 84%.


Find the public notebook here.


The above blog is an excerpt from a Keras tutorial by Sayak Paul & Soumik Rakshit. Keras tutorials have lots of awesome tutorials and code recipes, check them out here.


Great! You've successfully subscribed.
Great! Next, complete checkout for full access.
Welcome back! You've successfully signed in.
Success! Your account is fully activated, you now have access to all content.