딥러닝 모델 학습을 하다보면 이미 성능이 입증된 모델의 부분부분을 따와서 학습하고 싶을 때가 있다. 실제로도 이런 방식을 많이 사용한다. 성능 좋은 모델의 일부를 따와서 그곳의 일부 혹은 전체를 Freezing (동결)한담에 다른 task를 위한 layer 붙히고 학습하는 방법이 있다. 원리자체는 간단한데 이것을 코드로 실현하자고 하니 골치가 아팠다. 그렇지만 검색을 하다보니 파이토치를 이용해 파인튜닝을 하는 방법에 대해 알게 되었다. 어떻게 하는 것일까? 이제 슬슬 알아보도록 하자.
모델의 state_dict 만 저장 된 상황
만약에 내가 사용할 pre-trained 모델이 아래와 같이 state_dict 형태로 저장되어있다고 하자.
pre_trained_moel = ... #pre-trained model
saved_dict = {'model': pre_trained_model.state_dict()}
torch.save(saved_dict, path) #path 는 저장경로
이럴 때 위에 저장된 모델의 구조에서 내가 원하는 layer만 따오고 그 layer는 얼린 상태에서 학습을 진행하려고 한다.
model = myModel() # 학습할 모델
checkpoint = torch.load(path) #pre-trained의 모델의 위치를 입력하여 모델이 입력되어있는 dictionary 불러온다.
pre_state_dict = checkpoint('model') #모델을 불러온다.
model_state_dict = model.state_dict() # 모델의 state를 불러온다.
위와 같이 하면 내가 학습하고자 하는 모델을 불러왔고 학습 시 사용하게 될 pre-trained model을 불러왔다. 만약에 내가 특정 layer 만 사용하려고 한다면 아래와 같이 하면된다. 내가 원하는 layer의 이름을 입력해서 pre-trained 모델에서 layer를 따오고 내가 학습하고자 하는 모델에서 원하는 layer를 불러와서 그곳에 넣으면 된다. 그런데 여기서 pre-trained에서 사용하고자 하는 layer의 이름과 동일한 이름의 layer가 내가 학습하고자 하는 모델의 layer에도 있어야 한다.
layer_name = "pre-trained에서따올거"
for key in pre_state_dict.keys(): #이미 학습된 모델에서 구성요소 검색
if layer_name in key: #pre trained 모델의 layer중 내가 찾고자 하는 layer와 관련있는 곳 찾아서
model_state_dict[key].copy_(pre_state_dict[key]) #있다면 state_dict 업데이트 하고
model.load_state_dict(model_state_dict) #state_dict 업데이트 한것을 업데이트 한다.
여기서 gradient 가 안 흘ㄹ르게 하려면 아래와 같이 param.requires_grad=False를 지정하면된다.
for name, param in model.named_parameters():
if layer_name in name:
param.requires_grad = False