大規模なAIモデル学習において複数GPUを活用する場合、デバイス間の効率的なデータ通信はパフォーマンスの鍵となります。PyTorchのtorch.distributedモジュールはこれを実現する強力なプリミティブを提供しますが、ブロッキング/ノンブロッキングの動作の違いや各種Collective Operationsの理解は初学者にとって難しい場合があります。本稿では、NCCLバックエンドを基準に、これらの通信パターンを実行可能なコード例と共に明確に解説します。更なる詳細は原文の分析もご参照ください。

AI and GPU server illustration Developer Related Image

主要な通信パターンとコード例

分散PyTorchにおける通信は、主にポイントツーポイント通信と**Collective Operations(集合通信)**の2つに分類されます。以下に、同期/非同期両方の使い方を示す主要なコードスニペットを紹介します。

1. ポイントツーポイント通信 (Send/Recv)

2つの特定のRank間での直接通信です。

import torch
import torch.distributed as dist

# Rankの初期化は省略。各プロセスはdist.get_rank()で自身のRankを取得します。
rank = dist.get_rank()
device = torch.device(f'cuda:{rank}')

# 同期(ブロッキング)方式
if rank == 0:
    data = torch.tensor([1.0, 2.0, 3.0], device=device)
    dist.send(data, dst=1)  # Rank 1へデータ送信
    print("Rank 0: データ送信完了(キューに登録されました)")
elif rank == 1:
    data = torch.empty(3, device=device)
    dist.recv(data, src=0)  # Rank 0からデータ受信
    print(f"Rank 1: 受信データ = {data}")

# 非同期(ノンブロッキング)方式
if rank == 0:
    data = torch.tensor([4.0, 5.0], device=device)
    req = dist.isend(data, dst=1)  # 非同期送信を開始
    # ここでCPUは他の作業を実行可能
    req.wait()  # 通信完了を待機(ストリームをブロック)
    print("Rank 0: 非同期送信完了")
elif rank == 1:
    data = torch.empty(2, device=device)
    req = dist.irecv(data, src=0)  # 非同期受信を開始
    req.wait()
    print(f"Rank 1: 非同期受信データ = {data}")

2. 集合通信: Broadcast & All-Reduce

1つのRankのデータを全Rankにコピーする、または全Rankのデータを演算し結果を全Rankに配布します。

# Broadcast: src Rankのテンソルを全Rankにコピー
rank = dist.get_rank()
if rank == 0:
    tensor = torch.tensor([1, 2, 3], dtype=torch.int64, device=device)
else:
    tensor = torch.empty(3, dtype=torch.int64, device=device)  # 空テンソルを準備

dist.broadcast(tensor, src=0)
print(f'Rank {rank}: {tensor}')  # 全Rankが[1,2,3]を出力

# All-Reduce: 全Rankのテンソルを合計(SUM)し、結果を全Rankに保存
tensor = torch.tensor([rank + 1], dtype=torch.float32, device=device)  # Rank0: [1], Rank1: [2]
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
print(f'Rank {rank} after all_reduce: {tensor}')  # 全Rankが[3]を出力 (1+2)

Data center server rack Coding Session Visual

実装時の注意点とパフォーマンス向上のヒント

  1. 最初のrecv()は特別です: NCCLの内部ウォームアップのため、プロセスのライフサイクルにおける最初のtorch.distributed.recv()呼び出しは、データ転送が完了するまでホストを完全にブロックします。2回目以降の呼び出しは操作がキューに登録されると戻るため、注意が必要です。
  2. wait()synchronize() の違い:
    • request.wait(): 非同期通信操作に使用。その特定の通信が完了するまで現在アクティブなCUDAストリームをブロックします。ホストCPUはカーネルのキュー登録のみを待機します。
    • torch.cuda.synchronize(): GPU上のすべてのキューイングされた作業が完了するまで、ホストCPUスレッドを停止させるより強力なコマンドです。正確なベンチマーク測定には必要ですが、過度に使用するとパフォーマンスのボトルネックを引き起こす可能性があります。
  3. 計算と通信のオーバーラップ: 非同期通信(isend, irecv)を使用し、転送中にGPUが他の計算を実行できるようにすると、全体のスループットを大幅に向上させることができます。これは大規模分散学習における重要な最適化技術です。
  4. テンソルのメモリ管理: scatterなどの操作後、送信側Rank(src)の元データのメモリは自動的に解放されません。不要なメモリ占有を避けるため、明示的なdelまたはNoneの割り当てを検討してください。

Network nodes connection diagram Dev Environment Setup

まとめと次のステップ

本稿では、コードを中心にPyTorch torch.distributedの基本的な通信原理と使用方法を確認しました。実際の分散学習コード(例:Distributed Data Parallel)は、内部でこれらのCollective Operations、特にall_reduceを広範に使用しています。これらの概念を理解することは、複数GPUにわたるモデルの同期がどのように行われるかをデバッグし、パフォーマンスを分析する上で非常に役立ちます。

次のステップとして、torch.distributed.launchまたはtorchrunを使用して実際のマルチプロセス学習スクリプトを起動してみてください。環境変数NCCL_DEBUG=INFOを設定することで、通信プロセスの詳細なログを確認できます。通信がボトルネックになっているかプロファイリングし、必要に応じて非同期パターンへの切り替えを実践してみましょう。本稿の内容は原文の分析を参考に、実務中心に再構成したものです。