Tutorial 3: Learn about criterions¶
In this folder “fgvclib/criterions” we provide different loss functions for the fgvclib.
We provide four loss functions, cross_entropy_loss
, binary_cross_entropy_loss
, mean_square_error_loss
and mutual_channel_loss
You can choose which loss function you want to use, and you should set it in the “./configs”. For more details about the configs please see FGVC Configs
Loss functions | Name |
---|---|
cross entropy loss | cross_entropy_loss |
binary cross entropy loss | binary_cross_entropy_loss |
mean square error loss | mean_square_error_loss |
mutual channel loss | mutual_channel_loss |
Base loss function¶
cross_entropy_loss
, binary_cross_entropy_loss
, mean_square_error_loss
are the base loss functions, and they are from PyTorch.
“fgvclib/criterions/base_loss.py”: provides the base loss functions.
cross_entropy_loss: Build the cross entropy loss function.
Args:
cfg (CfgNode)
: The root node of config.Return:
nn.Module
: The loss function.
binary_cross_entropy_loss:Build the binary cross entropy loss function.
Args:
cfg (CfgNode)
: The root node of config.Return:
nn.Module
: The loss function.
mean_square_error_loss: Build the mean square error loss function.
Args:
cfg (CfgNode)
: The root node of config.Return:
nn.Module
: The loss function.
Mutual channel loss¶
“fgvclib/criterions/mutual_channel_loss.py”: provides the mutual channel loss function which was proposed on “The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification”.
class MutualChannelLoss
: The mutual channel loss function.
Args:
height (int)
: The kernel size of average pooling.cnum (int)
: Channel numbers per class.div_weight (float)
: The weight for diversity part loss.dis_weight (float)
: The weight for discriminality part loss.
Utils¶
In the “fgvclib/criterions/utils.py”, we design a class named LossItem
, and two functions, compute_loss_value
and detach_loss_value
.
LossItem: A dataclass object for store training loss
Args:
name (string): The loss item name. value (torch.Tensor): The value of loss. weight (float, optional): The weight of current loss item, default is 1.0.
compute_loss_value: A dataclass object for store training loss
Args:
items (List[LossItem]): The loss items.
Return:
Tensor: The total loss value.
detach_loss_value: Detach loss value from GPU.
Args:
items (List[LossItem]): The loss items.
Return:
Dict: A loss information dict whose key is loss name, value is loss value.
The example of using the criterions¶
Build loss functions for training.¶
In the “fgvclib/apis/build.py”, use the “fgvclib.criterions” to build loss functions for training. You can choose the loss function name criterion_cfg['name']
from cross_entropy_loss
, cross_entropy_loss
, mean_square_error_loss
and mutual_channel_loss
.
from fgvclib.criterions import get_criterion
def build_criterion(criterion_cfg: CfgNode) -> nn.Module:
criterion_builder = get_criterion(criterion_cfg['name'])
criterion = criterion_builder(cfg=tltd(criterion_cfg['args']))
return criterion
Calculate loss functions.¶
Following is about how to calculate the loss, and you can replace the loss functions.
from fgvclib.criterions.utils import LossItem
losses = list()
losses.append(LossItem(name='cross_entropy_loss', value=self.criterions['cross_entropy_loss']['fn'](x, targets)))
Define the forward.¶
Set the ResNet50 for example.
from fgvclib.criterions.utils import LossItem
def forward(self, x, targets=None):
x = self.infer(x)
if self.training:
losses = list()
osses.append(LossItem(name='cross_entropy_loss', value=self.criterions['cross_entropy_loss']['fn'](x, targets)))
return x, losses
return x