728x90
반응형
torch.matmul
vector 및 matrix 간의 다양한 곱을 수행한다.
broadcast 기능을 제공하며 가장 일반적으로 사용되나, broadcast 기능이 도리어 debug point가 될 수 있다.
broadcast 기능은 아래의 예제와 같이 T1(10, 3, 4) T2(4)을 곱할 때, 맨 앞의 dim이 3개 일 때는 첫 dim을 batch로 간주하고 T1 (3, 4) tensor의 10개의 batch와 각각 T2(4)랑 곱을 해주는 것이다.
torch.matmul(input, other, *, out=None) → Tensor
torch.mm
torch.matmul과 차이점은 broadcast가 안 된다는 점이다.
즉 mm은 정확하게 matrix 곱의 사이즈가 맞아야 사용이 가능하다.
따라서 내가 작성한 코드가 의도대로 작동하는 지 확인을 위해서 mm의 사용이 적절하다는 생각이 든다. (debug point)
torch.mm(input, mat2, *, out=None) → Tensor
input의 size: (n x m)
mat2의 size: (m x p)
output의 size: (n x p)
torch.bmm
torch.matmul과 차이점은 broadcast가 안 된다는 점이다.
즉 mm은 정확하게 matrix 곱의 사이즈가 맞아야 사용이 가능하다.
따라서 내가 작성한 코드가 의도대로 작동하는 지 확인을 위해서 mm의 사용이 적절하다는 생각이 든다. (debug point)
반응형
'AI > PyTorch' 카테고리의 다른 글
텐서에서 Top-K 결과를 받아오는 방법 (0) | 2023.10.26 |
---|---|
cuDNN benchmark 활성화를 통한 최적의 알고리즘 선택 (0) | 2023.10.14 |
PyTorch에서 이미지 데이터에 대해 normalize를 할 때, mean=[0.485, 0.456, 0.406]과 std=[0.229, 0.224, 0.225]를 쓰는 이유는? (2) | 2023.10.09 |
Softmax 결과의 총합이 1 이하로 나오는 경우 (0) | 2022.07.18 |