With the advancement in deep learning, neural network architectures like recurrent neural networks (RNN and LSTM) and convolutional neural networks (CNN) have shown a decent improvement in performance in solving several Natural Language Processing (NLP) tasks like text classification, language modeling, machine translation, etc.
However, this performance of deep learning models in NLP pales in comparison to the performance of deep learning in Computer Vision.
One of the main reasons for this slow progress could be the lack of large labeled text datasets. Most of the labeled text datasets are not big enough to train deep neural networks because these networks have a huge number of parameters and training such networks on small datasets will cause overfitting.
Another quite important reason for NLP lagging behind computer vision was the lack of transfer learning in NLP. Transfer learning has been instrumental in the success of deep learning in computer vision. This happened due to the availability of huge labeled datasets like Imagenet on which deep CNN based models were trained and later they were used as pre-trained models for a wide range of computer vision tasks.
That was not the case with NLP until 2018 when the transformer model was introduced by Google. Ever since the transfer learning in NLP is helping in solving many tasks with state of the art performance.
In this article, I explain how do we fine-tune BERT for text classification.
Transfer learning is a technique where a deep learning model trained on a large dataset is used to perform similar tasks on another dataset. We call such a deep learning model a pre-trained model. The most renowned examples of pre-trained models are the computer vision deep learning models trained on the ImageNet dataset. So, it is better to use a pre-trained model as a starting point to solve a problem rather than building a model from scratch.
This breakthrough of transfer learning in computer vision occurred in the year 2012-13. However, with recent advances in NLP, transfer learning has become a viable option in this NLP as well.
Most of the tasks in NLP such as text classification, language modeling, machine translation, etc. are sequence modeling tasks. The traditional machine learning models and neural networks cannot capture the sequential information present in the text. Therefore, people started using recurrent neural networks (RNN and LSTM) because these architectures can model sequential information present in the text.
A typical RNN
However, these recurrent neural networks have their own set of problems. One major issue is that RNNs can not be parallelized because they take one input at a time. In the case of a text sequence, an RNN or LSTM would take one token at a time as input. So, it will pass through the sequence token by token. Hence, training such a model on a big dataset will take a lot of time.
So, the need for transfer learning in NLP was at an all-time high. In 2018, the transformer was introduced by Google in the paper “Attention is All You Need” which turned out to be a groundbreaking milestone in NLP.
The Transformer – Model Architecture (Source: https://arxiv.org/abs/1706.03762)
Soon a wide range of transformer-based models started coming up for different NLP tasks. There are multiple advantages of using transformer-based models, but the most important ones are:
First Benefit
These models do not process an input sequence token by token rather they take the entire sequence as input in one go which is a big improvement over RNN based models because now the model can be accelerated by the GPUs.
2nd Benefit
We don’t need labeled data to pre-train these models. It means that we have to just provide a huge amount of unlabeled text data to train a transformer-based model. We can use this trained model for other NLP tasks like text classification, named entity recognition, text generation, etc. This is how transfer learning works in NLP.
BERT and GPT-2 are the most popular transformer-based models and in this article, we will focus on BERT and learn how we can use a pre-trained BERT model to perform text classification.
What is Model Fine-Tuning?
BERT (Bidirectional Encoder Representations from Transformers) is a big neural network architecture, with a huge number of parameters, that can range from 100 million to over 300 million. So, training a BERT model from scratch on a small dataset would result in overfitting.
So, it is better to use a pre-trained BERT model that was trained on a huge dataset, as a starting point. We can then further train the model on our relatively smaller dataset and this process is known as model fine-tuning.
Different Fine-Tuning Techniques
Train the entire architecture – We can further train the entire pre-trained model on our dataset and feed the output to a softmax layer. In this case, the error is back-propagated through the entire architecture and the pre-trained weights of the model are updated based on the new dataset.
Train some layers while freezing others – Another way to use a pre-trained model is to train it partially. What we can do is keep the weights of initial layers of the model frozen while we retrain only the higher layers. We can try and test as to how many layers to be frozen and how many to be trained.
Freeze the entire architecture – We can even freeze all the layers of the model and attach a few neural network layers of our own and train this new model. Note that the weights of only the attached layers will be updated during model training.
In this tutorial, we will use the third approach. We will freeze all the layers of BERT during fine-tuning and append a dense layer and a softmax layer to the architecture.
Overview of BERT
You’ve heard about BERT, you’ve read about how incredible it is, and how it’s potentially changing the NLP landscape. But what is BERT in the first place?
Here’s how the research team behind BERT describes the NLP framework:
“BERT stands for Bidirectional Encoder Representations from Transformers. It is designed to pre-train deep bidirectional representations from unlabeled text by jointly conditioning on both left and right context. As a result, the pre-trained BERT model can be fine-tuned with just one additional output layer to create state-of-the-art models for a wide range of NLP tasks.”
That sounds way too complex as a starting point. But it does summarize what BERT does pretty well so let’s break it down.
Firstly, BERT stands for Bidirectional Encoder Representations from Transformers. Each word here has a meaning to it and we will encounter that one by one in this article. For now, the key takeaway from this line is – BERT is based on the Transformer architecture. Secondly, BERT is pre-trained on a large corpus of unlabelled text including the entire Wikipedia (that’s 2,500 million words!) and Book Corpus (800 million words).
This pre-training step is half the magic behind BERT’s success. This is because as we train a model on a large text corpus, our model starts to pick up the deeper and intimate understandings of how the language works. This knowledge is the swiss army knife that is useful for almost any NLP task.
Third, BERT is a “deep bidirectional” model. Bidirectional means that BERT learns information from both the left and the right side of a token’s context during the training phase.
To learn more about the BERT architecture and its pre-training tasks, then you may like to read the below article:
Now we will fine-tune a BERT model to perform text classification with the help of the Transformers library. You should have a basic understanding of defining, training, and evaluating neural network models in PyTorch. If you want a quick refresher on PyTorch then you can go through the article below:
We have a collection of SMS messages. Some of these messages are spam and the rest are genuine. Our task is to build a system that would automatically detect whether a message is spam or not.
The dataset that we will be using for this use case can be downloaded from here (right-click and click on “Save link as…”).
I suggest you use Google Colab to perform this task so that you can use the GPU. Firstly, activate the GPU runtime on Colab by clicking on Runtime -> Change runtime type -> Select GPU.
Install Transformers Library
We will then install Huggingface’s transformers library. This library lets you import a wide range of transformer-based pre-trained models. Just execute the code below to install the library.
!pip install transformers
Import Libraries
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
You would have to upload the downloaded spam dataset to your Colab runtime. Then read it into a pandas dataframe.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
The dataset consists of two columns – “label” and “text”. The column “text” contains the message body and the “label” is a binary variable where 1 means spam and 0 means the message is not a spam.
Now we will split this dataset into three sets – train, validation, and test.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
We will fine-tune the model using the train set and the validation set, and make predictions for the test set.
Import BERT Model and BERT Tokenizer
We will import the BERT-base model that has 110 million parameters. There is an even bigger BERT model called BERT-large that has 345 million parameters.
Python Code:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import transformers
from transformers import AutoModel, BertTokenizerFast
# Read dataset
df = pd.read_csv("spamdata_v2.csv")
df.head()
# Train test split
train_text, temp_text, train_labels, temp_labels = train_test_split(df['text'], df['label'], random_state = 2018, test_size = 0.3, stratify = df['label'])
val_text, test_text, val_labels, test_labels = train_test_split(temp_text, temp_labels, random_state = 2018, test_size = 0.5, stratify = temp_labels)
# import BERT-base pretrained model
bert = AutoModel.from_pretrained('bert-base-uncased')
# Load the BERT tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
# sample data
text = ["this is a bert model tutorial", "we will fine-tune a bert model"]
# encode text
sent_id = tokenizer.batch_encode_plus(text, padding=True)
# output
print(sent_id)
Let’s see how this BERT tokenizer works. We will try to encode a couple of sentences using the tokenizer.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
As you can see the output is a dictionary of two items.
‘input_ids’ contains the integer sequences of the input sentences. The integers 101 and 102 are special tokens. We add them to both the sequences, and 0 represents the padding token.
‘attention_mask’ contains 1’s and 0’s. It tells the model to pay attention to the tokens corresponding to the mask value of 1 and ignore the rest.
Tokenize the Sentences
Since the messages (text) in the dataset are of varying length, therefore we will use padding to make all the messages have the same length. We can use the maximum sequence length to pad the messages. However, we can also have a look at the distribution of the sequence lengths in the train set to find the right padding length.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
We can clearly see that most of the messages have a length of 25 words or less. Whereas the maximum length is 175. So, if we select 175 as the padding length then all the input sequences will have length 175 and most of the tokens in those sequences will be padding tokens which are not going to help the model learn anything useful and on top of that, it will make the training slower.
Therefore, we will set 25 as the padding length.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
So, we have now converted the messages in train, validation, and test set to integer sequences of length 25 tokens each.
Next, we will convert the integer sequences to tensors.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Now we will create dataloaders for both train and validation set. These dataloaders will pass batches of train data and validation data as input to the model during the training phase.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
If you can recall, earlier I mentioned in this article that I would freeze all the layers of the model before fine-tuning it. So, let’s do it first.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This will prevent updating of model weights during fine-tuning. If you wish to fine-tune even the pre-trained weights of the BERT model then you should not execute the code above.
Moving on we will now let’s define our model architecture.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
We will use AdamW as our optimizer. It is an improved version of the Adam optimizer. To learn more about it do check out this paper.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
There is a class imbalance in our dataset. The majority of the observations are not spam. So, we will first compute class weights for the labels in the train set and then pass these weights to the loss function so that it takes care of the class imbalance.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
So, till now we have defined the model architecture, we have specified the optimizer and the loss function, and our dataloaders are also ready. Now we have to define a couple of functions to train (fine-tune) and evaluate the model, respectively.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
We will use the following function to evaluate the model. It will use the validation set data.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Now we will finally start fine-tuning of the model.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Training Loss: 0.592
Validation Loss: 0.567
Epoch 5 / 10
Batch 50 of 122.
Batch 100 of 122.
Evaluating...
Training Loss: 0.566
Validation Loss: 0.543
Epoch 6 / 10
Batch 50 of 122.
Batch 100 of 122.
Evaluating...
Training Loss: 0.552
Validation Loss: 0.525
Epoch 7 / 10
Batch 50 of 122.
Batch 100 of 122.
Evaluating...
Training Loss: 0.525
Validation Loss: 0.498
Epoch 8 / 10
Batch 50 of 122.
Batch 100 of 122.
Evaluating...
Training Loss: 0.507
Validation Loss: 0.477
Epoch 9 / 10
Batch 50 of 122.
Batch 100 of 122.
Evaluating...
Training Loss: 0.488
Validation Loss: 0.461
Epoch 10 / 10
Batch 50 of 122.
Batch 100 of 122.
Evaluating...
Training Loss: 0.474
Validation Loss: 0.454
You can see that the validation loss is still decreasing at the end of the 10th epoch. So, you may try a higher number of epochs. Now let’s see how well it performs on the test dataset.
Make Predictions
To make predictions, we will first of all load the best model weights which were saved during the training process.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Once the weights are loaded, we can use the fine-tuned model to make predictions on the test set.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Both recall and precision for class 1 are quite high which means that the model predicts this class pretty well. However, our objective was to detect spam messages, so misclassifying class 1 (spam) samples is a bigger concern than misclassifying class 0 samples. If you look at the recall for class 1, it is 0.90 which means that the model was able to correctly classify 90% of the spam messages. However, precision is a bit on the lower side for class 1. It means that the model misclassifies some of the class 0 messages (not spam) as spam.
To summarize, in this article, we fine-tuned a pre-trained BERT model to perform text classification on a very small dataset. I urge you to fine-tune BERT on a different dataset and see how it performs. You can even perform multiclass or multi-label classification with the help of BERT. In addition to that, you can even train the entire BERT architecture as well if you have a bigger dataset.
In case you are looking for a roadmap to becoming an expert in NLP read the following article-
Data Scientist at Analytics Vidhya with multidisciplinary academic background. Experienced in machine learning, NLP, graphs & networks. Passionate about learning and applying data science to solve real world problems.
Hi,
Thanks a lot for very detailed explanation. I have some doubt around precision/recall inference you made. You mentioned recall for class 1 is high (0.90) so model will correctly identify spam 90% of time , however doesn't that metric should be accuracy ? Also you mentioned precision is on lower side (0.39) which means we would misclassify non-spam as spam , however how do you interpret high precision for class 0 in same context ? if possible please expand little more on precision/recall for both spam (1) & ham(class 0).
Thanks & Regards
John ODonovan
Hi Dinesh,
Nice tutorial, thanks! I got similar results to you. ( I used my own GPU box instead of Colab)
precision recall f1-score support
0 0.97 0.87 0.91 724
1 0.48 0.81 0.61 112
accuracy 0.86 836
macro avg 0.73 0.84 0.76 836
weighted avg 0.90 0.86 0.87 836
Tegene
Hi Prateek !
Thank you very much for your hands on explanation on such comple concept. I am enjoying your tutorials as well.
Currently, am working a project on transfer learning for netx word prediction for one of my local languages. the language uses latin letters(english letters) and it uses a very long suffixes when it is inflected, it is also a low resource language.
so which method would you recommend?
if you can please put the steps for me.
thanks in advance
We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.
Show details
Powered By
Cookies
This site uses cookies to ensure that you get the best experience possible. To learn more about how we use cookies, please refer to our Privacy Policy & Cookies Policy.
brahmaid
It is needed for personalizing the website.
csrftoken
This cookie is used to prevent Cross-site request forgery (often abbreviated as CSRF) attacks of the website
Identityid
Preserves the login/logout state of users across the whole site.
sessionid
Preserves users' states across page requests.
g_state
Google One-Tap login adds this g_state cookie to set the user status on how they interact with the One-Tap modal.
MUID
Used by Microsoft Clarity, to store and track visits across websites.
_clck
Used by Microsoft Clarity, Persists the Clarity User ID and preferences, unique to that site, on the browser. This ensures that behavior in subsequent visits to the same site will be attributed to the same user ID.
_clsk
Used by Microsoft Clarity, Connects multiple page views by a user into a single Clarity session recording.
SRM_I
Collects user data is specifically adapted to the user or device. The user can also be followed outside of the loaded website, creating a picture of the visitor's behavior.
SM
Use to measure the use of the website for internal analytics
CLID
The cookie is set by embedded Microsoft Clarity scripts. The purpose of this cookie is for heatmap and session recording.
SRM_B
Collected user data is specifically adapted to the user or device. The user can also be followed outside of the loaded website, creating a picture of the visitor's behavior.
_gid
This cookie is installed by Google Analytics. The cookie is used to store information of how visitors use a website and helps in creating an analytics report of how the website is doing. The data collected includes the number of visitors, the source where they have come from, and the pages visited in an anonymous form.
_ga_#
Used by Google Analytics, to store and count pageviews.
_gat_#
Used by Google Analytics to collect data on the number of times a user has visited the website as well as dates for the first and most recent visit.
collect
Used to send data to Google Analytics about the visitor's device and behavior. Tracks the visitor across devices and marketing channels.
AEC
cookies ensure that requests within a browsing session are made by the user, and not by other sites.
G_ENABLED_IDPS
use the cookie when customers want to make a referral from their gmail contacts; it helps auth the gmail account.
test_cookie
This cookie is set by DoubleClick (which is owned by Google) to determine if the website visitor's browser supports cookies.
_we_us
this is used to send push notification using webengage.
WebKlipperAuth
used by webenage to track auth of webenagage.
ln_or
Linkedin sets this cookie to registers statistical data on users' behavior on the website for internal analytics.
JSESSIONID
Use to maintain an anonymous user session by the server.
li_rm
Used as part of the LinkedIn Remember Me feature and is set when a user clicks Remember Me on the device to make it easier for him or her to sign in to that device.
AnalyticsSyncHistory
Used to store information about the time a sync with the lms_analytics cookie took place for users in the Designated Countries.
lms_analytics
Used to store information about the time a sync with the AnalyticsSyncHistory cookie took place for users in the Designated Countries.
liap
Cookie used for Sign-in with Linkedin and/or to allow for the Linkedin follow feature.
visit
allow for the Linkedin follow feature.
li_at
often used to identify you, including your name, interests, and previous activity.
s_plt
Tracks the time that the previous page took to load
lang
Used to remember a user's language setting to ensure LinkedIn.com displays in the language selected by the user in their settings
s_tp
Tracks percent of page viewed
AMCV_14215E3D5995C57C0A495C55%40AdobeOrg
Indicates the start of a session for Adobe Experience Cloud
s_pltp
Provides page name value (URL) for use by Adobe Analytics
s_tslv
Used to retain and fetch time since last visit in Adobe Analytics
li_theme
Remembers a user's display preference/theme setting
li_theme_set
Remembers which users have updated their display / theme preferences
We do not use cookies of this type.
_gcl_au
Used by Google Adsense, to store and track conversions.
SID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
SAPISID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
__Secure-#
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
APISID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
SSID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
HSID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
DV
These cookies are used for the purpose of targeted advertising.
NID
These cookies are used for the purpose of targeted advertising.
1P_JAR
These cookies are used to gather website statistics, and track conversion rates.
OTZ
Aggregate analysis of website visitors
_fbp
This cookie is set by Facebook to deliver advertisements when they are on Facebook or a digital platform powered by Facebook advertising after visiting this website.
fr
Contains a unique browser and user ID, used for targeted advertising.
bscookie
Used by LinkedIn to track the use of embedded services.
lidc
Used by LinkedIn for tracking the use of embedded services.
bcookie
Used by LinkedIn to track the use of embedded services.
aam_uuid
Use these cookies to assign a unique ID when users visit a website.
UserMatchHistory
These cookies are set by LinkedIn for advertising purposes, including: tracking visitors so that more relevant ads can be presented, allowing users to use the 'Apply with LinkedIn' or the 'Sign-in with LinkedIn' functions, collecting information about how visitors use the site, etc.
li_sugr
Used to make a probabilistic match of a user's identity outside the Designated Countries
MR
Used to collect information for analytics purposes.
ANONCHK
Used to store session ID for a users session to ensure that clicks from adverts on the Bing search engine are verified for reporting purposes and for personalisation
We do not use cookies of this type.
Cookie declaration last updated on 24/03/2023 by Analytics Vidhya.
Cookies are small text files that can be used by websites to make a user's experience more efficient. The law states that we can store cookies on your device if they are strictly necessary for the operation of this site. For all other types of cookies, we need your permission. This site uses different types of cookies. Some cookies are placed by third-party services that appear on our pages. Learn more about who we are, how you can contact us, and how we process personal data in our Privacy Policy.
Hi, Thanks a lot for very detailed explanation. I have some doubt around precision/recall inference you made. You mentioned recall for class 1 is high (0.90) so model will correctly identify spam 90% of time , however doesn't that metric should be accuracy ? Also you mentioned precision is on lower side (0.39) which means we would misclassify non-spam as spam , however how do you interpret high precision for class 0 in same context ? if possible please expand little more on precision/recall for both spam (1) & ham(class 0). Thanks & Regards
Hi Dinesh, I have said that the model was able to correctly classify 90% of the spam messages. It means that if there were 100 spam messages in the unseen dataset then the model would have classified 90 of them as spam. Same logic can be applied for the ham (class 0).
Hi Dinesh, Nice tutorial, thanks! I got similar results to you. ( I used my own GPU box instead of Colab) precision recall f1-score support 0 0.97 0.87 0.91 724 1 0.48 0.81 0.61 112 accuracy 0.86 836 macro avg 0.73 0.84 0.76 836 weighted avg 0.90 0.86 0.87 836
Hi Prateek ! Thank you very much for your hands on explanation on such comple concept. I am enjoying your tutorials as well. Currently, am working a project on transfer learning for netx word prediction for one of my local languages. the language uses latin letters(english letters) and it uses a very long suffixes when it is inflected, it is also a low resource language. so which method would you recommend? if you can please put the steps for me. thanks in advance
Hi, We can solve the next word prediction problem with the help of a language model. The good thing is that you don't need labeled dataset to train a language model. So, you can either train a language model from scratch or use a pre-trained model such as GPT-2. But for your language I don't think there would be any pre-trained model avaialable. So, you should try to collect as much as data possible and train a language model from scratch to predict next word.