해결된 질문
작성
·
275
1
안녕하세요. 큰 질문은 아니고 사소한 질문일 수도 있습니다만..
다름이 아니라, 행렬곱 강의에서 구현한 코드에서는 곱해주는 행렬 크기가 커질수록 오차가 누적되는 듯한(정확히 말하면 파이토치 내장 matmul과 계산 결과가 점점 더 달라지는듯한) 현상이 관찰되어 질문드립니다.
먼저, 실습에서 정의한 코드에서부터
x = torch.randn(16,16,device = 'cuda')
y = torch.randn(16,16,device = 'cuda')
a = matmul(x,y)
b = torch.matmul(x,y)
assert torch.allclose(a,b)
torch.allclose
의 기본 인자(atol=1e-8, rtol=1e-5) 세팅에서는 assertion error가 발생하여 조건을 완화시켜야(atol=1e-5, rtol=1e-5) assertion이 통과되는 모습을 보였고
x = torch.randn(2048,1024,device = 'cuda')
y = torch.randn(1024,256,device = 'cuda')
x, y의 크기를 이와 같이 키웠을 경우엔 atol=1e-4, rtol=1e-4로 조건을 완화시켜야 assertion을 통과하는 모습을 보였습니다.
triton kernel로 구현한 행렬곱 연산과 PyTorch 내장 matmul 연산 모두 fp32로 연산이 이루어지고 있는데, 이러한 오차가 발생할 수 있는 원인에 무엇이 있는지 궁금해서 질문 드립니다.
답변 2
2
안녕하세요? 아담한 고슴도치님,
먼저 강의를 수강해주셔서 감사합니다. 계산의 결과가 다른 이유는 크게 2가지가 있습니다.
첫째, 다른 데이터 타입을 사용함에 따라 오차가 발생할 수 있습니다. 예제의 경우 행렬을 곱을 tl.dot(x, y, allow_tf32=False)를 호출해서 계산했습니다. 만약 allow_tf32에 True가 설정되어 있거나 allow_tf32가 정의되어 있지 않는 경우에 오차가 발생할 수 있습니다. Triton이 행렬을 빠르게 계산하기 위해 float32를 tf32로 변환한 뒤 Tensor Core를 사용하기 때문입니다. tf32의 경우 float32보다 정밀도가 낮은데, 이 차이로 인해 계산의 오차가 발생할 수 있습니다.
둘째, 계산 순서에 따라 결과가 달라질 수 있습니다. float32는 IEEE 754 표준에 맞춰서 구현되어 있습니다. 지수에 8비트가 사용되고 가수에 23비트가 사용됩니다. 그러므로 float32는 실수를 다 표현할 수 없습니다. 이러한 한계 때문에 계산 순서에 따라 오차가 발생할 수 있습니다. 이러한 현상은 쉽게 확인할 수 있습니다. 크기가 20000인 배열에 실수가 저장되어 있는 경우, 순서대로 실수의 합을 더할때와 역순으로 실수의 합을 더할때의 결과가 다른 것을 확인할 수 있습니다.
이 2가지의 경우 하드웨어의 한계로 발생하는 오차입니다. 개인적으로 저는 이러한 오차를 오차라고 생각하지 않습니다.
마지막으로 예제 코드의 경우 경계 검사가 되어있지 않습니다. 텐서의 크기와 블록의 크기가 정확히 나누어 떨어지지 않는다면 쓰레기 값이 임시 텐서에 로드되게 되고 이것이 잘못된 결과를 만들 수 있습니다.
요즘 제가 바빠서 답변이 늦었는데 죄송합니다. 궁금하신점 있으시면 계속 물어봐주세요. 감사합니다!
0
상세한 답변 감사드립니다!!