Flow matching 정리

Flow matching에 대해 알아보자. 여러 논문이 나오고 variation이 나오지만 나는 가장 쉬운(?) 버전으로 설명을 해보겠다.

Flow matching 기본 세팅

p(x) = N\left(x \mid \bar{\mu}_0, \bar{\sigma}_0^2 I\right) : 아주 쉬운 분포

p_data : data에 관한 분포

Flow matching의 목표

아래와 같은 미분 방정식을 생각하자.

dx_t = v_t(x_t) dt , 0 \leq t \leq 1,  x_0 \sim p(x_0) \tag{1}

위의 미분 방정식에 의해서 x_t 가 만들어진다.

x_t = \int_{0}^t v_\tau(x_\tau) d\tau + x_0 \tag{2}

p_t 를 바로 윗줄에서 나오는 x_t 의 분포라 하자.

Flow matching의 목표는 p_1 = p_{data} 이길 바란다.

Flow matching loss

p_1=p_{data}가 되도록 하는 p_t,v_t가 있다면 아래와 같은 loss를 사용하면 된다.

L_{FM} = E_{t \sim U[0,1], x_t \sim p_t(x_t)} \lVert v_\theta(x_t,t) - v_t(x_t) \rVert^2 \tag{3}

위와 같은 loss (3)를 Flow matching loss라 하는데, 사용하기에는 문제가 있다.

  • p_t 알 턱이 없다.
  • v_t 알 턱이 없다.

따라서 디퓨전 모델에서 했던 것처럼 다른 작업을 친다.

Conditional Flow matching

디퓨전에서 했던것 처럼 비슷한 작업을 하려고 합니다.

data distribution p_{data} 를 이용해서 고정된 z \sim p_{data}을 조건으로 하는 아래와 같은 미분방정식을 생각합니다.

dx_t = v_t(x_t \mid z) dt , x_0 \sim p(x_0) \tag{4}

여기서 p(x_0) = N\left(x_0 \mid \bar{\mu}_0, \bar{\sigma}_0^2 I \right)

(4)에서 x_0 에 대한 transformation이 아래와 같이 만들어진다.

x_t = \int_{0}^t v_\tau(x_\tau \mid z) d\tau + x_0 \tag{5}

(4)와 (5)를 만족하는 x_t 의 분포를 p_t(x_t\mid z) 라 하자. 그러면 p_0(x_0\mid z) = p(x_0) 이된다.

아래와 같이 v_t, p_t 를 정의하자.

p_t(x_t) := \int p_t(x_t \mid z ) p_{data} (z) dz \tag{6} v_t(x_t) := \int \frac{v_t(x_t\mid z) p_t (x_t \mid z) p_{data} (z) } {p_t(x_t)} dz \tag{7}

그러면 (6)에서 나온 v_t 를 가지고 아래와 같은 미분방정식과 x_0 의 transformation x_t 을 생각하자.

dx_t = v_t (x_t) dt, x_0 \sim p_0(x_0)(=p(x_0)) \tag{8} x_t = \int_{0}^t v_\tau(x_\tau) d\tau + x_0 \tag{9}

중요한 사실 1

  • (8),(9)를 만족하는 x_t 는 (6)의 p_t 분포를 갖는다.
  • (6)의 분포는 p_t p_t(x_t) = \int p_t(x_t \mid z) p_{data}(z) dz 이다.
    • 따라서 p_1(x_1 \mid z ) \approx \delta(x_1 -z) 를 따르도록 하면 아래와 같이 데이터 분포를 근사할 수 있다.
p_1(x_1) =\int p_1(x_1 \mid z) p_{data}(z) dz \approx \approx p_{data}(x_1) \tag{10}

Conditional Flow matching loss

위에서 $p_t( \x_t \mid z) , v_t(x_t \mid z)$을 이용해서 아래와 같은 loss 를 만들자.

L_{CFM} = E_{t \sim U[0,1], z \sim p_{data}(z), x_t \sim p_t(x_t\mid z) } \lVert v_\theta (x_t,t) - v_t (x_t \mid z) \rVert^2 \tag{11}

다시 기억을 상기하면 Flow matching을 학습하기 위해 아래와 같은 L_{FM} 을 학습하려고 했었고 불가능 했다.

L_{FM} = E_{t \sim U[0,1], x_t \sim p_t(x_t)} \lVert v_\theta(x_t,t) - v_t(x_t) \rVert^2 \tag{12}

그런데!

중요한 사실 2: L_{CFM}, L_{FM} 의 gradient 동일

그런데 (6),(7)이 만족할 때 아래와 같은 결과를 만족한다.

\nabla_{\theta}(L_{CFM}-L_{FM}) = 0 \tag{13}

즉 힘들게 L_{FM} 최적화할게 아니라 L_{CFM} 구해서 미분해서 최적화 하면 된다.

 

Conditional Flow matching loss 사용해서 학습하자!

아래의 loss를 최적화하기 위하여 이것을 구성하는 v_t(x_t \mid z), p_t(x_t | z) 를 그럴듯하게 구상해야 한다.

L_{CFM} = E_{t \sim U[0,1], z \sim p_{data}(z), x_t \sim p_t(x_t\mid z) } \lVert v_\theta (x_t,t) - v_t (x_t \mid z) \rVert^2 \tag{14}

v_t(x_t \mid z), p_t(x_t \mid z), p_data(z) 과 다른 term들에 대해 상기하면

dx_t = v_t(x_t \mid z) dt , x_0 \sim p(x_0) \tag{15}
  • p(x_0) = N(x_0 \mid \bar{\mu}_0, \bar{\sigma}_0^2 I )
  • p_t(x_t\mid z) (15)를 만족하는 x_t 의 pdf
  • p_t(x_t) := \int p_t(x_t \mid z ) p_{data} (z) dz 
    • 만약에 p_1(x_1 \mid z) = N(x_1 \mid \bar{\mu}_1(z), \bar{\sigma}_1^2 I ) 이고 p_1 (x_1 \mid z) \approx \delta(x_1-z) 라면
    • p_1 (x_1) = \int p_1(x_1 \mid z) p_{data}(z) dz \approx p_{data}(x_1)

주목할점!

  • p_0(x_0 \mid z) = N(x_0 \mid \bar{\mu}_0, \bar{\sigma}_0^2 I ) 으로 세팅하고
  • p_1(x_1 \mid z) = N(x_1 \mid \bar{\mu}_1(z), \bar{\sigma}_1^2 I ) 이길 바란다.
  • 그러면! p_t(x_t \mid z) = N(x_t \mid \mu_t(z), \sigma_t^2 I) 이면 그럴듯하겠다!!

 

그러기 위해 다음을 만족하는 \mu_t(z), \sigma_t 를 정의하자.

  • \mu_0 (z) = \bar{\mu}_0 , \mu_1(z) = \bar{\mu}_1(z)
  • \sigma_0(z) = \bar{\sigma}_0, \sigma_1(z) = \bar{\sigma}_1

그리고 위의 \mu_t(z) , \sigma_t 를 mean과 variance로 갖는 p_t(x_t \mid z) 정의하자.

p_t(x_t\mid z) = N(x_t \mid \mu_t(z), \sigma_t^2 I  ) \tag{16}

 

z \sim p_{data}(z) 가 조건부일때 아래와 같이 x_t 를 정의하면

x_0 \sim p_0(x_0 \mid z) , x_t = \frac{\sigma_t}{\sigma_0}(x_0 - \mu_0) + \mu_t  \tag{17}

위의 (17)에서 x_t 는 자연스럽게 p_t(x_t \mid z) 를 따른다.

이제 v_t (x_t \mid z) 를 얻기 위해 (17)의 x_t 를 아래 식에 넣고 정리하자.

dx_t = v_t(x_t \mid z) dt , x_0 \sim p(x_0) \tag{18}

(18)을 정리하면 아래와 같이 target value를 얻는다. v_t (x_t \mid z) = \frac{\sigma_t\prime}{\sigma_t}(x_t - \mu_t(z)) + \mu_t \prime(z)

이제 아래의 loss를 L_{CFM} 를 사용하기 위한 재료들이 다 모아짐

L_{CFM} = E_{t \sim U[0,1], z \sim p_{data}(z), x_t \sim p_t(x_t\mid z) } \lVert v_\theta (x_t,t) - v_t (x_t \mid z) \rVert^2 \tag{14}
  • p_0(x_0) = N(x_0 \mid \bar{\mu}_0, \bar{\sigma}_0^2 I)
  • z \sim p_{data}(z) 일때 p_t(x_t \mid z) = N(x_t \mid \mu_t(z), \sigma_t^2(z))
    • \mu_0(z) = \bar{\mu}_0 , \sigma_0 = \bar{\sigma}_0 이고 p_1 (x_1 \mid z) \approx \delta(x_1 -z) 되도록만 잡으면 됨
  • v_t(x_t \mid z) = \frac{\sigma_t \prime}{\sigma_t} ( x_t - \mu_t(z) ) + \mu_t \prime (z)

 

Conditional Flow matching 세팅하는 연습

linear mean, linear std 원조버전

p (x_0) = N(x_0 \mid 0, \sigma^2 I) , p_{data}(z) 데이터 분포

p_t(x_t \mid z) = N(x_t \mid \mu_t(z), \sigma_t^2 I) 를 세팅하고자 한다.

\mu_t(z) = tz, \sigma_t = (1-t)\sigma 으로 세팅하자.

그러면!

v_t(x_t \mid z) = \frac{\sigma_t\prime}{\sigma_t} (x_t - \mu_t(z)) + \mu_t \prime(z) , v_t(x_t \mid z) = \frac{z-x_t}{1-t}

x_t \sim p_t(x_t \mid z) 일 때 v_t(x_t \mid z) =

x_0 \sim p_0(x_0 \mid z) 일 때 v_t(x_t \mid z) =

linear mean, linear std speech enhancement 위한 버전

s,y 를 각각 clean speech와 그것의 noisy speech라 하자.

p(x_0) = N(x_0 \mid y, \sigma^2 I) , p_{clean}(s\mid y) 의 y의 clean speech 분포

p_t(x_t \mid s) = N(x_t \mid \mu_t(s), \sigma_t^2 I) 를 세팅하고자 한다.

\mu_t(s) =(1-t)y+ ts, \sigma_t = (1-t)\sigma 으로 세팅하자.

그러면!

v_t(x_t \mid s) = \frac{\sigma_t\prime}{\sigma_t} (x_t - \mu_t(s)) + \mu_t \prime(s) , v_t(x_t \mid s) = 

x_t \sim p_t(x_t \mid s) 일 때 v_t(x_t \mid s) =

x_0 \sim p_0(x_0 \mid s) 일 때 v_t(x_t \mid s) =

 

Leave a Comment