[pytorch, 파이토치] 모델의 parameter 학습 안하기

import torch
import torch.nn as nn
import torch.optim as optim


model = ...  # 모델 정의

for param in model.parameters(): #모델의 모든 parameter 에 대한 gradient 추적 안하게함
    param.requires_grad = False

위에 처럼 하면, 모델의 모든 parameter가 학습되지 않는다.

모델의 특정 모듈의 parameter만 업데이트 하고 싶지 않다면 아래와 같은 코드를 사용하자.

for param in model.submodule.parameters():
    param.requires_grad = True

#학습가능한 parameter만 optimizer 가 업데이트
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

Leave a Comment