Entropy optimization using PyTorch¶

First, let's import some things from PyTorch that we'll need to perform this optimization. We'll use PyTorch's optimization framework to create a distribution with our desired entropy.

  • We'll rely on some functions from the root torch module.
  • We'll use the nn module to create an nn.Param to store the distribution itself.
  • Finally, for compatibility with a Python language server, I find it's nice to import Tensor so that I can use it in type hints.
In [1]:
import torch
from torch import nn
from torch import Tensor

Let's define a custom loss function that we can use to bring our distribution's entropy closer to our target entropy. A simple way to do this is (with gradients enabled):

  • Normalize the distribution to sum to 1.
  • Calculate the entropy of the distribution.
  • Use a distance metric such as Mean Squared Error to compare the distribution's entropy to the target entropy.

Despite its simplicity, this works shockingly well!

In [2]:
class MSEAgainstEntropyLoss(nn.Module):
    
    def __init__(self):
        super().__init__()
        
    def forward(self, dist: Tensor, true_entropy: Tensor) -> Tensor:
        
        # normalize the distribution
        dist = torch.softmax(dist, dim=0)
        
        # calculate the entropy
        approx_entropy = -(dist * dist.log()).sum()
        
        # calculate the mean squared error
        mse = (approx_entropy - true_entropy) ** 2
        
        # return the loss tensor
        return mse

Now, we'll define a function that returns a distribution with the desired entropy (or as close as we can get). It will take the following arguments:

  • criterion (nn.Module): The loss function to use to optimize the entropy (in this case, it will just be the class we defined above).
  • support_size (int): The number of outcomes we want our random variable to have.
  • desired_entropy (float): The entropy we want our distribution to have.
  • lr (float): The learning rate for the optimization algorithm.
  • tol (float): The tolerance for the optimization algorithm. We will stop when the loss is less than this value.
  • max_iter (int): The maximum number of iterations to run the optimization algorithm.
  • do_logging (bool): Whether to log the loss during optimization.
  • log_freq (int): How often to log the loss during optimization.
In [3]:
def get_dist(
    criterion: nn.Module,
    support_size: int,
    desired_entropy: float,
    lr: float = 0.001,
    tol: float = 1e-6,
    max_iter: int = 100_000,
    do_logging: bool = True,
    log_freq: int = 200
) -> Tensor:
    
    # define a parameter (gradient updates possible) with the right support size
    dist = nn.Parameter(torch.randn((support_size,), dtype=torch.float64))
    dist.requires_grad = True
    
    # make a torch.Tensor with the desired entropy to compute loss
    DE = torch.tensor(desired_entropy, dtype=torch.float64)
    DE.requires_grad = False
    
    # define an optimizer over the parameter
    optimizer = torch.optim.AdamW([dist], lr=lr)
    
    i = 0
    if do_logging:
        print('-----------------------------------------------------')
    while True:
        
        # optimize the parameter
        optimizer.zero_grad()
        loss = criterion(dist, DE)
        loss.backward()
        optimizer.step()
        
        # log the loss
        if (i % log_freq == 0):
            loss_val = loss.item()
            if do_logging:
                print(f'loss: {loss_val:.4}')
            if loss_val < tol: # we are done if the loss is small enough
                break
        i += 1 # count iterations
        
        # give up if max_iter is reached
        if i > max_iter:
            msg = 'Optimization did not converge!'
            Warning(msg)
            break
    
    # renormalize
    final_dist = torch.softmax(dist, dim=0)
    
    # summary of results
    if do_logging:
        print('-----------------------------------------------------')
        print(f'sum of probabilities (should be 1): {final_dist.sum()}')
        approx_entropy = -(final_dist * final_dist.log()).sum()
        print('-----------------------------------------------------')
        print(f'desired entropy:    {desired_entropy}')
        print(f'true entropy:       {approx_entropy.item()}')
        print('-----------------------------------------------------')
    
    return final_dist
In [4]:
get_dist(
    criterion=MSEAgainstEntropyLoss(),
    support_size=10,
    desired_entropy=1.5707963267948966, # pi/2 because why not
    lr=0.001,
    tol=1e-12,
    max_iter=100_000,
    do_logging=True,
    log_freq=200
)
-----------------------------------------------------
loss: 0.3188
loss: 0.2091
loss: 0.1094
loss: 0.04938
loss: 0.02052
loss: 0.008044
loss: 0.002983
loss: 0.001043
loss: 0.0003462
loss: 0.0001118
loss: 3.726e-05
loss: 1.406e-05
loss: 6.524e-06
loss: 3.784e-06
loss: 2.589e-06
loss: 1.94e-06
loss: 1.515e-06
loss: 1.204e-06
loss: 9.639e-07
loss: 7.749e-07
loss: 6.247e-07
loss: 5.048e-07
loss: 4.086e-07
loss: 3.313e-07
loss: 2.69e-07
loss: 2.186e-07
loss: 1.778e-07
loss: 1.448e-07
loss: 1.179e-07
loss: 9.614e-08
loss: 7.841e-08
loss: 6.398e-08
loss: 5.223e-08
loss: 4.265e-08
loss: 3.483e-08
loss: 2.846e-08
loss: 2.326e-08
loss: 1.901e-08
loss: 1.554e-08
loss: 1.271e-08
loss: 1.039e-08
loss: 8.501e-09
loss: 6.954e-09
loss: 5.689e-09
loss: 4.654e-09
loss: 3.808e-09
loss: 3.116e-09
loss: 2.55e-09
loss: 2.087e-09
loss: 1.708e-09
loss: 1.398e-09
loss: 1.144e-09
loss: 9.368e-10
loss: 7.668e-10
loss: 6.277e-10
loss: 5.139e-10
loss: 4.207e-10
loss: 3.444e-10
loss: 2.82e-10
loss: 2.308e-10
loss: 1.89e-10
loss: 1.547e-10
loss: 1.267e-10
loss: 1.037e-10
loss: 8.493e-11
loss: 6.954e-11
loss: 5.694e-11
loss: 4.662e-11
loss: 3.817e-11
loss: 3.126e-11
loss: 2.559e-11
loss: 2.096e-11
loss: 1.716e-11
loss: 1.405e-11
loss: 1.151e-11
loss: 9.422e-12
loss: 7.715e-12
loss: 6.318e-12
loss: 5.173e-12
loss: 4.237e-12
loss: 3.469e-12
loss: 2.841e-12
loss: 2.327e-12
loss: 1.905e-12
loss: 1.56e-12
loss: 1.278e-12
loss: 1.047e-12
loss: 8.572e-13
-----------------------------------------------------
sum of probabilities (should be 1): 1.0
-----------------------------------------------------
desired entropy:    1.5707963267948966
true entropy:       1.5707972521604139
-----------------------------------------------------
Out[4]:
tensor([0.0203, 0.0098, 0.2433, 0.0118, 0.0225, 0.0226, 0.0166, 0.3209, 0.0217,
        0.3105], dtype=torch.float64, grad_fn=<SoftmaxBackward0>)