Entropy optimization using PyTorch¶
First, let's import some things from PyTorch that we'll need to perform this optimization. We'll use PyTorch's optimization framework to create a distribution with our desired entropy.
- We'll rely on some functions from the root torchmodule.
- We'll use the nnmodule to create annn.Paramto store the distribution itself.
- Finally, for compatibility with a Python language server, I find it's nice to import Tensorso that I can use it in type hints.
In [1]:
 import torch
from torch import nn
from torch import Tensor
Let's define a custom loss function that we can use to bring our distribution's entropy closer to our target entropy. A simple way to do this is (with gradients enabled):
- Normalize the distribution to sum to 1.
- Calculate the entropy of the distribution.
- Use a distance metric such as Mean Squared Error to compare the distribution's entropy to the target entropy.
Despite its simplicity, this works shockingly well!
In [2]:
 class MSEAgainstEntropyLoss(nn.Module):
    
    def __init__(self):
        super().__init__()
        
    def forward(self, dist: Tensor, true_entropy: Tensor) -> Tensor:
        
        # normalize the distribution
        dist = torch.softmax(dist, dim=0)
        
        # calculate the entropy
        approx_entropy = -(dist * dist.log()).sum()
        
        # calculate the mean squared error
        mse = (approx_entropy - true_entropy) ** 2
        
        # return the loss tensor
        return mse
Now, we'll define a function that returns a distribution with the desired entropy (or as close as we can get). It will take the following arguments:
- criterion (nn.Module): The loss function to use to optimize the entropy (in this case, it will just be the class we defined above).
- support_size (int): The number of outcomes we want our random variable to have.
- desired_entropy (float): The entropy we want our distribution to have.
- lr (float): The learning rate for the optimization algorithm.
- tol (float): The tolerance for the optimization algorithm. We will stop when the loss is less than this value.
- max_iter (int): The maximum number of iterations to run the optimization algorithm.
- do_logging (bool): Whether to log the loss during optimization.
- log_freq (int): How often to log the loss during optimization.
In [3]:
 def get_dist(
    criterion: nn.Module,
    support_size: int,
    desired_entropy: float,
    lr: float = 0.001,
    tol: float = 1e-6,
    max_iter: int = 100_000,
    do_logging: bool = True,
    log_freq: int = 200
) -> Tensor:
    
    # define a parameter (gradient updates possible) with the right support size
    dist = nn.Parameter(torch.randn((support_size,), dtype=torch.float64))
    dist.requires_grad = True
    
    # make a torch.Tensor with the desired entropy to compute loss
    DE = torch.tensor(desired_entropy, dtype=torch.float64)
    DE.requires_grad = False
    
    # define an optimizer over the parameter
    optimizer = torch.optim.AdamW([dist], lr=lr)
    
    i = 0
    if do_logging:
        print('-----------------------------------------------------')
    while True:
        
        # optimize the parameter
        optimizer.zero_grad()
        loss = criterion(dist, DE)
        loss.backward()
        optimizer.step()
        
        # log the loss
        if (i % log_freq == 0):
            loss_val = loss.item()
            if do_logging:
                print(f'loss: {loss_val:.4}')
            if loss_val < tol: # we are done if the loss is small enough
                break
        i += 1 # count iterations
        
        # give up if max_iter is reached
        if i > max_iter:
            msg = 'Optimization did not converge!'
            Warning(msg)
            break
    
    # renormalize
    final_dist = torch.softmax(dist, dim=0)
    
    # summary of results
    if do_logging:
        print('-----------------------------------------------------')
        print(f'sum of probabilities (should be 1): {final_dist.sum()}')
        approx_entropy = -(final_dist * final_dist.log()).sum()
        print('-----------------------------------------------------')
        print(f'desired entropy:    {desired_entropy}')
        print(f'true entropy:       {approx_entropy.item()}')
        print('-----------------------------------------------------')
    
    return final_dist
In [4]:
 get_dist(
    criterion=MSEAgainstEntropyLoss(),
    support_size=10,
    desired_entropy=1.5707963267948966, # pi/2 because why not
    lr=0.001,
    tol=1e-12,
    max_iter=100_000,
    do_logging=True,
    log_freq=200
)
----------------------------------------------------- loss: 0.3188 loss: 0.2091 loss: 0.1094 loss: 0.04938 loss: 0.02052 loss: 0.008044 loss: 0.002983 loss: 0.001043 loss: 0.0003462 loss: 0.0001118 loss: 3.726e-05 loss: 1.406e-05 loss: 6.524e-06 loss: 3.784e-06 loss: 2.589e-06 loss: 1.94e-06 loss: 1.515e-06 loss: 1.204e-06 loss: 9.639e-07 loss: 7.749e-07 loss: 6.247e-07 loss: 5.048e-07 loss: 4.086e-07 loss: 3.313e-07 loss: 2.69e-07 loss: 2.186e-07 loss: 1.778e-07 loss: 1.448e-07 loss: 1.179e-07 loss: 9.614e-08 loss: 7.841e-08 loss: 6.398e-08 loss: 5.223e-08 loss: 4.265e-08 loss: 3.483e-08 loss: 2.846e-08 loss: 2.326e-08 loss: 1.901e-08 loss: 1.554e-08 loss: 1.271e-08 loss: 1.039e-08 loss: 8.501e-09 loss: 6.954e-09 loss: 5.689e-09 loss: 4.654e-09 loss: 3.808e-09 loss: 3.116e-09 loss: 2.55e-09 loss: 2.087e-09 loss: 1.708e-09 loss: 1.398e-09 loss: 1.144e-09 loss: 9.368e-10 loss: 7.668e-10 loss: 6.277e-10 loss: 5.139e-10 loss: 4.207e-10 loss: 3.444e-10 loss: 2.82e-10 loss: 2.308e-10 loss: 1.89e-10 loss: 1.547e-10 loss: 1.267e-10 loss: 1.037e-10 loss: 8.493e-11 loss: 6.954e-11 loss: 5.694e-11 loss: 4.662e-11 loss: 3.817e-11 loss: 3.126e-11 loss: 2.559e-11 loss: 2.096e-11 loss: 1.716e-11 loss: 1.405e-11 loss: 1.151e-11 loss: 9.422e-12 loss: 7.715e-12 loss: 6.318e-12 loss: 5.173e-12 loss: 4.237e-12 loss: 3.469e-12 loss: 2.841e-12 loss: 2.327e-12 loss: 1.905e-12 loss: 1.56e-12 loss: 1.278e-12 loss: 1.047e-12 loss: 8.572e-13 ----------------------------------------------------- sum of probabilities (should be 1): 1.0 ----------------------------------------------------- desired entropy: 1.5707963267948966 true entropy: 1.5707972521604139 -----------------------------------------------------
Out[4]:
 tensor([0.0203, 0.0098, 0.2433, 0.0118, 0.0225, 0.0226, 0.0166, 0.3209, 0.0217,
        0.3105], dtype=torch.float64, grad_fn=<SoftmaxBackward0>)