Tensor란 Numpy의 ndarray와 같은 다차원 배열을 담는 자료구조이다.
1. Tensor 사용법
■ Tensor 생성
Tensor는 아래의 그림과 같이 list와 ndarray를 이용하여 생성할 수 있다.
※ torch.Tensor()와 torch.tensor() 차이
· torch.Tensor는 tensor자료구조의 클래스이다. 즉, 이 클래스를 이용하여 인스턴스를 생성할 수 있다. T = torch.Tensor()라 하면 T는 Tensor클래스의 인스턴스가 된다. T의 경우 data를 입력하지 않았으니 빈 tensor가 생성된다.
· torch.tensor는 어떤 data를 tensor로 copy해주는 함수이다. torch.tensor에 data를 넣었을 때, 그 data가 tensor가 아니면 torch.Tensor클래스를 적용하여 복사한다. 따라서 t = torch.tensor([1, 2, 3])처럼 data를 꼭 넣어주어야 한다. 그렇지 않으면 copy할 데이터가 없으니 에러가 난다.
■ Operation
1) 덧셈, 뺄셈, slicing, flatten, ones_like, shape등 기본적인 Operation은 Numpy와 똑같다.
2) Tensor는 GPU에 올려서 사용할 수 있다.
.device를 통해 현재 연결된 장치를 확인할 수 있고, 기본적으로는 cpu에 연동 되어있다.
.to('cuda')로 Tensor를 GPU에 올릴 수 있다.
3) 행렬 곱셈 함수
pytorch에서 행렬 곱셈은 mm과 matmul을 통해 할 수 있다. dot은 벡터 곱셈만 가능하고 mm과 matmul으로 행렬 곱셈을 할 수 있다. mm은 matmul과 다르게 broacast가 안 된다는 특징이 있다.
dot | 벡터 X 벡터 |
mm | 행렬 X 행렬 , broadcasting 지원 안 됨 |
matmul | 행렬 X 행렬 , 행렬 X 벡터 , broadcasting 지원됨 |
■ Tensor handling
view와 squeeze/unsqueeze를 통해 tensor를 handling할 수 있다.
1) view
view는 shape을 변환하는 함수로 reshape과 유사하다. 그러나 이 두 가지는 contiguity 보장에 있어서 차이가 있다.
※ view의 contiguity
a에 변화를 주면 a를 대입하여 만든 b에도 변화가 생긴다. 따라서 view의 경우 copy를 하지 않고 같은 메모리 주소를 그대로 쓰는 방식인 것을 알 수 있다.
※ reshape의 contiguity
2) squeeze/unsqueeze
squeeze는 1인 차원을 삭제하여 data를 압축해주고, unsqueeze는 1인 차원을 원하는 곳에 추가해준다.
■ nn.functional 모듈을 통한 DL/ML formula
import torch
import torch.nn.functional as F
tensor = FloatTensor([0.5, 0.7, 0.1])
#softmax 적용
h_tensor = F.softmax(tensor, dim=0)
h_tensor #tensor([0.3458, 0.4224, 0.2318])
#argmax 적용
y = torch.randint(5, (10, 5))
y_label = y.argmax(dim=1)
#one-hot encoding
torch.nn.functional.one_hot(y_label)
■ 자동 미분 AugoGrad
w = torch.tensor(2.0, requires_grad = True)
y = w**2
z = 10*y + 25
z.backward() #.backward()를 통해 자동 미분
w.grad #tensor(40,)
a = torch.tensor([2., 3.], requires_grad=True)
b = torch.tensor([6., 4.], requires_grad=True)
Q = 3*a**3 - b**2 external_grad = torch.tensor([1., 1.])
Q.backward(gradient=external_grad)
print(a.grad()) print(b.grad())
변수가 여러개인 함수의 역전파 시, 각각의 변수에 대해 grad()를 취해야 한다.
댓글