프로그래밍 언어/Python
Torch의 unsqueeze(), unsqueeze_()
Cuoong
2025. 3. 19. 00:14
PyTorch에서 unsqueeze()와 unsqueeze_()는 텐서의 차원을 확장하는 함수입니다. 두 함수의 주요 차이점은 연산 방식에 있습니다.
unsqueeze()
- 새로운 차원을 지정된 위치(인덱스)에 추가합니다
- 원본 텐서를 변경하지 않고 새로운 텐서를 반환합니다 (비파괴적 연산)
- 사용 예시:
import torch x = torch.tensor([1, 2, 3, 4]) # 형태: [4] y = x.unsqueeze(0) # 형태: [1, 4] z = x.unsqueeze(1) # 형태: [4, 1]
unsqueeze_()
- 기능적으로는 unsqueeze()와 동일합니다
- 원본 텐서를 직접 변경합니다 (파괴적/인플레이스 연산)
- 밑줄(_)은 PyTorch에서 인플레이스 연산을 의미합니다
- 사용 예시:
import torch x = torch.tensor([1, 2, 3, 4]) # 형태: [4] x.unsqueeze_(0) # x의 형태가 [1, 4]로 직접 변경됩니다
주요 용도
- 배치 차원 추가하기 (단일 샘플을 배치로 변환)
- 브로드캐스팅 연산을 위한 차원 맞추기
- CNN이나 RNN 등의 신경망에 입력할 때 필요한 차원 형태로 변환
unsqueeze_()는 메모리를 절약할 수 있지만, 원본 데이터가 변경된다는 점에 주의해야 합니다.