본문 바로가기
카테고리 없음

RuntimeError: mat1 and mat2 shapes cannot be multiplied 오류 해결하기

by 차분한 공돌이 2023. 12. 18.

pytorch에서 모델을 학습시킬 때 RuntimeError: mat1 and mat2 shapes cannot be multiplied 에러가 발생할 때가 있습니다.

이 오류는 선형 레이어의 가중치 행렬과 입력 행렬 간의 크기가 호환되지 않을 때 발생합니다.

model = nn.Sequential(         
    #데이터의 shape : (32,3,32,32)   <= (배치 사이즈, 채널 수, 높이, 너비)
    regnet_bifpn_object,
    #데이터의 shape : (32,320,32,32)
    nn.Conv2d(320,80, kernel_size=3, stride=1, padding=1), 
    #데이터의 shape : (32,80,32,32)
    nn.Conv2d(80,10, kernel_size=3, stride=1, padding=1),
    #데이터의 shape : (32,10,32,32)
    nn.Flatten(),
    #데이터의 shape : (32,10240)   <=10240  = 10*32*32
    nn.ReLU(),
     #데이터의 shape : (32,10240)
    nn.Linear(2000,200),
    
    nn.ReLU(),
    nn.Linear(200,10),
     
    

)

해결방법

입력 행렬의 크기는 10240 이면(32는 배치 사이즈이므로 고려하지 않습니다), 선형 레이어(nn.Linear)를
nn.Linear(2000,200 => nn.Linear(10240,200)
으로 바꿔준다.

 

#참고

nn.Linear의 기능

#데이터의 shape : (32,10240)
 nn.Linear(10240,200),
#데이터의 shape : (32,200)