torch.clamp()函数用于对输入张量进行截断操作,将张量中的每个元素限制在指定的范围内。

其语法为:

torch.clamp(input, min, max, out=None) -> Tensor

其中,参数的含义如下:

  • input:输入张量。
  • min:张量中的最小值。如果为None,则表示不对最小值进行限制。
  • max:张量中的最大值。如果为None,则表示不对最大值进行限制。
  • out:输出张量。

torch.clamp()函数返回一个新的张量,其中每个元素都被截断在[min, max]的范围内。如果minmaxNone,则对应的限制条件被忽略。

下面是一个使用torch.clamp()函数的示例:

import torchx = torch.randn(2, 3)print(x)y = torch.clamp(x, min=-0.5, max=0.5)print(y)

输出结果为:

tensor([[-0.3138, -0.1604, -0.4374],[-1.0861, -0.2837,1.1688]])tensor([[-0.3138, -0.1604, -0.4374],[-0.5000, -0.2837,0.5000]])

可以看到,torch.clamp()函数将x张量中的元素限制在了[-0.5, 0.5]的范围内。