대규모 AI 모델 학습을 위해 멀티 GPU를 활용할 때, 각 GPU 간의 효율적인 데이터 통신은 성능의 핵심입니다. PyTorch의 torch.distributed 모듈은 이를 위한 강력한 도구를 제공하지만, Blocking/Non-Blocking 동작 방식과 다양한 Collective Operations의 차이를 이해하는 것은 초보자에게는 쉽지 않을 수 있습니다. 이 글에서는 NCCL 백엔드를 기준으로, 각 통신 패턴을 실전 예제 코드와 함께 명확하게 설명해 드립니다. 더 깊은 이해를 위해 원문 분석도 참고하시면 좋습니다.

AI and GPU server illustration IT Technology Image

핵심 통신 패턴과 코드 예제

통신은 크게 **점대점(Point-to-Point)**과 **집합 통신(Collective)**으로 나뉩니다. 아래는 각 패턴의 동기/비동기 사용법을 보여주는 핵심 코드입니다.

1. 점대점 통신 (Send/Recv)

두 개의 특정 Rank 간 직접 통신입니다.

import torch
import torch.distributed as dist

# 랭크 초기화는 생략. 각 프로세스의 랭크는 dist.get_rank()로 얻습니다.
rank = dist.get_rank()
device = torch.device(f'cuda:{rank}')

# 동기(Synchronous) 방식
if rank == 0:
    data = torch.tensor([1.0, 2.0, 3.0], device=device)
    dist.send(data, dst=1)  # 랭크 1로 데이터 전송
    print("Rank 0: 데이터 전송 완료(큐에 등록됨)")
elif rank == 1:
    data = torch.empty(3, device=device)
    dist.recv(data, src=0)  # 랭크 0으로부터 데이터 수신
    print(f"Rank 1: 수신 데이터 = {data}")

# 비동기(Asynchronous) 방식
if rank == 0:
    data = torch.tensor([4.0, 5.0], device=device)
    req = dist.isend(data, dst=1)  # 비동기 전송 시작
    # 여기서 CPU는 다른 작업을 할 수 있음
    req.wait()  # 통신 완료 대기 (스트림 차단)
    print("Rank 0: 비동기 전송 완료")
elif rank == 1:
    data = torch.empty(2, device=device)
    req = dist.irecv(data, src=0)  # 비동기 수신 시작
    req.wait()
    print(f"Rank 1: 비동기 수신 데이터 = {data}")

2. 집합 통신: Broadcast & All-Reduce

한 랭크의 데이터를 모든 랭크에 복사하거나, 모든 랭크의 데이터를 연산 후 모두에게 배포합니다.

# Broadcast: src 랭크의 텐서를 모든 랭크에 복사
rank = dist.get_rank()
if rank == 0:
    tensor = torch.tensor([1, 2, 3], dtype=torch.int64, device=device)
else:
    tensor = torch.empty(3, dtype=torch.int64, device=device)  # 빈 텐서 준비

dist.broadcast(tensor, src=0)
print(f'Rank {rank}: {tensor}')  # 모든 랭크가 [1,2,3] 출력

# All-Reduce: 모든 랭크의 텐서를 합산(SUM) 후 결과를 모든 랭크에 저장
tensor = torch.tensor([rank + 1], dtype=torch.float32, device=device)  # 랭크0: [1], 랭크1: [2]
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
print(f'Rank {rank} after all_reduce: {tensor}')  # 모든 랭크가 [3] 출력 (1+2)

Data center server rack Programming Illustration

실무 적용 시 주의사항 및 성능 팁

  1. 첫 번째 recv()는 특별합니다: NCCL 내부 워밍업 때문에, 프로세스 생애 주기에서 첫 번째 torch.distributed.recv() 호출은 데이터 전송이 완료될 때까지 호스트를 완전히 차단(Block)합니다. 두 번째 호출부터는 커널이 CUDA 스트림에 등록만 되면 반환하므로 주의가 필요합니다.
  2. wait() vs synchronize():
    • request.wait(): 비동기 통신 작업에 사용. 해당 통신이 완료될 때까지 현재 활성 CUDA 스트림을 차단합니다. 호스트 CPU는 커널 등록만 기다립니다.
    • torch.cuda.synchronize(): 호스트 CPU를 GPU의 모든 대기 작업이 완료될 때까지 완전히 정지시킵니다. 성능 측정(Benchmark) 시 정확한 시간을 재려면 필요하지만, 과용하면 성능 병목을 일으킬 수 있습니다.
  3. 통신과 계산 중첩(Overlap): 비동기 통신(isend, irecv)을 사용하고 통신이 진행되는 동안 GPU가 다른 계산을 수행하게 하면 전체 처리 속도를 크게 높일 수 있습니다. 이는 대규모 분산 학습의 핵심 최적화 기술입니다.
  4. 텐서 메모리 관리: scatter 같은 작업 후, 송신 측(src)의 원본 데이터 메모리는 자동으로 해제되지 않습니다. 불필요한 메모리 점유를 방지하려면 명시적으로 del 또는 None 할당을 고려하세요.

Network nodes connection diagram System Abstract Visual

마무리: 실전 적용 조언

이번 글에서는 PyTorch torch.distributed의 기본 통신 원리와 사용법을 코드 중심으로 살펴봤습니다. 실제 분산 학습 코드(예: Distributed Data Parallel)는 내부적으로 이러한 Collective Operations, 특히 all_reduce를 광범위하게 사용합니다. 개념을 이해하면 모델이 여러 GPU에 걸쳐 어떻게 동기화되는지 디버깅하고 성능을 분석하는 데 큰 도움이 될 것입니다.

다음 단계로는 torch.distributed.launch 또는 torchrun을 이용해 실제 멀티 프로세스 학습 스크립트를 실행해보고, NCCL_DEBUG=INFO 환경 변수를 설정하여 통신 과정을 로그로 확인해보는 것을 추천합니다. 통신이 병목이 되는지 프로파일링하고, 필요하다면 비동기 통신으로 전환하는 실전 연습을 해보세요. 본 글의 내용은 원문 분석을 참고하여 실무 중심으로 재구성되었습니다.