TanhNormalΒΆ
- class torchrl.modules.TanhNormal(loc: torch.Tensor, scale: torch.Tensor, upscale: torch.Tensor | Number = 5.0, low: torch.Tensor | Number = - 1.0, high: torch.Tensor | Number = 1.0, event_dims: int | None = None, tanh_loc: bool = False, safe_tanh: bool = True)[source]ΒΆ
Implements a TanhNormal distribution with location scaling.
Location scaling prevents the location to be βtoo farβ from 0 when a
TanhTransform
is applied, but ultimately leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion). In practice, with location scaling the location is computed according to\[loc = tanh(loc / upscale) * upscale.\]- Parameters:
loc (torch.Tensor) β normal distribution location parameter
scale (torch.Tensor) β normal distribution sigma parameter (squared root of variance)
upscale (torch.Tensor or number) β
βaβ scaling factor in the formula:
\[loc = tanh(loc / upscale) * upscale.\]low (torch.Tensor or number, optional) β minimum value of the distribution. Default is -1.0;
high (torch.Tensor or number, optional) β maximum value of the distribution. Default is 1.0;
event_dims (int, optional) β number of dimensions describing the action. Default is 1. Setting
event_dims
to0
will result in a log-probability that has the same shape as the input,1
will reduce (sum over) the last dimension,2
the last two etc.tanh_loc (bool, optional) β if
True
, the above formula is used for the location scaling, otherwise the raw value is kept. Default isFalse
;safe_tanh (bool, optional) β if
True
, the Tanh transform is done βsafelyβ, to avoid numerical overflows. This will currently break withtorch.compile()
.
- property meanΒΆ
Returns the mean of the distribution.
- property modeΒΆ
Returns the mode of the distribution.
- property supportΒΆ
Returns a
Constraint
object representing this distributionβs support.