Diffusion model에서 DSM (Denoising Score Matching) loss를 사용하는 이유에 대해

Diffusion model에서 DSM (Denoising Score Matching) loss를 사용하는 이유에 대해알아보자. 이 글을 읽기전에 아래 글을 숙지해오길 바란다.

SDE(Stochastic Differential Equation)를 활용한 Diffusion model에 대하여

Diffusion model의 score 학습 방법

Diffusion model의 목적

Diffusion model에서는 각 시점 t 에 대한 x_t 의 score \nabla_{x_t} \log p_t (x_t) 를 score model s_\theta (x_t, t) 를 이용해서 추정하는 것을 목적으로 한다.

실제로 다뤄야 하는 loss, SM (score matching) loss

그러면 \nabla{x_t} \log p_t (x_t) 를 추정하기 위하여 아래와 같은 loss를 사용하는 것이 합당할 것이다.

\begin{equation} \mathcal{L}_{SM} (\theta) = E_{t,x_t \sim p_t(x_t)} \lambda(t) \lVert s_\theta (x_t,t)-\nabla_{x_t} \log p_t(x_t) \rVert^2 \end{equation} \tag{SM loss}

여기서 \lambda(t) 는 loss에 대한 weight로 사용자가 설정해야 할 값이다. (SM Loss)를 학습하면 될것 같은데 안타깝게도 (SM loss)에서 \nabla{x_t} \log p_t(x_t) 는 대다수의 경우 알턱이 없다. SM loss 를 사용하는 대신에 아래와 같은 (DSM loss)를 사용한다.

학습에 사용하는 DSM loss

이렇게 (DSM loss)를 정의했다.

\begin{equation} \mathcal{L}_{DSM} (\theta) = E_{t,x_0 \sim p_{data}(x_0) , x_t \sim p_t(x_t|x_0)} \lambda(t) \lVert s_\theta (x_t,t)-\nabla_{x_t} \log p_t(x_t|x_0) \rVert^2 \end{equation} \tag{SM loss}

그러면 (DSM loss)와 (SM loss)의 관계는 무엇일까? (DSM loss)와 (SM loss)는 동일할까? 여기서 p_{data}(x_0) = p_0(x_0) 이라 가정하자.

SM loss 대신 DSM loss 사용해도 되는 이유

SM loss를 전개해서 보자. < , > 는 inner product 표시이다.

\begin{equation} \mathcal{L}_{SM}(\theta) = E_{t, x_t \sim p_t(x_t)} \lambda(t) \lVert s_\theta (x_t, t) \rVert^2 - 2 E_{t,x_t \sim p_t(x_t)} \lambda(t) < s_\theta (x_t,t), \nabla_{x_t} \log p_t (x_t) > + E_{t,x_t \sim p_t(x_t)} \lambda(t) \lVert \nabla_{x_t} \log p_t (x_t) \rVert^2 \end{equation} \tag{SM loss 전개}

DSM loss를 전개해서 보자.

\begin{equation} \mathcal{L}_{DSM}(\theta) = E_{t, x_t \sim p_t(x_t)} \lambda(t) \lVert s_\theta (x_t, t) \rVert^2 - 2 E_{t, x_0 \sim p_{data}(x_0), x_t \sim p_t(x_t|x_0)} \lambda(t) < s_\theta (x_t,t), \nabla_{x_t} \log p_t (x_t|x_0) > + E_{t, x_0 \sim p_{data}(x_0) , x_t \sim p_t(x_t|x_0)} \lambda(t) \lVert \nabla_{x_t} \log p_t (x_t|x_0) \rVert^2 \end{equation} \tag{DSM loss 전개}

여기서 두번째 항에서 expectation을 보자. 몇가지 조건하에 조작을 하면 아래와 같이 변한다.

E_{t, x_0 \sim p_{data} (x_0) , x_t \sim p_t(x_t |x_0)} \lambda(t) <s_\theta(x_t, t), \nabla_{x_t} \log p_{0t} (x_t | x_0)> = E_{t} \int \int \lambda(t) <s_\theta (x_t, t) , \nabla_{x_t} \log p_{0t}(x_t|x_0)> p_{data}(x_0) p_{0t}(x_t|x_0) dx_t dx_0 = E_t \int \int \lambda(t) < s_\theta(x_t, t), \frac{ \nabla_{x_t} p_{0t} (x_t |x_0)}{p_{0t} (x_t |x_0)}> p_{0t}(x_t | x_0) p_{data}(x_0) dx_t dx_0

여기서  마지막 term에서 p_{0t}(x_t|x_0) = p_{0,t} (x_0, x_t) / p_{data} (x_0) , \nabla_{x_t} \log p_t(x_t) = (\nabla{x_t}(p_t(x_t)))/p_t(x_t)라는 사실을 사용하면

E_t \int \int \lambda(t) < s_\theta(x_t, t), \frac{ \nabla_{x_t} p_{0t} (x_t |x_0)}{p_{0t} (x_t |x_0)}> p_{0t}(x_t | x_0) p_{data}(x_0) dx_t dx_0 = E_t \int \lambda(t) <s_\theta(x_t,t) , \nabla_{x_t} \log p_t(x_t) > p_t (x_t) dx_t= E_{t,x_t\sim p_t(x_t) } \lambda(t) < s_\theta (x_t,t) , \nabla_{x_t} \log p_t (x_t)>

그리고 DSM loss를 다시 써보자.

\begin{equation} \mathcal{L}_{DSM}(\theta) = E_{t, x_t \sim p_t(x_t)} \lambda(t) \lVert s_\theta (x_t, t) \rVert^2 - 2 E_{t,x_t\sim p_t(x_t) } \lambda(t) < s_\theta (x_t,t) , \nabla_{x_t} \log p_t (x_t)> + E_{t, x_0 \sim p_{data}(x_0) , x_t \sim p_t(x_t|x_0)} \lambda(t) \lVert \nabla_{x_t} \log p_t (x_t) \rVert^2 \end{equation} \tag{DSM loss 전개}

결국엔 (SM loss 전개)와 (DSM loss 전개)에서 DSM loss와 SM loss의 parameter \theta 에 대한 미분값이 동일함을 알수 있다. 딥러닝 모델을 학습 할 때 미분을 해서 gradient descent 기반의 최적화를 사용하므로 DSM loss 를 미분해서 최적화를 하나 SM loss를 미분해서 최적화를 하나 같은 효과를 낸다. 따라서 DSM loss를 SM loss 대신 사용해도 된다는 것을 알 수 있다.

Leave a Comment