Text Classification with SimpleTransformers
Summary
Using a few lines of code from the SimpleTransformers package, we’re able to build a classification model using roBERTa (pre-trained model) to identify spam text messages.
About SimpleTransformers
SimpleTransformers is a Natural Language Processing (NLP) package that can perform machine learning tasks like Text Classification and Conversational AI. Text Classification uses deep learning models like BERT, XLM, or roBERTa while Conversational AI uses GPT (Generative Pre-trained Transformer). SimpleTransformers is based on the Transformers library by HuggingFace.
The general model flow looks something like below:
- Initialize a task-specific model like
ConvAIModel
,ClassificationModel
, orMultiLabelClassificationModel
- Train the model using
train_model()
- Evaluate the model with
eval_model()
- (Text Classification Only) Make predictions on (unlabelled) data with
predict()
SimpleTransformers falls into the category of open-source low-code libraries (like PyCaret), but are also flexible enough for advanced users who want full control over their models through Hyperparameter Tuning.
Once you’ve built out a proper data pipeline and successfully trained the model, all you need to do is swap out your data set to continue the classification.
Example
- Install SimpleTransformers
pip install simpletransformers
2. Import Classification Models and sklearn accuracy metrics
from simpletransformers.classification import ClassificationModel
import pandas as pd
import logging
import sklearn
3. Import dataset and data prep
The data sets must be formatted exactly to SimpleTransformers specifications.
train_df = pd.read_csv("/Users/williamchan/Desktop/spamraw_train.csv")
test_df = pd.read_csv("/Users/williamchan/Desktop/spamraw_test.csv")train_df.drop(columns = 'id')
test_df.drop(columns = 'id')
train_df = train_df.rename(columns={"sms_text": "text", "spam": "labels"})
test_df = test_df.rename(columns={"sms_text": "text"})
4. Initialize the Classification Model
Here I chose roBERTa, a retrained version of the BERT model with an improved methodology and more training data. For more technical details, here is a great post outlining the differences.
model = ClassificationModel('roberta', 'roberta-base', use_cuda=False, args={'reprocess_input_data': True, 'overwrite_output_dir': True})
5. Train the Model
model.train_model(train_df)
6. Evaluate the Model
result, model_outputs, wrong_predictions = model.eval_model(train_df, acc=sklearn.metrics.accuracy_score)
7. Predict on unlabeled Data
predictions, raw_outputs = model.predict(test_df['text'])
Conclusion
As you can see, it took no more than 15 lines of code (11 were data prep, 4 of which were the actual model building) to build, train, and evaluate performance of a full text classification model. This is just the beginning as we could opt to tune the model further or use another type of pre-trained model. The code skeleton is there to be modified and it’s not complicated at all.