Deep Learning study

PyTorch에서 발생하는 ‘expected scalar type Float but found Double’ 에러 해결 본문

AI/Pytorch

PyTorch에서 발생하는 ‘expected scalar type Float but found Double’ 에러 해결

illinaire 2025. 4. 6. 11:49
반응형


문제 정의:
PyTorch 코드를 실행할 때 종종 RuntimeError: expected scalar type Float but found Double라는 에러 메시지가 나타날 수 있습니다. 이 에러는 텐서의 데이터 타입이 기대하는 타입과 맞지 않을 때 발생합니다.

에러 케이스 설명:
예를 들어, 아래 코드에서는 torch.FloatTensor를 기대하는 모델이 torch.DoubleTensor로 변환된 데이터를 입력받아 문제가 발생합니다:

import torch

# 모델 정의
model = torch.nn.Linear(10, 1)

# 데이터 생성 (잘못된 데이터 타입)
x = torch.rand(32, 10, dtype=torch.double)

# 모델에 데이터 전달
output = model(x)  # 여기서 RuntimeError 발생


이 경우, 모델의 파라미터는 float32 데이터 타입으로 초기화되었으나 입력 데이터가 float64로 되어 있어 호환되지 않습니다.

올바른 사용법과 해결 예제:
문제를 해결하려면 데이터 타입을 맞추어야 합니다. 보통 PyTorch에서는 float32(torch.float)를 주로 사용합니다.

# 데이터 타입을 명시적으로 지정
x = x.float()

# 또는 처음부터 올바른 dtype으로 생성
x = torch.rand(32, 10, dtype=torch.float)

# 모델에 데이터 전달
output = model(x)  # 정상적으로 동작


정리 및 결론:
PyTorch의 텐서 타입 불일치 문제는 흔하지만, 해결하기도 쉽습니다. 입력 데이터의 타입을 모델의 파라미터 타입과 일치시키는 것으로 대부분의 에러를 방지할 수 있습니다. 이 과정에서 tensor.float()와 같은 타입 변환 함수가 유용하며, 데이터를 생성할 때부터 올바른 타입을 설정하는 습관을 가지는 것이 중요합니다.

반응형
Comments