Single-Label Entropy Loss in Torch
machine-learning
In this notebook, I look at loss functions for the binary and multi-class case and within each case, show their equivalence. See a previous blog post for an explanation about why you cannot show equivalence across the two cases. Within each case, there are two types of loss functions in PyTorch that can each be implemented in two ways (losses can be calculated using a class instantiation or the corresponding function directly).
There are two cases for loss calculation: 1. Binary and 2. Multi-class classification.
Imports¶
import torch
from torch.nn import functional as F
1. Binary Case¶
1a. Synthetic Data¶
logits = torch.randn(10,1)
probs = torch.sigmoid(logits)
log_probs = torch.log(probs)
preds = (probs>0.5).to(torch.float32)
targets = torch.randint(high=2, size=(10,1), dtype=torch.float32)
lst_vars = [logits, probs, log_probs, preds, targets]
[x.shape for x in lst_vars], [x.dtype for x in lst_vars]
([torch.Size([10, 1]), torch.Size([10, 1]), torch.Size([10, 1]), torch.Size([10, 1]), torch.Size([10, 1])], [torch.float32, torch.float32, torch.float32, torch.float32, torch.float32])
1b: Loss Functions¶
There are two classes and corresponding functions:
BCELoss
, which uses the functionbinary_cross_entropy
BCELossWithLogitsLoss
, which uses the functionbinary_cross_entropy_with_logits
Importantly, these take in one-dimensional vectors as inputs. The with_logits
version performs sigmoid, not softmax over the one-dimensional input.
BCE Loss and binary_cross_entropy
print(f'The Class Instantiation: {torch.nn.BCELoss()(probs, targets)}\n'+
f'The Function the class uses: {F.binary_cross_entropy(probs, targets)}')
The Class Instantiation: 0.9423926472663879 The Function the class uses: 0.9423926472663879
BCE Loss with Logits and binary_cross_entropy_with_logits
print(f'The Class Instantiation: {torch.nn.BCEWithLogitsLoss()(logits, targets)}\n'+
f'The Function the class uses: {F.binary_cross_entropy_with_logits(logits, targets)}')
The Class Instantiation: 0.9423925280570984 The Function the class uses: 0.9423925280570984
2 Multi-Classiciation Case¶
2a. Synthetic Data¶
num_classes = 2
logits = torch.randn(10,num_classes)
probs = torch.softmax(logits, dim=1)
log_probs = torch.log(probs)
preds = probs.argmax(dim=1).to(torch.float32)
targets = torch.randint(high=num_classes, size=(10,1), dtype=torch.int64).squeeze()
lst_vars = [logits, probs, log_probs, preds, targets]
[x.shape for x in lst_vars], [x.dtype for x in lst_vars]
([torch.Size([10, 2]), torch.Size([10, 2]), torch.Size([10, 2]), torch.Size([10]), torch.Size([10])], [torch.float32, torch.float32, torch.float32, torch.float32, torch.int64])
2b: Loss Functions¶
There are two classes and corresponding functions:
CrossEntropyLoss
, which uses the functioncross_entropy
NLLLoss
, which uses the functionnll_loss
Note that CrossEntropyLoss
takes logits as inputs and NLLLoss
takes log probabilities as inputs.
Cross Entropy and cross_entropy
print(f'The Class Instantiation: {torch.nn.CrossEntropyLoss()(logits, targets)}\n'+
f'The Function the class uses: {F.cross_entropy(logits, targets)}')
The Class Instantiation: 0.8172963261604309 The Function the class uses: 0.8172963261604309
NLL Loss and nll_loss
print(f'The Class Instantiation: {torch.nn.NLLLoss()(log_probs, targets)}\n'+
f'The Function the class uses: {F.nll_loss(log_probs, targets)}')
The Class Instantiation: 0.8172963261604309 The Function the class uses: 0.8172963261604309