Rate this Page
โ˜… โ˜… โ˜… โ˜… โ˜…

Flatten#

class torch.nn.Flatten(start_dim=1, end_dim=-1)[source]#

Flattens a contiguous range of dims into a tensor.

For use with Sequential, see torch.flatten() for details.

Shape:
  • Input: (โˆ—,Sstart,...,Si,...,Send,โˆ—)(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *),โ€™ where SiS_{i} is the size at dimension ii and โˆ—* means any number of dimensions including none.

  • Output: (โˆ—,โˆi=startendSi,โˆ—)(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *).

Parameters
  • start_dim (int) โ€“ first dim to flatten (default = 1).

  • end_dim (int) โ€“ last dim to flatten (default = -1).

Examples::
>>> input = torch.randn(32, 1, 5, 5)
>>> # With default parameters
>>> m = nn.Flatten()
>>> output = m(input)
>>> output.size()
torch.Size([32, 25])
>>> # With non-default parameters
>>> m = nn.Flatten(0, 2)
>>> output = m(input)
>>> output.size()
torch.Size([160, 5])