Flow matching에 대해 알아보자. 여러 논문이 나오고 variation이 나오지만 나는 가장 쉬운(?) 버전으로 설명을 해보겠다.
Flow matching 기본 세팅
p(x) = N\left(x \mid \bar{\mu}_0, \bar{\sigma}_0^2 I\right) : 아주 쉬운 분포
p_data : data에 관한 분포
Flow matching의 목표
아래와 같은 미분 방정식을 생각하자.
dx_t = v_t(x_t) dt , 0 \leq t \leq 1, x_0 \sim p(x_0) \tag{1}위의 미분 방정식에 의해서 x_t 가 만들어진다.
x_t = \int_{0}^t v_\tau(x_\tau) d\tau + x_0 \tag{2}p_t 를 바로 윗줄에서 나오는 x_t 의 분포라 하자.
Flow matching의 목표는 p_1 = p_{data} 이길 바란다.
Flow matching loss
p_1=p_{data}가 되도록 하는 p_t,v_t가 있다면 아래와 같은 loss를 사용하면 된다.
L_{FM} = E_{t \sim U[0,1], x_t \sim p_t(x_t)} \lVert v_\theta(x_t,t) - v_t(x_t) \rVert^2 \tag{3}위와 같은 loss (3)를 Flow matching loss라 하는데, 사용하기에는 문제가 있다.
- p_t 알 턱이 없다.
- v_t 알 턱이 없다.
따라서 디퓨전 모델에서 했던 것처럼 다른 작업을 친다.
Conditional Flow matching
디퓨전에서 했던것 처럼 비슷한 작업을 하려고 합니다.
data distribution p_{data} 를 이용해서 고정된 z \sim p_{data}을 조건으로 하는 아래와 같은 미분방정식을 생각합니다.
dx_t = v_t(x_t \mid z) dt , x_0 \sim p(x_0) \tag{4}여기서 p(x_0) = N\left(x_0 \mid \bar{\mu}_0, \bar{\sigma}_0^2 I \right)
(4)에서 x_0 에 대한 transformation이 아래와 같이 만들어진다.
x_t = \int_{0}^t v_\tau(x_\tau \mid z) d\tau + x_0 \tag{5}(4)와 (5)를 만족하는 x_t 의 분포를 p_t(x_t\mid z) 라 하자. 그러면 p_0(x_0\mid z) = p(x_0) 이된다.
아래와 같이 v_t, p_t 를 정의하자.
p_t(x_t) := \int p_t(x_t \mid z ) p_{data} (z) dz \tag{6} v_t(x_t) := \int \frac{v_t(x_t\mid z) p_t (x_t \mid z) p_{data} (z) } {p_t(x_t)} dz \tag{7}그러면 (6)에서 나온 v_t 를 가지고 아래와 같은 미분방정식과 x_0 의 transformation x_t 을 생각하자.
dx_t = v_t (x_t) dt, x_0 \sim p_0(x_0)(=p(x_0)) \tag{8} x_t = \int_{0}^t v_\tau(x_\tau) d\tau + x_0 \tag{9}중요한 사실 1
- (8),(9)를 만족하는 x_t 는 (6)의 p_t 분포를 갖는다.
- (6)의 분포는 p_t 는 p_t(x_t) = \int p_t(x_t \mid z) p_{data}(z) dz 이다.
- 따라서 p_1(x_1 \mid z ) \approx \delta(x_1 -z) 를 따르도록 하면 아래와 같이 데이터 분포를 근사할 수 있다.
Conditional Flow matching loss
위에서 $p_t( \x_t \mid z) , v_t(x_t \mid z)$을 이용해서 아래와 같은 loss 를 만들자.
L_{CFM} = E_{t \sim U[0,1], z \sim p_{data}(z), x_t \sim p_t(x_t\mid z) } \lVert v_\theta (x_t,t) - v_t (x_t \mid z) \rVert^2 \tag{11}다시 기억을 상기하면 Flow matching을 학습하기 위해 아래와 같은 L_{FM} 을 학습하려고 했었고 불가능 했다.
L_{FM} = E_{t \sim U[0,1], x_t \sim p_t(x_t)} \lVert v_\theta(x_t,t) - v_t(x_t) \rVert^2 \tag{12}그런데!
중요한 사실 2: L_{CFM}, L_{FM} 의 gradient 동일
그런데 (6),(7)이 만족할 때 아래와 같은 결과를 만족한다.
\nabla_{\theta}(L_{CFM}-L_{FM}) = 0 \tag{13}즉 힘들게 L_{FM} 최적화할게 아니라 L_{CFM} 구해서 미분해서 최적화 하면 된다.
Conditional Flow matching loss 사용해서 학습하자!
아래의 loss를 최적화하기 위하여 이것을 구성하는 v_t(x_t \mid z), p_t(x_t | z) 를 그럴듯하게 구상해야 한다.
L_{CFM} = E_{t \sim U[0,1], z \sim p_{data}(z), x_t \sim p_t(x_t\mid z) } \lVert v_\theta (x_t,t) - v_t (x_t \mid z) \rVert^2 \tag{14}v_t(x_t \mid z), p_t(x_t \mid z), p_data(z) 과 다른 term들에 대해 상기하면
dx_t = v_t(x_t \mid z) dt , x_0 \sim p(x_0) \tag{15}- p(x_0) = N(x_0 \mid \bar{\mu}_0, \bar{\sigma}_0^2 I )
- p_t(x_t\mid z) (15)를 만족하는 x_t 의 pdf
- p_t(x_t) := \int p_t(x_t \mid z ) p_{data} (z) dz
- 만약에 p_1(x_1 \mid z) = N(x_1 \mid \bar{\mu}_1(z), \bar{\sigma}_1^2 I ) 이고 p_1 (x_1 \mid z) \approx \delta(x_1-z) 라면
- p_1 (x_1) = \int p_1(x_1 \mid z) p_{data}(z) dz \approx p_{data}(x_1)
주목할점!
- p_0(x_0 \mid z) = N(x_0 \mid \bar{\mu}_0, \bar{\sigma}_0^2 I ) 으로 세팅하고
- p_1(x_1 \mid z) = N(x_1 \mid \bar{\mu}_1(z), \bar{\sigma}_1^2 I ) 이길 바란다.
- 그러면! p_t(x_t \mid z) = N(x_t \mid \mu_t(z), \sigma_t^2 I) 이면 그럴듯하겠다!!
그러기 위해 다음을 만족하는 \mu_t(z), \sigma_t 를 정의하자.
- \mu_0 (z) = \bar{\mu}_0 , \mu_1(z) = \bar{\mu}_1(z)
- \sigma_0(z) = \bar{\sigma}_0, \sigma_1(z) = \bar{\sigma}_1
그리고 위의 \mu_t(z) , \sigma_t 를 mean과 variance로 갖는 p_t(x_t \mid z) 정의하자.
p_t(x_t\mid z) = N(x_t \mid \mu_t(z), \sigma_t^2 I ) \tag{16}
z \sim p_{data}(z) 가 조건부일때 아래와 같이 x_t 를 정의하면
x_0 \sim p_0(x_0 \mid z) , x_t = \frac{\sigma_t}{\sigma_0}(x_0 - \mu_0) + \mu_t \tag{17}위의 (17)에서 x_t 는 자연스럽게 p_t(x_t \mid z) 를 따른다.
이제 v_t (x_t \mid z) 를 얻기 위해 (17)의 x_t 를 아래 식에 넣고 정리하자.
dx_t = v_t(x_t \mid z) dt , x_0 \sim p(x_0) \tag{18}(18)을 정리하면 아래와 같이 target value를 얻는다. v_t (x_t \mid z) = \frac{\sigma_t\prime}{\sigma_t}(x_t - \mu_t(z)) + \mu_t \prime(z)
이제 아래의 loss를 L_{CFM} 를 사용하기 위한 재료들이 다 모아짐
L_{CFM} = E_{t \sim U[0,1], z \sim p_{data}(z), x_t \sim p_t(x_t\mid z) } \lVert v_\theta (x_t,t) - v_t (x_t \mid z) \rVert^2 \tag{14}- p_0(x_0) = N(x_0 \mid \bar{\mu}_0, \bar{\sigma}_0^2 I)
- z \sim p_{data}(z) 일때 p_t(x_t \mid z) = N(x_t \mid \mu_t(z), \sigma_t^2(z))
- \mu_0(z) = \bar{\mu}_0 , \sigma_0 = \bar{\sigma}_0 이고 p_1 (x_1 \mid z) \approx \delta(x_1 -z) 되도록만 잡으면 됨
- v_t(x_t \mid z) = \frac{\sigma_t \prime}{\sigma_t} ( x_t - \mu_t(z) ) + \mu_t \prime (z)
Conditional Flow matching 세팅하는 연습
linear mean, linear std 원조버전
p (x_0) = N(x_0 \mid 0, \sigma^2 I) , p_{data}(z) 데이터 분포
p_t(x_t \mid z) = N(x_t \mid \mu_t(z), \sigma_t^2 I) 를 세팅하고자 한다.
\mu_t(z) = tz, \sigma_t = (1-t)\sigma 으로 세팅하자.
그러면!
v_t(x_t \mid z) = \frac{\sigma_t\prime}{\sigma_t} (x_t - \mu_t(z)) + \mu_t \prime(z) , v_t(x_t \mid z) = \frac{z-x_t}{1-t}
x_t \sim p_t(x_t \mid z) 일 때 v_t(x_t \mid z) =
x_0 \sim p_0(x_0 \mid z) 일 때 v_t(x_t \mid z) =
linear mean, linear std speech enhancement 위한 버전
s,y 를 각각 clean speech와 그것의 noisy speech라 하자.
p(x_0) = N(x_0 \mid y, \sigma^2 I) , p_{clean}(s\mid y) 의 y의 clean speech 분포
p_t(x_t \mid s) = N(x_t \mid \mu_t(s), \sigma_t^2 I) 를 세팅하고자 한다.
\mu_t(s) =(1-t)y+ ts, \sigma_t = (1-t)\sigma 으로 세팅하자.
그러면!
v_t(x_t \mid s) = \frac{\sigma_t\prime}{\sigma_t} (x_t - \mu_t(s)) + \mu_t \prime(s) , v_t(x_t \mid s) =
x_t \sim p_t(x_t \mid s) 일 때 v_t(x_t \mid s) =
x_0 \sim p_0(x_0 \mid s) 일 때 v_t(x_t \mid s) =