Text Classification with SimpleTransformers

William Chan
2 min readAug 29, 2021
Text Classification is one of many tasks SimpleTransformers is ready to handle.

Summary

Using a few lines of code from the SimpleTransformers package, we’re able to build a classification model using roBERTa (pre-trained model) to identify spam text messages.

About SimpleTransformers

SimpleTransformers is a Natural Language Processing (NLP) package that can perform machine learning tasks like Text Classification and Conversational AI. Text Classification uses deep learning models like BERT, XLM, or roBERTa while Conversational AI uses GPT (Generative Pre-trained Transformer). SimpleTransformers is based on the Transformers library by HuggingFace.

The general model flow looks something like below:

  1. Initialize a task-specific model likeConvAIModel, ClassificationModel, or MultiLabelClassificationModel
  2. Train the model using train_model()
  3. Evaluate the model with eval_model()
  4. (Text Classification Only) Make predictions on (unlabelled) data with predict()

SimpleTransformers falls into the category of open-source low-code libraries (like PyCaret), but are also flexible enough for advanced users who want full control over their models through Hyperparameter Tuning.

Once you’ve built out a proper data pipeline and successfully trained the model, all you need to do is swap out your data set to continue the classification.

Example

  1. Install SimpleTransformers
pip install simpletransformers

2. Import Classification Models and sklearn accuracy metrics

from simpletransformers.classification import ClassificationModel
import pandas as pd
import logging
import sklearn

3. Import dataset and data prep

The data sets must be formatted exactly to SimpleTransformers specifications.

train_df = pd.read_csv("/Users/williamchan/Desktop/spamraw_train.csv")
test_df = pd.read_csv("/Users/williamchan/Desktop/spamraw_test.csv")
train_df.drop(columns = 'id')
test_df.drop(columns = 'id')
train_df = train_df.rename(columns={"sms_text": "text", "spam": "labels"})
test_df = test_df.rename(columns={"sms_text": "text"})

4. Initialize the Classification Model

Here I chose roBERTa, a retrained version of the BERT model with an improved methodology and more training data. For more technical details, here is a great post outlining the differences.

model = ClassificationModel('roberta', 'roberta-base', use_cuda=False, args={'reprocess_input_data': True, 'overwrite_output_dir': True})

5. Train the Model

model.train_model(train_df)

6. Evaluate the Model

result, model_outputs, wrong_predictions = model.eval_model(train_df, acc=sklearn.metrics.accuracy_score)

7. Predict on unlabeled Data

predictions, raw_outputs = model.predict(test_df['text'])

Conclusion

As you can see, it took no more than 15 lines of code (11 were data prep, 4 of which were the actual model building) to build, train, and evaluate performance of a full text classification model. This is just the beginning as we could opt to tune the model further or use another type of pre-trained model. The code skeleton is there to be modified and it’s not complicated at all.

Full GitHub Code

--

--

William Chan

CEO & Partner, Compass Data, a Canadian analytics company helping Canada advance its digital transformation.