인프런 커뮤니티 질문&답변

킴허클베리님의 프로필 이미지
킴허클베리

작성한 질문수

[PyTorch] 쉽고 빠르게 배우는 GAN

[실습] DCGAN/LSGAN

runtime error : Found dtype Long but expected Float

작성

·

1.4K

0

안녕하세요~ 실습 DCGAN 부분에서 D 네트워크를 업데이트 하는 부분에서 에러가 나는데, 라인별로 타입을 확인하면서 수정하려 했는데, 왜 에러가 나는지 모르겠네요 ㅠㅠ 도움 부탁 드립니다. 

답변 3

2

output, label 둘 다 추가해주니 실행되네요.

---------------------------------------------------------------------------

output = netD(real_cpu).view(-1)

(추가) output = output.type(torch.FloatTensor).cuda()

(추가) label = label.type(torch.FloatTensor).cuda()

errD_real = criterion(output, label)

---------------------------------------------------------------------------

cuda 안 쓰시면 뒤에 .cuda()는 빼셔야 하구요.

1

아~! 답을 찾았습니다. 아래 처럼 바꾸니 되네요 ^^

output = netD(real_cpu).view(-1)

#추가하는 줄 

output = output.type(torch.FloatTensor)

0

        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        output = netD(real_cpu).view(-1)
        
        output = output.type(torch.FloatTensor)
추가해도 저는 똑같네요
Starting Training Loop...
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-16-7c0f18f95470> in <module> 23 output = output.type(torch.FloatTensor) 24 ---> 25 errD_real = criterion(output, label) 26 errD_real.backward() 27 D_x = output.mean().item() ~\anaconda3\envs\pytorch3.7\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) 725 result = self._slow_forward(*input, **kwargs) 726 else: --> 727 result = self.forward(*input, **kwargs) 728 for hook in itertools.chain( 729 _global_forward_hooks.values(), ~\anaconda3\envs\pytorch3.7\lib\site-packages\torch\nn\modules\loss.py in forward(self, input, target) 528 529 def forward(self, input: Tensor, target: Tensor) -> Tensor: --> 530 return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction) 531 532 ~\anaconda3\envs\pytorch3.7\lib\site-packages\torch\nn\functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction) 2524 2525 return torch._C._nn.binary_cross_entropy( -> 2526 input, target, weight, reduction_enum) 2527 2528 RuntimeError: Found dtype Long but expected Float
킴허클베리님의 프로필 이미지
킴허클베리

작성한 질문수

질문하기