See how to do topic modeling using Roberta and transformers. We will use a pre-trained Roberta model finetuned on the NLI dataset.
According to Wikipedia, In machine learning and natural language processing, a topic model is a type of statistical model for discovering the abstract “topics” that occur in a collection of documents. Topic modeling is a frequently used text-mining tool for the discovery of hidden semantic structures in a text body.
– It an unsupervised technique to know which “topic” a text document belongs to.
Bidirectional Encoder Representations from Transformers is a technique for natural language processing pre-training developed by Google.
RoBERTa builds on BERT’s language masking strategy, wherein the system learns to predict intentionally hidden sections of text within otherwise unannotated language examples.
RoBERTa, which was implemented in PyTorch, modifies key hyperparameters in BERT, including removing BERT’s next-sentence pretraining objective, and training with much larger mini-batches and learning rates.
This allows RoBERTa to improve on the masked language modeling objective compared with BERT and leads to better downstream task performance.
Transformer-based models such as BERT or Roberta have shown SOTA performance in various NLP tasks over the last few years.
Pre-trained models are trained on a huge corpus of text from the web. These models contain more accurate representations of words and sentences.
Words or phrases of a document are mapped to vectors of real numbers called embeddings.
There are many ways to embeddings of text using BERT/Roberta.
We will use sentence-transformers here to get embeddings.
There are many pre-trained models finetuned on the NLI dataset.
Now let’s jump into topic modeling using Roberta and transformers using Bertopic.
I strongly recommend using Google colab with GPU enabled for this.
1. Install Bertopic,
!pip install bertopic
2. Prepare data for topic modelling,
We will use 20newsgroups dataset available in sklearn datasets.
But, we will filter to keep only 5 different types. There are ‘alt.atheism’, ‘rec.motorcycles’, ‘rec.sport.baseball’, ‘talk.politics.mideast’, ‘sci.space’.
Import Bertopic and filter 20newsgroup datasets,
from bertopic import BERTopic from sklearn.datasets import fetch_20newsgroups # Focus on 5 differenty news categories. cats = ['alt.atheism', 'rec.motorcycles', 'rec.sport.baseball', 'talk.politics.mideast', 'sci.space'] docs = fetch_20newsgroups(subset='train', categories=cats)["data"]
docs
is a list of news content without any labels.
3. Create topics using BERTopic
We will use roberta-base-nli-stsb-mean-tokens
model for getting embeddings of news text.
But you can use any model listed here.
model = BERTopic("roberta-base-nli-stsb-mean-tokens", top_n_words = 20, nr_topics = 5, n_gram_range = (1, 2), min_topic_size = 30, n_neighbors = 15, n_components = 5, verbose=True)
The parameters used as below,
top_n_words
: int, default 20 The number of words per topic to extract.
nr_topics
: int, default None Specifying the number of topics will reduce the initial number of topics to the value specified. This reduction can take a while as each reduction in topics (-1) activates a c-TF-IDF calculation. IF this is set to None, no reduction is applied.
I want to create 5 topics here.
n_gram_range
: Tuple[int (low), int (high)], default (1, 1) The n-gram range for the CountVectorizer. Advised to keep high values between 1 and 3. More would likely lead to memory issues.
min_topic_size
: int, optional (default=30) The minimum size of the topic. n_neighbors: int, default 15 The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation (UMAP).
n_components
: int, default 5 The dimension of the space to embed into when reducing dimensionality with UMAP.
verbose
: bool, optional (default=False) Changes the verbosity of the model, Set to True if you want to track the stages of the model.
4. Fit the model.
Fit the models on a collection of documents and generate topics.
topics = model.fit(docs)
Here is the logs,
100%|██████████| 461M/461M [00:48<00:00, 9.53MB/s] 2020-10-10 06:09:08,066 - BERTopic - Loaded BERT model INFO:BERTopic:Loaded BERT model 2020-10-10 06:09:39,537 - BERTopic - Transformed documents to Embeddings INFO:BERTopic:Transformed documents to Embeddings 2020-10-10 06:09:58,499 - BERTopic - Reduced dimensionality with UMAP INFO:BERTopic:Reduced dimensionality with UMAP 2020-10-10 06:09:58,626 - BERTopic - Clustered UMAP embeddings with HDBSCAN INFO:BERTopic:Clustered UMAP embeddings with HDBSCAN 2020-10-10 06:10:05,005 - BERTopic - Constructed topics with c-TF-IDF INFO:BERTopic:Constructed topics with c-TF-IDF 2020-10-10 06:10:08,244 - BERTopic - Constructed topics with c-TF-IDF INFO:BERTopic:Constructed topics with c-TF-IDF 2020-10-10 06:10:11,466 - BERTopic - Constructed topics with c-TF-IDF INFO:BERTopic:Constructed topics with c-TF-IDF 2020-10-10 06:10:14,716 - BERTopic - Constructed topics with c-TF-IDF INFO:BERTopic:Constructed topics with c-TF-IDF 2020-10-10 06:10:17,846 - BERTopic - Constructed topics with c-TF-IDF INFO:BERTopic:Constructed topics with c-TF-IDF 2020-10-10 06:10:17,867 - BERTopic - Reduced number of topics from 10 to 5 INFO:BERTopic:Reduced number of topics from 10 to 5
Here you can see number of topics initially determined to be 10 by model but it reduced to 5.
5. Transform the docs.
topics = model.transform(docs)
logs,
2020-10-10 06:10:41,618 - BERTopic - Loaded BERT model INFO:BERTopic:Loaded BERT model 2020-10-10 06:10:59,712 - BERTopic - Transformed documents to Embeddings INFO:BERTopic:Transformed documents to Embeddings
6. Get all topics, their top keywords, and topic id:-
model.get_topics()
Here is the topics,
{-1: [('armenian', 0.0017693323038049094), ('armenians', 0.0016565243206481473), ('turkish', 0.0015366849193318036), ('armenia', 0.0013903376480366681), ('said', 0.0013758937730116987), ('dod', 0.0012992454495922676), ('bike', 0.0012928605305784712), ('turks', 0.0012654484389126623), ('good', 0.0012586538141642354), ('didn', 0.0012238552412850187), ('did', 0.001186921328459221), ('know', 0.001176709003317458), ('cs', 0.0011692494395163801), ('say', 0.0011601316008558332), ('year', 0.0011523408652942147), ('ll', 0.0011448219890220434), ('world', 0.001133375447962003), ('going', 0.0011219456607587204), ('turkey', 0.0011034297859071024), ('right', 0.0010911655655473488)], 0: [('israel', 0.00821818919224951), ('israeli', 0.0075735962617810085), ('jews', 0.006075444179935547), ('arab', 0.004627930619646598), ('lebanese', 0.004383798402286841), ('lebanon', 0.0037706635420496856), ('israelis', 0.003671999676758539), ('cpr', 0.003480032235447003), ('arabs', 0.0034668497365718727), ('gaza', 0.003463835049483846), ('policy', 0.0034403288206129555), ('center policy', 0.0032162089430165645), ('policy research', 0.0032012302167785333), ('palestinian', 0.0031678439613086373), ('adam', 0.003163791825319031), ('palestinians', 0.0030192484441283825), ('jewish', 0.003017380023930699), ('attacks', 0.002961080059985608), ('igc', 0.0029226581263436984), ('hernlem', 0.002865323302670359)], 2: [('space', 0.004324200107322805), ('nasa', 0.003785952195071457), ('launch', 0.0035810202515314214), ('moon', 0.00351051658277184), ('orbit', 0.003391991793946083), ('lunar', 0.003102520947323984), ('gov', 0.003101427878189496), ('henry', 0.0030503220419388264), ('nasa gov', 0.0030072145921081024), ('satellite', 0.0028870175175573116), ('shuttle', 0.002851150001265263), ('earth', 0.0026528308553142323), ('jpl', 0.0025857704207529445), ('mission', 0.0025587430491707283), ('alaska', 0.002461952233665219), ('toronto', 0.0023141616598706546), ('mars', 0.002299537511379093), ('solar', 0.002273568144876374), ('access', 0.0022109157650099546), ('alaska edu', 0.002114090937880779)], 5: [('keith', 0.004109975332301287), ('god', 0.003163473905413346), ('caltech edu', 0.003102519164290355), ('caltech', 0.003052947568354468), ('edu keith', 0.002867329828850716), ('livesey', 0.0026760301095891527), ('sgi', 0.0026326257465370075), ('atheists', 0.00251283994781542), ('moral', 0.002354739579824985), ('islam', 0.002354091101939691), ('atheism', 0.0023399129214929193), ('morality', 0.0022973020821965497), ('sgi com', 0.002251649986266537), ('wpd', 0.0021668911467585413), ('wpd sgi', 0.0021668911467585413), ('solntze wpd', 0.0021572724320563946), ('solntze', 0.0021572724320563946), ('schneider', 0.0021353863002852784), ('evidence', 0.002097834227036188), ('argument', 0.00206386069211151)], 8: [('00', 0.010653711435821906), ('02', 0.00892045316032216), ('03', 0.008794185757250314), ('01', 0.007427940411598331), ('00 00', 0.006493248229890668), ('04', 0.006221657086294662), ('333', 0.0053242759123529135), ('games', 0.004815768529654299), ('lost', 0.004497531155338454), ('won', 0.0043492236370303375), ('05', 0.0043295706540779945), ('baseball', 0.004260337219654381), ('philadelphia', 0.00422342904723393), ('game', 0.0040184757010950515), ('league', 0.0038619264872576885), ('500', 0.00375257613216997), ('runs', 0.0035607986294674643), ('sox', 0.003533651848201407), ('win', 0.0033356958104620677), ('hit', 0.0032188072687711474)]}
7. Get a particular topic and max n keywords.
# Get a topic model.get_topic(2)[:10]
[('space', 0.004324200107322805), ('nasa', 0.003785952195071457), ('launch', 0.0035810202515314214), ('moon', 0.00351051658277184), ('orbit', 0.003391991793946083), ('lunar', 0.003102520947323984), ('gov', 0.003101427878189496), ('henry', 0.0030503220419388264), ('nasa gov', 0.0030072145921081024), ('satellite', 0.0028870175175573116)]
8. Get all topics frequency:-
model.get_topics_freq()
Topic | Count |
-1 | 1512 |
5 | 494 |
2 | 434 |
0 | 271 |
8 | 121 |
9. Model serialization:- See how to save and load for topic prediction.
# Save model model.save("my_model")
# Load model my_model = BERTopic.load("my_model")
Check if the loaded model working fine,
We will get topic id 2
which is about space
my_model.get_topic(2)[:10]
output,
[('space', 0.004324200107322805), ('nasa', 0.003785952195071457), ('launch', 0.0035810202515314214), ('moon', 0.00351051658277184), ('orbit', 0.003391991793946083), ('lunar', 0.003102520947323984), ('gov', 0.003101427878189496), ('henry', 0.0030503220419388264), ('nasa gov', 0.0030072145921081024), ('satellite', 0.0028870175175573116)]
It is working fine.
10. Let’s predict the topic for a news text,
my_model.transform("NASA Finds Billion-year-old Sand Dunes On Mars That Reveal Climate Pattern Of The Planet")
It predicts [2]
, which is the topic id for space news. Alos, it is correct topic.
2020-10-10 06:19:20,754 - BERTopic - Loaded BERT model INFO:BERTopic:Loaded BERT model 2020-10-10 06:19:20,941 - BERTopic - Transformed documents to Embeddings INFO:BERTopic:Transformed documents to Embeddings [2]
If you are facing any issues let me know in comment section.
My other articles about Huggingface/transformers/BERT,
Conversational response generation using DialoGPT
Faster transformer NLP pipeline using ONNX
Text2TextGeneration pipeline by Huggingface transformers
Question answering using transformers and BERT
How to cluster text documents using BERT
How to do semantic document similarity using BERT
Zero-shot classification using Huggingface transformers
Summarize text document using transformers and BERT
Follow me on Twitter, Instagram, Pinterest, and Tumblr for new post notification.