Link

실습: MSE Loss

이제 앞서 배운 MSE 손실 함수를 파이토치로 직접 구현해볼 차례입니다. 좀 더 쉽게 구현할 수 있게 손실 함수의 수식도 함께 써 놓겠습니다.

\[\begin{gathered} \text{MSE}(\hat{x}_{1:N},x_{1:N})=\frac{1}{N\times{n}}\sum_{i=1}^N{ \|x_i-\hat{x}_i\|_2^2 } \end{gathered}\]

이제 이 수식을 코드로 옮기면 다음과 같습니다.

def mse(x_hat, x):
    # |x_hat| = (batch_size, dim)
    # |x| = (batch_size, dim)
    y = ((x - x_hat)**2).mean()
    
    return y

매우 간단하게 MSE 손실 함수를 구현하였습니다. 그럼 실제 두 텐서 사이의 MSE 손실 값을 구해보도록 할까요?

>>> x = torch.FloatTensor([[1, 1],
...                        [2, 2]])
>>> x_hat = torch.FloatTensor([[0, 0],
...                            [0, 0]])
>>> print(mse(x_hat, x))
tensor(2.5000)

torch.nn.functional 사용하기

사실 당연히 MSE 손실 함수도 파이토치에서 기본적으로 제공합니다. 파이토치 내장 MSE 손실 함수는 다음과 같이 활용할 수 있습니다.

>>> import torch.nn.functional as F
>>> F.mse_loss(x_hat, x)
tensor(2.5000)

그리고 해당 함수는 reduction이라는 인자를 통해 MSE 손실 값을 구할 때 차원 감소 연산(e.g. 평균)에 대한 설정을 할 수 있습니다. sum과 none 등을 선택하여 원하는대로 MSE 손실 함수의 출력 값을 얻을 수 있습니다.

>>> F.mse_loss(x_hat, x, reduction='sum')
tensor(10.)
>>> F.mse_loss(x_hat, x, reduction='none')
tensor([[1., 1.],
        [4., 4.]])

torch.nn 사용하기

torch.nn.functional 이외에도 torch.nn 을 활용하는 방법이 있습니다.

>>> import torch.nn as nn
>>> mse_loss = nn.MSELoss()
>>> mse_loss(x_hat, x)
tensor(2.5000)

사실 이 두 방법의 차이는 거의 없는데, 이 방법을 사용하게 되면 nn.Module 하위 클래스 내부에 선언하여, 계층layer의 하나처럼 취급할 수 있습니다.