PyTorch 코드를 보면 자주 나오는 model.train(), model.eval(), 그리고 torch.no_grad()에 대해서 간단히 정리해봤습니다.
model.train()
학습할 때와 추론할 때 다르게 동작하는 Layer들을 Training mode로 바꿔줍니다. 예를 들어
- Batch Normalization Layer는 Batch Statistics를 이용하게 되고,
- Dropout Layer가 주어진 확률에 따라 활성화됩니다.
model.eval()
학습할 때와 추론할 때 다르게 동작하는 Layer들을 Evaluation(Inference) mode로 바꿔줍니다. 예를 들어
- Batch Normalization Layer는 학습할 때 사용된 Batch Statistics를 통해 결정된 Running Statistics를 이용하게 되고,
- Dropout Layer는 비활성화됩니다.
참고로 model.eval()과 model.train(False)는 동일한 기능을 합니다.
torch.no_grad()
PyTorch의 Autograd Engine을 비활성화하여 Gradient를 계산하지 않도록 합니다. 따라서 Gradient를 계산할 필요가 없는 경우 torch.no_grad()를 통해 메모리 사용량을 줄이고 계산 속도를 빠르게 만들 수 있습니다.
참고로 데코레이터로 사용할 수도 있기 때문에 Gradient 계산이 필요없는 연산을 수행하는 함수를 사전에 지정할 수도 있습니다.
@torch.no_grad()
def func_without_gradient_tracking():
....
f = func_without_gradient_tracking()
f.requires_grad # False
Summary
따라서 상황별로 사용할 메소드를 다음과 같이 정리할 수 있습니다.
1. 모델 학습
model.train()
# Code related to training
2. 모델 추론 및 평가
model.eval()
with torch.no_grad():
# Code related to inference
잘못된 내용, 오타, 부정확한 문장 등 어떤 피드백이든 환영합니다. 감사합니다.