Bridge matching이란 무엇일까?

Bridge matching에 대해 알아보자. Bridge matching은 한 분포 \pi_0 로 부터 \pi_1 를 잇는 bridge를 만드는 것으로 보면 된다. 아래의 SDE를 표현하는 Markov path measure \mathbb{Q} 가 있다고 하자.

dX_t = f_t(X_t)dt+\sigma_t dB_t, X_0~\mathbb{Q}_0 = \pi_0

위의 SDE에서 t=0,1 에 각각 condition x_0, x_1 이 있다고 하자. 이에 따른 \mathbb{Q}_{|0,1} (\cdot | x_0, x_1) 는 아래와 같은 diffusion bridge를 따른다.

dX_t^{0,T} = \{f_t(X_t^{0,1})+\sigma_t^2 \nabla \log \mathbb{Q}_{1|t} (x_1 | X_t^{0,1})\}dt+ \sigma_t dB_t, X_0^{0,1}=x_0, X_1^{0,1}=x_1 

이제 \Pi_{0,1} = \pi_0 \otimes \pi_1 를 independent coupling이라 하고 \Pi = \Pi_{0,1} \mathbb{Q}_{|0,1}를 mixture of bridge라고 합시다.

Bridge matching의 목표는 모든 t 에 대해서 Y_t ~ \Pi_t 이고 다음을 만족하는 Stochastic process Y_t 를 찾는 것입니다.

dY_t = \{ f_t(Y_t)+v_t(Y_t)\} dt + \sigma_t dB_t

이렇게 되면 Y_1 ~ \pi_1 일 때 Y_0 ~ \pi_0 에서 시작해 위의 SDE를 이용해서 \pi_1 의 sample을 만들 수 있습니다.

여기서 v_t 의 optimal은 v_t^\star (x_t) = \sigma_t^2 \mathbb{E}_{\Pi_{T|t}}\left[\nabla \log \mathbb{Q}_{T|t}(X_T|X_t) | X_t = x_t \right] 이고 이것을 찾기 어렵기 때문에 아래와 같은 loss를 사용합니다.

Loss = \mathbb{E}_{\Pi_{t,T}} \lVert \sigma_t^2 \nabla \log \mathbb{Q}_{T|t}(X_T|X_t) - v_\theta (t,X_t) \rVert^2

만약에 f_t = 0 , \sigma_t = \sigma 라면 위의 regression loss는 flow matching을 위한 loss와 같아집니다.

Leave a Comment