.. _sec_natural-language-inference-bert:
Natural Language Inference: Fine-Tuning BERT
============================================
In earlier sections of this chapter, we have designed an attention-based
architecture (in :numref:`sec_natural-language-inference-attention`)
for the natural language inference task on the SNLI dataset (as
described in :numref:`sec_natural-language-inference-and-dataset`).
Now we revisit this task by fine-tuning BERT. As discussed in
:numref:`sec_finetuning-bert`, natural language inference is a
sequence-level text pair classification problem, and fine-tuning BERT
only requires an additional MLP-based architecture, as illustrated in
:numref:`fig_nlp-map-nli-bert`.
.. _fig_nlp-map-nli-bert:
.. figure:: ../img/nlp-map-nli-bert.svg
This section feeds pretrained BERT to an MLP-based architecture for
natural language inference.
In this section, we will download a pretrained small version of BERT,
then fine-tune it for natural language inference on the SNLI dataset.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import json
import multiprocessing
import os
import torch
from torch import nn
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import json
import multiprocessing
import os
from mxnet import gluon, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
Loading Pretrained BERT
-----------------------
We have explained how to pretrain BERT on the WikiText-2 dataset in
:numref:`sec_bert-dataset` and :numref:`sec_bert-pretraining` (note
that the original BERT model is pretrained on much bigger corpora). As
discussed in :numref:`sec_bert-pretraining`, the original BERT model
has hundreds of millions of parameters. In the following, we provide two
versions of pretrained BERT: “bert.base” is about as big as the original
BERT base model that requires a lot of computational resources to
fine-tune, while “bert.small” is a small version to facilitate
demonstration.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip',
'225d66f04cae318b841a13d32af3acc165f253ac')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip',
'c72329e68a732bef0452e4b96a1c341c8910f81f')
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.zip',
'7b3820b35da691042e5d34c0971ac3edbd80d3f4')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.zip',
'a4e718a47137ccd1809c9107ab4f5edd317bae2c')
.. raw:: html
.. raw:: html
Either pretrained BERT model contains a “vocab.json” file that defines
the vocabulary set and a “pretrained.params” file of the pretrained
parameters. We implement the following ``load_pretrained_model``
function to load pretrained BERT parameters.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
num_heads, num_blks, dropout, max_len, devices):
data_dir = d2l.download_extract(pretrained_model)
# Define an empty vocabulary to load the predefined vocabulary
vocab = d2l.Vocab()
vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
vocab.token_to_idx = {token: idx for idx, token in enumerate(
vocab.idx_to_token)}
bert = d2l.BERTModel(
len(vocab), num_hiddens, ffn_num_hiddens=ffn_num_hiddens, num_heads=4,
num_blks=2, dropout=0.2, max_len=max_len)
# Load pretrained BERT parameters
bert.load_state_dict(torch.load(os.path.join(data_dir,
'pretrained.params')))
return bert, vocab
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
num_heads, num_blks, dropout, max_len, devices):
data_dir = d2l.download_extract(pretrained_model)
# Define an empty vocabulary to load the predefined vocabulary
vocab = d2l.Vocab()
vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
vocab.token_to_idx = {token: idx for idx, token in enumerate(
vocab.idx_to_token)}
bert = d2l.BERTModel(len(vocab), num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout, max_len)
# Load pretrained BERT parameters
bert.load_parameters(os.path.join(data_dir, 'pretrained.params'),
ctx=devices)
return bert, vocab
.. raw:: html
.. raw:: html
To facilitate demonstration on most of machines, we will load and
fine-tune the small version (“bert.small”) of the pretrained BERT in
this section. In the exercise, we will show how to fine-tune the much
larger “bert.base” to significantly improve the testing accuracy.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
num_blks=2, dropout=0.1, max_len=512, devices=devices)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Downloading ../data/bert.small.torch.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.torch.zip...
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
num_blks=2, dropout=0.1, max_len=512, devices=devices)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Downloading ../data/bert.small.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.zip...
[21:49:07] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
[21:49:08] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
[21:49:08] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
.. raw:: html
.. raw:: html
The Dataset for Fine-Tuning BERT
--------------------------------
For the downstream task natural language inference on the SNLI dataset,
we define a customized dataset class ``SNLIBERTDataset``. In each
example, the premise and hypothesis form a pair of text sequence and is
packed into one BERT input sequence as depicted in
:numref:`fig_bert-two-seqs`. Recall :numref:`subsec_bert_input_rep`
that segment IDs are used to distinguish the premise and the hypothesis
in a BERT input sequence. With the predefined maximum length of a BERT
input sequence (``max_len``), the last token of the longer of the input
text pair keeps getting removed until ``max_len`` is met. To accelerate
generation of the SNLI dataset for fine-tuning BERT, we use 4 worker
processes to generate training or testing examples in parallel.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class SNLIBERTDataset(torch.utils.data.Dataset):
def __init__(self, dataset, max_len, vocab=None):
all_premise_hypothesis_tokens = [[
p_tokens, h_tokens] for p_tokens, h_tokens in zip(
*[d2l.tokenize([s.lower() for s in sentences])
for sentences in dataset[:2]])]
self.labels = torch.tensor(dataset[2])
self.vocab = vocab
self.max_len = max_len
(self.all_token_ids, self.all_segments,
self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
print('read ' + str(len(self.all_token_ids)) + ' examples')
def _preprocess(self, all_premise_hypothesis_tokens):
pool = multiprocessing.Pool(4) # Use 4 worker processes
out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
all_token_ids = [
token_ids for token_ids, segments, valid_len in out]
all_segments = [segments for token_ids, segments, valid_len in out]
valid_lens = [valid_len for token_ids, segments, valid_len in out]
return (torch.tensor(all_token_ids, dtype=torch.long),
torch.tensor(all_segments, dtype=torch.long),
torch.tensor(valid_lens))
def _mp_worker(self, premise_hypothesis_tokens):
p_tokens, h_tokens = premise_hypothesis_tokens
self._truncate_pair_of_tokens(p_tokens, h_tokens)
tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
token_ids = self.vocab[tokens] + [self.vocab['
']] \
* (self.max_len - len(tokens))
segments = segments + [0] * (self.max_len - len(segments))
valid_len = len(tokens)
return token_ids, segments, valid_len
def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
# Reserve slots for '', '', and '' tokens for the BERT
# input
while len(p_tokens) + len(h_tokens) > self.max_len - 3:
if len(p_tokens) > len(h_tokens):
p_tokens.pop()
else:
h_tokens.pop()
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx]), self.labels[idx]
def __len__(self):
return len(self.all_token_ids)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class SNLIBERTDataset(gluon.data.Dataset):
def __init__(self, dataset, max_len, vocab=None):
all_premise_hypothesis_tokens = [[
p_tokens, h_tokens] for p_tokens, h_tokens in zip(
*[d2l.tokenize([s.lower() for s in sentences])
for sentences in dataset[:2]])]
self.labels = np.array(dataset[2])
self.vocab = vocab
self.max_len = max_len
(self.all_token_ids, self.all_segments,
self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
print('read ' + str(len(self.all_token_ids)) + ' examples')
def _preprocess(self, all_premise_hypothesis_tokens):
pool = multiprocessing.Pool(4) # Use 4 worker processes
out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
all_token_ids = [
token_ids for token_ids, segments, valid_len in out]
all_segments = [segments for token_ids, segments, valid_len in out]
valid_lens = [valid_len for token_ids, segments, valid_len in out]
return (np.array(all_token_ids, dtype='int32'),
np.array(all_segments, dtype='int32'),
np.array(valid_lens))
def _mp_worker(self, premise_hypothesis_tokens):
p_tokens, h_tokens = premise_hypothesis_tokens
self._truncate_pair_of_tokens(p_tokens, h_tokens)
tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
token_ids = self.vocab[tokens] + [self.vocab['
']] \
* (self.max_len - len(tokens))
segments = segments + [0] * (self.max_len - len(segments))
valid_len = len(tokens)
return token_ids, segments, valid_len
def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
# Reserve slots for '', '', and '' tokens for the BERT
# input
while len(p_tokens) + len(h_tokens) > self.max_len - 3:
if len(p_tokens) > len(h_tokens):
p_tokens.pop()
else:
h_tokens.pop()
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx]), self.labels[idx]
def __len__(self):
return len(self.all_token_ids)
.. raw:: html
.. raw:: html
After downloading the SNLI dataset, we generate training and testing
examples by instantiating the ``SNLIBERTDataset`` class. Such examples
will be read in minibatches during training and testing of natural
language inference.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
num_workers=num_workers)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
read 549367 examples
read 9824 examples
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gluon.data.DataLoader(test_set, batch_size,
num_workers=num_workers)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
read 549367 examples
read 9824 examples
.. raw:: html
.. raw:: html
Fine-Tuning BERT
----------------
As :numref:`fig_bert-two-seqs` indicates, fine-tuning BERT for natural
language inference requires only an extra MLP consisting of two fully
connected layers (see ``self.hidden`` and ``self.output`` in the
following ``BERTClassifier`` class). This MLP transforms the BERT
representation of the special “” token, which encodes the
information of both the premise and the hypothesis, into three outputs
of natural language inference: entailment, contradiction, and neutral.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class BERTClassifier(nn.Module):
def __init__(self, bert):
super(BERTClassifier, self).__init__()
self.encoder = bert.encoder
self.hidden = bert.hidden
self.output = nn.LazyLinear(3)
def forward(self, inputs):
tokens_X, segments_X, valid_lens_x = inputs
encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
return self.output(self.hidden(encoded_X[:, 0, :]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class BERTClassifier(nn.Block):
def __init__(self, bert):
super(BERTClassifier, self).__init__()
self.encoder = bert.encoder
self.hidden = bert.hidden
self.output = nn.Dense(3)
def forward(self, inputs):
tokens_X, segments_X, valid_lens_x = inputs
encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
return self.output(self.hidden(encoded_X[:, 0, :]))
.. raw:: html
.. raw:: html
In the following, the pretrained BERT model ``bert`` is fed into the
``BERTClassifier`` instance ``net`` for the downstream application. In
common implementations of BERT fine-tuning, only the parameters of the
output layer of the additional MLP (``net.output``) will be learned from
scratch. All the parameters of the pretrained BERT encoder
(``net.encoder``) and the hidden layer of the additional MLP
(``net.hidden``) will be fine-tuned.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = BERTClassifier(bert)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = BERTClassifier(bert)
net.output.initialize(ctx=devices)
.. raw:: html
.. raw:: html
Recall that in :numref:`sec_bert` both the ``MaskLM`` class and the
``NextSentencePred`` class have parameters in their employed MLPs. These
parameters are part of those in the pretrained BERT model ``bert``, and
thus part of parameters in ``net``. However, such parameters are only
for computing the masked language modeling loss and the next sentence
prediction loss during pretraining. These two loss functions are
irrelevant to fine-tuning downstream applications, thus the parameters
of the employed MLPs in ``MaskLM`` and ``NextSentencePred`` are not
updated (staled) when BERT is fine-tuned.
To allow parameters with stale gradients, the flag
``ignore_stale_grad=True`` is set in the ``step`` function of
``d2l.train_batch_ch13``. We use this function to train and evaluate the
model ``net`` using the training set (``train_iter``) and the testing
set (``test_iter``) of SNLI. Due to the limited computational resources,
the training and testing accuracy can be further improved: we leave its
discussions in the exercises.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
lr, num_epochs = 1e-4, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
net(next(iter(train_iter))[0])
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
loss 0.520, train acc 0.791, test acc 0.786
10588.8 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]
.. figure:: output_natural-language-inference-bert_1857e6_75_1.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
lr, num_epochs = 1e-4, 5
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
loss = gluon.loss.SoftmaxCrossEntropyLoss()
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices,
d2l.split_batch_multi_inputs)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
loss 0.477, train acc 0.811, test acc 0.789
4652.5 examples/sec on [gpu(0), gpu(1)]
.. figure:: output_natural-language-inference-bert_1857e6_78_1.svg
.. raw:: html
.. raw:: html
Summary
-------
- We can fine-tune the pretrained BERT model for downstream
applications, such as natural language inference on the SNLI dataset.
- During fine-tuning, the BERT model becomes part of the model for the
downstream application. Parameters that are only related to
pretraining loss will not be updated during fine-tuning.
Exercises
---------
1. Fine-tune a much larger pretrained BERT model that is about as big as
the original BERT base model if your computational resource allows.
Set arguments in the ``load_pretrained_model`` function as: replacing
‘bert.small’ with ‘bert.base’, increasing values of
``num_hiddens=256``, ``ffn_num_hiddens=512``, ``num_heads=4``, and
``num_blks=2`` to 768, 3072, 12, and 12, respectively. By increasing
fine-tuning epochs (and possibly tuning other hyperparameters), can
you get a testing accuracy higher than 0.86?
2. How to truncate a pair of sequences according to their ratio of
length? Compare this pair truncation method and the one used in the
``SNLIBERTDataset`` class. What are their pros and cons?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html