Search

torch.autograd.Function 클래스를 상속받아 autograd function 구현하기

이 방법을 통해 pytorch autograd로 gradient를 계산할 수 있습니다

예시

P3(x)=12(5x33x)P_3(x) = \frac{1}{2}(5x^3-3x)
이 식(3차 르장드르 다항식)의 backpropagation 을 수행하기 위해 autograd Function을 구현한다
미분한 식 → def backward()
P3(x)=32(5x21)P'_3(x)=\frac{3}{2}(5x^2-1)

forward

import torch import math class LegendrePolynomial3(torch.autograd.Function): @staticmethod def forward(ctx, input): ctx.save_for_backward(input) return 0.5 * (5 * input ** 3 - 3 * input)
Python
복사
ctx : context object, 역전파 연산을 위한 정보 저장에 사용
ctx.save_for_backward(input) 으로 backward단계에서 사용할 어떠한 객체도 저장(cache)해둘 수 있다

backward

import torch import math class LegendrePolynomial3(torch.autograd.Function): @staticmethod def forward(ctx, input): ctx.save_for_backward(input) return 0.5 * (5 * input ** 3 - 3 * input) @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors return grad_output * 1.5 * (5 * input ** 2 - 1)
Python
복사
backward에서는 출력에 대한 loss의 gradient를 갖는 tensor를 받아 입력에 대한 loss의 변화도를 계산한다.

학습시

# 입력값과 출력값을 갖는 텐서들을 생성합니다. # requires_grad=False가 기본값으로 설정되어 역전파 단계 중에 이 텐서들에 대한 변화도를 계산할 # 필요가 없음을 나타냅니다. x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) y = torch.sin(x) # 가중치를 갖는 임의의 텐서를 생성합니다. 3차 다항식이므로 4개의 가중치가 필요합니다: # y = a + b * P3(c + d * x) # 이 가중치들이 수렴(convergence)하기 위해서는 정답으로부터 너무 멀리 떨어지지 않은 값으로 # 초기화가 되어야 합니다. # requires_grad=True로 설정하여 역전파 단계 중에 이 텐서들에 대한 변화도를 계산할 필요가 # 있음을 나타냅니다. a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True) c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True) learning_rate = 5e-6 for t in range(2000): # 사용자 정의 Function을 적용하기 위해 Function.apply 메소드를 사용합니다. # Pytorch docs를 보면 forward()호출을 바로 하지 말고 apply메소드를 적용시켜주라고 나와있습니다 # 여기에 'P3'라고 이름을 붙였습니다. P3 = LegendrePolynomial3.apply # 순전파 단계: 연산을 하여 예측값 y를 계산합니다; # 사용자 정의 autograd 연산을 사용하여 P3를 계산합니다. y_pred = a + b * P3(c + d * x) # 손실을 계산하고 출력합니다. loss = (y_pred - y).pow(2).sum() if t % 100 == 99: print(t, loss.item()) # autograd를 사용하여 역전파 단계를 계산합니다. loss.backward() # 경사하강법(gradient descent)을 사용하여 가중치를 갱신합니다. with torch.no_grad(): a -= learning_rate * a.grad b -= learning_rate * b.grad c -= learning_rate * c.grad d -= learning_rate * d.grad # 가중치 갱신 후에는 변화도를 직접 0으로 만듭니다. a.grad = None b.grad = None c.grad = None d.grad = None print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')
Python
복사