Surprisal/Shannon Information¶

The Shannon information, also called the self-information, information content, or surprisal was invented by Claude Shannon as a way to create a monotonically decreasing function of probability.

That is, as a probability becomes larger, the Shannon information becomes smaller, and vice versa.

The intuitive idea behind this measurement is that an unlikely event is more surprising and therefore "contains more information," while a likely event is obvious and "contains less information."

Definition¶

Shannon defined surprisal as follows:

$$ \text{I}(X=x) := -\log[\text{Pr}(X=x)] = \log\left(\frac{1}{\text{Pr}(X=x)}\right) $$

where $\text{Pr}$ means "probability of," $X$ is a discrete random variable, and $x$ is some particular outcome of the random variable $X$.

Note that the rules of logarithms allow either of the two expressions (negative log probability or log of reciprocal of probability) - they are equivalent.

Some motivation behind the definition¶

Let's say we learn not one, but two pieces of information, which are independent of each other.

Shouldn't the total information gained be the sum of the information contents of the two events?

Well, if we have two independent events $x$ and $y$ (observed instances of random variables $X$ and $Y$), we know that:

$$ \text{Pr}(X=x, Y=y) = \text{Pr}(X=x) \cdot \text{Pr}(Y=y) $$

If we take the logarithm of both sides (and simplify the notation a bit):

$$ \log\bigg(\text{Pr}(x, y)\bigg) = \log\bigg(\text{Pr}(x) \cdot \text{Pr}(y)\bigg) $$

Applying a rule of logarithms:

$$ \log\bigg(\text{Pr}(x, y)\bigg) = \log\bigg(\text{Pr}(x)\bigg) + \log\bigg(\text{Pr}(y)\bigg) $$

Now an issue is that the logarithm of a number strictly between 0 and 1 (i.e. all valid probabilities besides 0 and 1) is negative. Since positive numbers are more friendly, we'll just take the negative of both sides:

$$ -\log\bigg(\text{Pr}(x, y)\bigg) = -\log\bigg(\text{Pr}(x)\bigg) - \log\bigg(\text{Pr}(y)\bigg) $$

And that is the (non-negative) information content of two independent observations $x$ and $y$!

Something great about this definition is that it agrees intuitively with the "surprising" vs. "obvious" we were talking about.

Consider an event whose probability is exactly 1 - we are certain the event will occur. This is completely obvious, and thus the information gained from observing it is 0:

$$ -\log(1) = \log\left(\frac{1}{1}\right) = \log(1) = 0 $$

Now consider an event whose probability approaches 0 - we are nearly certain the event will NOT occur. As such, observing this event is extremely surprising, and the information gained from observing it is substantial:

$$ \lim_{\text{Pr}(x)\to0^+}-\log(\text{Pr}(x)) = \lim_{\text{Pr}(x)\to0^+}\log\left(\frac{1}{\text{Pr}(x)}\right) ``=" \infty $$

That is, for events with very small probabilities, the information gained from observing them approaches infinity. You can keep testing smaller and smaller probabilities, and you will find that the information gained from observing events with such probabilities increases without bound.

Visual proof of logarithm rules¶

In case you don't believe the logarithm rules, here are some graphs I made. You can click the colored buttons on the left-hand side to show/hide the different graphs.

Numpy Implementation¶

Here are some simple implementation(s) of Shannon information (computed for each element in the list provided).

The choice of base (2, $e$, or 10) affects the "units" of the information gained. When using base 2, the unit of information gain is called a shannon. For base $e$, it is a nat, and for base 10, it is a hartley.

In [1]:
from numpy import log2, array

def info(probs):
    return -log2(array(probs))

def also_info(probs):
    return log2(1 / array(probs))
In [2]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd

probs = [1, 1/2, 1/32, 1/32768, 1/1073741824]

df = pd.DataFrame({
    'probs': probs,
    'info': info(probs),
    'also_info': also_info(probs)
})

print(df)
          probs  info  also_info
0  1.000000e+00  -0.0        0.0
1  5.000000e-01   1.0        1.0
2  3.125000e-02   5.0        5.0
3  3.051758e-05  15.0       15.0
4  9.313226e-10  30.0       30.0

Computing GPT-2's surprisal over a few sequences¶

GPT-2 is a generative, decoder-only/autoregressive LM. As such, we can compute surprisal at each time step by just feeding the sequence into the model and taking a softmax over the logits produced at each time step.

In [3]:
import matplotlib.pyplot as plt
In [4]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch

# load GPT2 and its tokenizer
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model.eval()

def gpt_surprisal(sentence, num):

    # tokenize sentence into a torch tensor
    inputs = tokenizer(sentence, return_tensors="pt")
    input_ids = inputs['input_ids']

    # pass token sequence through model
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits

    # softmax over the vocabulary dimension of the outputs
    softmax = torch.nn.functional.softmax(logits, dim=-1)

    df = pd.DataFrame(columns=['Word', 'Surprisal'])

    for i in range(input_ids.size(1)):

        # i-th word's token ID in our token sequence
        word_id = input_ids[0, i].item()

        # probability assigned to the next word given the preceding context
        word_prob = softmax[0, i-1, word_id]

        # torch implementation of surprisal
        surprisal = -torch.log2(word_prob)

        # get the string representation of the word back again
        word = tokenizer.decode([word_id])

        # record model's surprisal for this word
        df.loc[len(df)] = [word, surprisal.item()]

    plt.subplot(1, 4, num)
    plt.plot(df['Surprisal'], marker='.')
    plt.xticks(ticks=range(len(df)), labels=df['Word'], rotation=75)
    plt.yticks(ticks=range(0, 25, 2))
    plt.grid()
    plt.title('...' + sentence.split()[-1][:-1])
    plt.xlabel('Word')
    plt.ylabel('Surprisal (bits)')

    print(df)
    print('\n\n')
In [5]:
plt.figure(figsize=(14, 4))
plt.tight_layout()
gpt_surprisal("The quick brown fox jumps over the lazy dog.", 1)
gpt_surprisal("The quick brown fox jumps over the lazy corrosion.", 2)
gpt_surprisal("The quick brown fox jumps over the lazy of.", 3)
gpt_surprisal("The quick brown fox jumps over the lazy perro.", 4)
plt.show()
     Word  Surprisal
0     The  11.515246
1   quick  12.569286
2   brown  12.846173
3     fox   9.033367
4   jumps   7.005901
5    over   3.964435
6     the   1.242006
7    lazy  10.458117
8     dog   5.507222
9       .   3.469883



         Word  Surprisal
0         The  11.760913
1       quick  12.569286
2       brown  12.846173
3         fox   9.033367
4       jumps   7.005901
5        over   3.964435
6         the   1.242006
7        lazy  10.458117
8   corrosion  20.361383
9           .   5.036692



     Word  Surprisal
0     The  10.904470
1   quick  12.569286
2   brown  12.846173
3     fox   9.033367
4   jumps   7.005901
5    over   3.964435
6     the   1.242006
7    lazy  10.458117
8      of  11.482323
9       .  13.755750



      Word  Surprisal
0      The  12.224942
1    quick  12.569286
2    brown  12.846173
3      fox   9.033367
4    jumps   7.005901
5     over   3.964435
6      the   1.242006
7     lazy  10.458117
8      per  13.856278
9       ro  10.085159
10       .   5.876067



No description has been provided for this image

Computing BERT's surprisal over a few sequences¶

BERT is a discriminative, encoder-only/masked LM. As such, we can compute surprisal at each time step by masking out each token, one at a time, and taking a softmax over the logits produced at the masked time step.

In [6]:
from transformers import BertTokenizer, BertForMaskedLM
import torch

# load GPT2 and its tokenizer
model_name = "bert-base-uncased"
model = BertForMaskedLM.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
model.eval()

def bert_surprisal(sentence, num):

    # tokenize sentence into a torch tensor
    inputs = tokenizer(sentence, return_tensors="pt")
    input_ids = inputs['input_ids']

    df = pd.DataFrame(columns=['Word', 'Surprisal'])

    for i in range(1, input_ids.size(1) - 1):

        # mask a token
        masked_input_ids = input_ids.clone()
        masked_input_ids[0, i] = tokenizer.mask_token_id

        with torch.no_grad():
            outputs = model(masked_input_ids)
            logits = outputs.logits

        # softmax over the vocabulary dimension of the outputs
        # restrict to current token to save a bit of compute
        softmax = torch.nn.functional.softmax(logits[0, i], dim=-1)

        # i-th word's token ID in our token sequence
        word_id = input_ids[0, i].item()

        # probability assigned to the next word given the preceding context
        word_prob = softmax[word_id]

        # torch implementation of surprisal
        surprisal = -torch.log2(word_prob)

        # get the string representation of the word back again
        word = tokenizer.decode([word_id])

        # record model's surprisal for this word
        df.loc[len(df)] = [word, surprisal.item()]

    plt.subplot(1, 4, num)
    plt.plot(df['Surprisal'], marker='.')
    plt.xticks(ticks=range(len(df)), labels=df['Word'], rotation=75)
    plt.yticks(ticks=range(0, 25, 2))
    plt.grid()
    plt.title('...' + sentence.split()[-1][:-1])
    plt.xlabel('Word')
    plt.ylabel('Surprisal (bits)')

    print(df)
    print('\n\n')
BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
In [7]:
plt.figure(figsize=(14, 4))
plt.tight_layout()
bert_surprisal("The quick brown fox jumps over the lazy dog.", 1)
bert_surprisal("The quick brown fox jumps over the lazy corrosion.", 2)
bert_surprisal("The quick brown fox jumps over the lazy of.", 3)
bert_surprisal("The quick brown fox jumps over the lazy perro.", 4)
plt.show()
    Word  Surprisal
0    the   2.023821
1  quick  16.252876
2  brown   7.069638
3    fox   6.454735
4  jumps   9.147573
5   over   4.555444
6    the   0.270881
7   lazy   9.540742
8    dog   7.727884
9      .   0.029186



        Word  Surprisal
0        the   3.166586
1      quick  15.813129
2      brown   7.115983
3        fox   7.785961
4      jumps  10.516631
5       over   3.302016
6        the   0.180737
7       lazy  13.221567
8  corrosion  20.197489
9          .   0.023570



    Word  Surprisal
0    the   2.915615
1  quick  15.318841
2  brown   7.797459
3    fox   9.323952
4  jumps   9.435372
5   over   6.029813
6    the   0.941009
7   lazy  19.864275
8     of  15.490068
9      .   0.126942



     Word  Surprisal
0     the   2.107822
1   quick  15.464628
2   brown   6.714638
3     fox   6.556987
4   jumps   9.226711
5    over   3.212596
6     the   0.298637
7    lazy  10.006992
8     per   9.213134
9    ##ro  13.693556
10      .   0.047083



No description has been provided for this image

Character Surprisal¶

Cleaning up the Brown corpus doesn't give us exactly what we expect, but it gives some interesting insights.

In [8]:
from string import punctuation
import nltk
nltk.download('brown')
from nltk.corpus import brown
import re

text = ''

pos_tag_regex = r'/[`\'\$A-Za-z\-\.]* '
white_space_not_space_char_regex = r'[\n\t\r]'

for fileid in brown.fileids():
    text += re.sub(pos_tag_regex, ' ', brown.raw(fileid))

text = re.sub(white_space_not_space_char_regex, '', text)

punctuation = '!"#$%&\'()*+,./:;<=>?@[\\]^_`{|}~'

digits = '0123456789'

for char in digits:
    text = text.replace(char, '')

for char in punctuation:
    text = text.replace(char, '')

text = text.replace(' -', ' ')
text = text.replace(' - ', ' ')
text = text.replace('- ', ' ')
text = text.replace('-', ' ')

while '  ' in text:
    text = text.replace('  ', ' ')

text = text.lower().strip()

text[:1000]
[nltk_data] Downloading package brown to
[nltk_data]     C:\Users\danie\AppData\Roaming\nltk_data...
[nltk_data]   Package brown is already up-to-date!
Out[8]:
'the fulton county grand jury said friday an investigation of atlantas recent primary election produced no evidence that any irregularities took place the jury further said in term end presentments that the city executive committee which had over all charge of the election deserves the praise and thanks of the city of atlanta for the manner in which the election was conducted the september october term jury had been charged by fulton superior court judge durwood pye to investigate reports of possible irregularities in the hard fought primary which was won by mayor nominate ivan allen jr only a relative handful of such reports was received the jury said considering the widespread interest in the election the number of voters and the size of this city the jury said it did find that many of georgias registration and election laws are outmoded or inadequate and often ambiguous it recommended that fulton legislators act to have these laws studied and revised to the end of modernizing and imp'

Below, we build a vector of unigram counts and a matrix of bigram counts.

In [9]:
import string
import numpy as np

mapping = {
    c: i for i, c in enumerate(' ' + string.ascii_lowercase)
}

unigram_char_count_vector = np.array([0] * len(mapping))
bigram_char_count_matrix = np.array([[0] * len(mapping)] * len(mapping))

for i in range(len(text) - 1):
    unigram_char_count_vector[mapping[text[i]]] += 1
    bigram_char_count_matrix[mapping[text[i]], mapping[text[i + 1]]] += 1
unigram_char_count_vector[mapping[text[-1]]] += 1

Heatmaps for unigrams¶

In [10]:
from seaborn import heatmap

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 3))

normed_vector = unigram_char_count_vector / sum(unigram_char_count_vector)

heatmap(
    [normed_vector],
    xticklabels=list(mapping.keys()),
    yticklabels=[],
    ax=ax1,
    vmin=0,
    vmax=0.2
)
ax1.set_title('Unigram Probabilities')
ax1.set_xlabel('Character')
ax1.set_ylabel('Frequency')
ax1.set_xticklabels(ax1.get_xticklabels(), rotation=0)

heatmap(
    [info(normed_vector)],
    xticklabels=list(mapping.keys()),
    yticklabels=[],
    ax=ax2,
    vmin=0,
    vmax=16
)
ax2.set_title('Unigram Surprisal')
ax2.set_xlabel('Character')
ax2.set_ylabel('Information')
ax2.set_xticklabels(ax1.get_xticklabels(), rotation=0)

plt.tight_layout()
plt.show()
No description has been provided for this image

Heatmaps for bigrams¶

In [11]:
from seaborn import heatmap

fig, (ax3, ax4) = plt.subplots(1, 2, figsize=(10, 4))

normed_matrix = bigram_char_count_matrix / bigram_char_count_matrix.sum(axis=1, keepdims=True)

heatmap(
    normed_matrix.transpose(),
    xticklabels=list(mapping.keys()),
    yticklabels=list(mapping.keys()),
    ax=ax3,
    vmin=0,
    vmax=0.8
)
ax3.set_title('Bigram Probabilities')
ax3.set_xlabel('Character 1')
ax3.set_ylabel('Character 2')
ax3.set_xticklabels(ax3.get_xticklabels(), rotation=0)

heatmap(
    info(normed_matrix.transpose()),
    xticklabels=list(mapping.keys()),
    yticklabels=list(mapping.keys()),
    ax=ax4,
    vmin=0,
    vmax=16
)
ax4.set_title('Surprisal of Character 2 Given Character 1')
ax4.set_xlabel('Character 1')
ax4.set_ylabel('Character 2')
ax4.set_xticklabels(ax4.get_xticklabels(), rotation=0)

plt.tight_layout()
plt.show()
No description has been provided for this image