Data Science/Framework

    [PyTorch] model.train() vs. model.eval() vs. torch.no_grad()

    PyTorch 코드를 보면 자주 나오는 model.train(), model.eval(), 그리고 torch.no_grad()에 대해서 간단히 정리해봤습니다. model.train() 학습할 때와 추론할 때 다르게 동작하는 Layer들을 Training mode로 바꿔줍니다. 예를 들어 Batch Normalization Layer는 Batch Statistics를 이용하게 되고, Dropout Layer가 주어진 확률에 따라 활성화됩니다. model.eval() 학습할 때와 추론할 때 다르게 동작하는 Layer들을 Evaluation(Inference) mode로 바꿔줍니다. 예를 들어 Batch Normalization Layer는 학습할 때 사용된 Batch Statistics를 통해 결정된 Run..