import torch
import torch.nn as nn
import torch.optim as optim
model = ... # 모델 정의
for param in model.parameters(): #모델의 모든 parameter 에 대한 gradient 추적 안하게함
param.requires_grad = False
위에 처럼 하면, 모델의 모든 parameter가 학습되지 않는다.
모델의 특정 모듈의 parameter만 업데이트 하고 싶지 않다면 아래와 같은 코드를 사용하자.
for param in model.submodule.parameters():
param.requires_grad = True
#학습가능한 parameter만 optimizer 가 업데이트
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)