프로그래밍 언어/Python

Torch의 unsqueeze(), unsqueeze_()

Cuoong 2025. 3. 19. 00:14

PyTorch에서 unsqueeze()와 unsqueeze_()는 텐서의 차원을 확장하는 함수입니다. 두 함수의 주요 차이점은 연산 방식에 있습니다.

unsqueeze()

  • 새로운 차원을 지정된 위치(인덱스)에 추가합니다
  • 원본 텐서를 변경하지 않고 새로운 텐서를 반환합니다 (비파괴적 연산)
  • 사용 예시:
    import torch
    x = torch.tensor([1, 2, 3, 4])  # 형태: [4]
    y = x.unsqueeze(0)  # 형태: [1, 4]
    z = x.unsqueeze(1)  # 형태: [4, 1]

unsqueeze_()

  • 기능적으로는 unsqueeze()와 동일합니다
  • 원본 텐서를 직접 변경합니다 (파괴적/인플레이스 연산)
  • 밑줄(_)은 PyTorch에서 인플레이스 연산을 의미합니다
  • 사용 예시:
    import torch
    x = torch.tensor([1, 2, 3, 4])  # 형태: [4]
    x.unsqueeze_(0)  # x의 형태가 [1, 4]로 직접 변경됩니다

주요 용도

  • 배치 차원 추가하기 (단일 샘플을 배치로 변환)
  • 브로드캐스팅 연산을 위한 차원 맞추기
  • CNN이나 RNN 등의 신경망에 입력할 때 필요한 차원 형태로 변환

unsqueeze_()는 메모리를 절약할 수 있지만, 원본 데이터가 변경된다는 점에 주의해야 합니다.