Entropy optimization using PyTorch¶
First, let's import some things from PyTorch that we'll need to perform this optimization. We'll use PyTorch's optimization framework to create a distribution with our desired entropy.
- We'll rely on some functions from the root
torch
module. - We'll use the
nn
module to create annn.Param
to store the distribution itself. - Finally, for compatibility with a Python language server, I find it's nice to import
Tensor
so that I can use it in type hints.
In [1]:
import torch
from torch import nn
from torch import Tensor
Let's define a custom loss function that we can use to bring our distribution's entropy closer to our target entropy. A simple way to do this is (with gradients enabled):
- Normalize the distribution to sum to 1.
- Calculate the entropy of the distribution.
- Use a distance metric such as Mean Squared Error to compare the distribution's entropy to the target entropy.
Despite its simplicity, this works shockingly well!
In [2]:
class MSEAgainstEntropyLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, dist: Tensor, true_entropy: Tensor) -> Tensor:
# normalize the distribution
dist = torch.softmax(dist, dim=0)
# calculate the entropy
approx_entropy = -(dist * dist.log()).sum()
# calculate the mean squared error
mse = (approx_entropy - true_entropy) ** 2
# return the loss tensor
return mse
Now, we'll define a function that returns a distribution with the desired entropy (or as close as we can get). It will take the following arguments:
criterion (nn.Module)
: The loss function to use to optimize the entropy (in this case, it will just be the class we defined above).support_size (int)
: The number of outcomes we want our random variable to have.desired_entropy (float)
: The entropy we want our distribution to have.lr (float)
: The learning rate for the optimization algorithm.tol (float)
: The tolerance for the optimization algorithm. We will stop when the loss is less than this value.max_iter (int)
: The maximum number of iterations to run the optimization algorithm.do_logging (bool)
: Whether to log the loss during optimization.log_freq (int)
: How often to log the loss during optimization.
In [3]:
def get_dist(
criterion: nn.Module,
support_size: int,
desired_entropy: float,
lr: float = 0.001,
tol: float = 1e-6,
max_iter: int = 100_000,
do_logging: bool = True,
log_freq: int = 200
) -> Tensor:
# define a parameter (gradient updates possible) with the right support size
dist = nn.Parameter(torch.randn((support_size,), dtype=torch.float64))
dist.requires_grad = True
# make a torch.Tensor with the desired entropy to compute loss
DE = torch.tensor(desired_entropy, dtype=torch.float64)
DE.requires_grad = False
# define an optimizer over the parameter
optimizer = torch.optim.AdamW([dist], lr=lr)
i = 0
if do_logging:
print('-----------------------------------------------------')
while True:
# optimize the parameter
optimizer.zero_grad()
loss = criterion(dist, DE)
loss.backward()
optimizer.step()
# log the loss
if (i % log_freq == 0):
loss_val = loss.item()
if do_logging:
print(f'loss: {loss_val:.4}')
if loss_val < tol: # we are done if the loss is small enough
break
i += 1 # count iterations
# give up if max_iter is reached
if i > max_iter:
msg = 'Optimization did not converge!'
Warning(msg)
break
# renormalize
final_dist = torch.softmax(dist, dim=0)
# summary of results
if do_logging:
print('-----------------------------------------------------')
print(f'sum of probabilities (should be 1): {final_dist.sum()}')
approx_entropy = -(final_dist * final_dist.log()).sum()
print('-----------------------------------------------------')
print(f'desired entropy: {desired_entropy}')
print(f'true entropy: {approx_entropy.item()}')
print('-----------------------------------------------------')
return final_dist
In [4]:
get_dist(
criterion=MSEAgainstEntropyLoss(),
support_size=10,
desired_entropy=1.5707963267948966, # pi/2 because why not
lr=0.001,
tol=1e-12,
max_iter=100_000,
do_logging=True,
log_freq=200
)
----------------------------------------------------- loss: 0.3188 loss: 0.2091 loss: 0.1094 loss: 0.04938 loss: 0.02052 loss: 0.008044 loss: 0.002983 loss: 0.001043 loss: 0.0003462 loss: 0.0001118 loss: 3.726e-05 loss: 1.406e-05 loss: 6.524e-06 loss: 3.784e-06 loss: 2.589e-06 loss: 1.94e-06 loss: 1.515e-06 loss: 1.204e-06 loss: 9.639e-07 loss: 7.749e-07 loss: 6.247e-07 loss: 5.048e-07 loss: 4.086e-07 loss: 3.313e-07 loss: 2.69e-07 loss: 2.186e-07 loss: 1.778e-07 loss: 1.448e-07 loss: 1.179e-07 loss: 9.614e-08 loss: 7.841e-08 loss: 6.398e-08 loss: 5.223e-08 loss: 4.265e-08 loss: 3.483e-08 loss: 2.846e-08 loss: 2.326e-08 loss: 1.901e-08 loss: 1.554e-08 loss: 1.271e-08 loss: 1.039e-08 loss: 8.501e-09 loss: 6.954e-09 loss: 5.689e-09 loss: 4.654e-09 loss: 3.808e-09 loss: 3.116e-09 loss: 2.55e-09 loss: 2.087e-09 loss: 1.708e-09 loss: 1.398e-09 loss: 1.144e-09 loss: 9.368e-10 loss: 7.668e-10 loss: 6.277e-10 loss: 5.139e-10 loss: 4.207e-10 loss: 3.444e-10 loss: 2.82e-10 loss: 2.308e-10 loss: 1.89e-10 loss: 1.547e-10 loss: 1.267e-10 loss: 1.037e-10 loss: 8.493e-11 loss: 6.954e-11 loss: 5.694e-11 loss: 4.662e-11 loss: 3.817e-11 loss: 3.126e-11 loss: 2.559e-11 loss: 2.096e-11 loss: 1.716e-11 loss: 1.405e-11 loss: 1.151e-11 loss: 9.422e-12 loss: 7.715e-12 loss: 6.318e-12 loss: 5.173e-12 loss: 4.237e-12 loss: 3.469e-12 loss: 2.841e-12 loss: 2.327e-12 loss: 1.905e-12 loss: 1.56e-12 loss: 1.278e-12 loss: 1.047e-12 loss: 8.572e-13 ----------------------------------------------------- sum of probabilities (should be 1): 1.0 ----------------------------------------------------- desired entropy: 1.5707963267948966 true entropy: 1.5707972521604139 -----------------------------------------------------
Out[4]:
tensor([0.0203, 0.0098, 0.2433, 0.0118, 0.0225, 0.0226, 0.0166, 0.3209, 0.0217, 0.3105], dtype=torch.float64, grad_fn=<SoftmaxBackward0>)