torch.nn.utils.fuse_linear_bn_weights#
- torch.nn.utils.fuse_linear_bn_weights(linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b)[source]#
Fuse linear module parameters and BatchNorm module parameters into new linear module parameters.
- Parameters
linear_w (torch.Tensor) โ Linear weight.
linear_b (Optional[torch.Tensor]) โ Linear bias.
bn_rm (torch.Tensor) โ BatchNorm running mean.
bn_rv (torch.Tensor) โ BatchNorm running variance.
bn_eps (float) โ BatchNorm epsilon.
bn_w (torch.Tensor) โ BatchNorm weight.
bn_b (torch.Tensor) โ BatchNorm bias.
- Returns
Fused linear weight and bias.
- Return type
Tuple[torch.nn.Parameter, torch.nn.Parameter]