Flow matching Knowledge distillation 이용한 아이디어.

v_\theta(x_t, . , . , t) 가 모델이고 v_t(x_t|x_0,y) 가 target일 때 아래와 같은 loss들 사용할 수 있겠다.

\lVert v_\theta(x_t,x_0,y,t) - v_t(x_t|x_0,y)\rVert^2 \lVert v_\theta(x_t,y,y,t) - v_t(x_t|x_0,y)\rVert^2 \tag{original CFM} \lVert v_\theta(x_t,y,y,t) - v_\theta(x_t,x_0,y,t) \rVert^2 \tag{KD}

위의 세가지 조합에 대해 학습 할 수 있다.

만약에 storm 처럼 predictive model D_\phi (y) 가 있다면

\lVert v_\theta(x_t,x_0,y,t) - v_t(x_t|x_0,y)\rVert^2 \lVert v_\theta(x_t,D_\phi(y),y,t) - v_t(x_t|x_0,y)\rVert^2 \lVert v_\theta(x_t,D_\phi(y),y,t) - v_\theta(x_t,x_0,y,t) \rVert^2

그리고 cascading flow를 활용한다면 만약에 D_\theta(x_1) 이라는 결과가 v_\theta 에 의해 만들어진다면 아래와 같은 loss 사용

\lVert v_\theta(x_t,x_0,y,t) - v_t(x_t|x_0,y)\rVert^2 \lVert v_\theta(x_t,y,y,t) - v_t(x_t|x_0,y)\rVert^2 \lVert v_\theta(x_t,D_\theta(x_1),y,t) - v_t(x_t|x_0,y)\rVert^2 \lVert v_\theta(x_t,D_\theta(x_1),y,t) - v_\theta(x_t,x_0,y,t) \rVert^2 \lVert v_\theta(x_t,y,y,t) - v_\theta(x_t,x_0,y,t) \rVert^2

바로 위의 식에서 x_t 를 다시 만들어줘야 하는 부분이 전부다?

Leave a Comment