torch.clamp함수.
WGAN에서 clipping gradient 수행할 때 사용한 함수입니다.
torch.clamp(input, min=None, max=None, *, out=None)
return Tensor
input의 모든 요소들을 [min, max]내로 clamps한다.
clamp가 고정시키다, 꽉죄다 라는 뜻이기에 어떠한 범주 내로 고정시킨다, 꾸겨넣는다, 이런 느낌으로 생각하면 될 것 같습니다
input의 모든 요소()들을 저 식에 대입해 값들을 갖는 Tensor를 return
파라미터 min이 None이면 하한값은 없다고 합니다. max이하의 값으로만 채운다는 거겠죠
min > max이면 모든 값은 max로 채워질 것입니다.
a=torch.randn(4)
# tensor([-1.7120, 0.1734, -0.0478, -0.0922])
torch.clamp(a, min=-0.5, max=0.5)
# tensor([-0.5000, 0.1734, -0.0478, -0.0922])
min=torch.linspace(-1, 1,steps=4)
torch.clamp(a, min=min)
# tensor([-1.0000, 0.1734, 0.3333, 1.0000])
Python
복사