실습: MSE Loss
이제 앞서 배운 MSE 손실 함수를 파이토치로 직접 구현해볼 차례입니다. 좀 더 쉽게 구현할 수 있게 손실 함수의 수식도 함께 써 놓겠습니다.
\[\begin{gathered} \text{MSE}(\hat{x}_{1:N},x_{1:N})=\frac{1}{N\times{n}}\sum_{i=1}^N{ \|x_i-\hat{x}_i\|_2^2 } \end{gathered}\]이제 이 수식을 코드로 옮기면 다음과 같습니다.
def mse(x_hat, x):
# |x_hat| = (batch_size, dim)
# |x| = (batch_size, dim)
y = ((x - x_hat)**2).mean()
return y
매우 간단하게 MSE 손실 함수를 구현하였습니다. 그럼 실제 두 텐서 사이의 MSE 손실 값을 구해보도록 할까요?
>>> x = torch.FloatTensor([[1, 1],
... [2, 2]])
>>> x_hat = torch.FloatTensor([[0, 0],
... [0, 0]])
>>> print(mse(x_hat, x))
tensor(2.5000)
torch.nn.functional 사용하기
사실 당연히 MSE 손실 함수도 파이토치에서 기본적으로 제공합니다. 파이토치 내장 MSE 손실 함수는 다음과 같이 활용할 수 있습니다.
>>> import torch.nn.functional as F
>>> F.mse_loss(x_hat, x)
tensor(2.5000)
그리고 해당 함수는 reduction이라는 인자를 통해 MSE 손실 값을 구할 때 차원 감소 연산(e.g. 평균)에 대한 설정을 할 수 있습니다. sum과 none 등을 선택하여 원하는대로 MSE 손실 함수의 출력 값을 얻을 수 있습니다.
>>> F.mse_loss(x_hat, x, reduction='sum')
tensor(10.)
>>> F.mse_loss(x_hat, x, reduction='none')
tensor([[1., 1.],
[4., 4.]])
torch.nn 사용하기
torch.nn.functional 이외에도 torch.nn 을 활용하는 방법이 있습니다.
>>> import torch.nn as nn
>>> mse_loss = nn.MSELoss()
>>> mse_loss(x_hat, x)
tensor(2.5000)
사실 이 두 방법의 차이는 거의 없는데, 이 방법을 사용하게 되면 nn.Module 하위 클래스 내부에 선언하여, 계층layer의 하나처럼 취급할 수 있습니다.