はじめに:TPUは素晴らしいが、PyTorchエコシステムとの壁が問題だった
AIインフラのパラダイムが変わっています。単一GPUでは大規模モデルを賄えず、数万個のアクセラレータが接続された分散システムが必須になりました。GoogleのTPU(Tensor Processing Unit)は、こうした大規模AIワークロードのために設計されたカスタムASICです。Gemini、VeoといったGoogleの中核AIプラットフォームから、多数のクラウド顧客のワークロードまで、TPUはすでに実証済みのインフラです。
しかし、問題が一つありました。それは PyTorchとの互換性です。世界中のAI研究者や開発者の大多数はPyTorchを使用しています。ところが、彼らがTPUの性能をフル活用するには、PyTorch/XLAという別のブリッジを経由する必要があり、その過程で数多くの制約と非効率が発生していました。
Googleはこの問題を解決するために、TorchTPUという新しいスタックを開発しました。単なるXLAの改良ではなく、PyTorchの内部構造から再設計した根本的なアプローチです。この記事では、TorchTPUがどのようなエンジニアリング原則で作られ、内部アーキテクチャがどう動作するのか、そして2026年のロードマップまで詳細に解説します。

TorchTPUアーキテクチャ:TPUハードウェアの理解から始める
TPUシステムは単なるチップではありません。1つのホスト(Host)に複数のTPUチップが接続され、各チップは ICI(Inter-Chip Interconnect) という高速ネットワークで相互接続されます。このICIはチップを2Dまたは3Dトーラス(Torus)トポロジーで構成し、従来のネットワークボトルネックなしに大規模拡張を可能にします。
各TPUチップ内部は、さらに TensorCore と SparseCore に分かれます。
- TensorCore:シングルスレッドで動作し、行列積のような密結合演算(dense matrix math)に特化
- SparseCore:埋め込み(embedding)、gather/scatterのような不規則なメモリアクセスパターンを処理
こうしたハードウェア構造を最大限活用するには、PyTorchがTPUの特性を理解し、最適化されたコードを生成できる必要があります。TorchTPUはまさにこのポイントを目標としています。
中核哲学:「ただのPyTorchのように感じさせる」
TorchTPUの最も重要な原則は ユーザビリティ(Usability) です。開発者は既存のPyTorchスクリプトのたった一行を変更するだけで、TPU上で実行できる必要があります。
# 従来のGPUコード
device = torch.device("cuda")
# TPUに変更:この一行だけで完了!
device = torch.device("tpu")
これを実現するため、TorchTPUはPyTorchの PrivateUse1 インターフェースを活用しました。サブクラスやラッパーを使わず、普通のPyTorchテンソルをTPU上でそのまま使えるようにしたのです。これは、PyTorch/XLAが強制していた静的グラフコンパイル方式から脱却し、PyTorch開発者に馴染みのある Eager実行(即時実行) 環境を最優先で提供するためです。

3つのEagerモード:開発からプロダクションまで
TorchTPUは開発ライフサイクル全体をサポートするために、3つのEagerモードを提供します。
| モード | 動作方式 | 使用目的 | パフォーマンス |
|---|---|---|---|
| Debug Eager | 一度に1演算実行後、CPUと同期 | デバッグ:shape不一致、NaN、OOM追跡 | 非常に遅い |
| Strict Eager | 単一演算ディスパッチ、非同期実行 | 基本PyTorch体験のミラーリング | 普通 |
| Fused Eager | 演算ストリームを自動分析しリアルタイム融合 | プロダクション:追加設定不要で最大性能 | Strict比 50〜100%+向上 |
最も注目すべきは Fused Eager モードです。TorchTPUは実行される演算のストリームを自動的にリフレクション(reflection)し、複数のステップを1つの大きな演算塊にフュージョン(fusion)します。これによりTensorCoreの利用率が最大化され、メモリ帯域幅のオーバーヘッドが最小化されます。ユーザーは何も設定する必要がなく、Strict Eagerより2倍近く速いパフォーマンスを得られます。
これらすべてのモードは コンパイルキャッシュ(Compilation Cache) と連携します。シングルホストだけでなく、マルチホスト環境でも持続的にキャッシュを共有できるため、繰り返し実行時のコンパイル時間を劇的に削減します。
torch.compileとの統合:XLAバックエンドの力
最高のパフォーマンスが必要な場合、TorchTPUは torch.compile インターフェースとネイティブに統合されます。PyTorch 2.0から導入された torch.compile は、FXグラフをキャプチャした後コンパイラに送って最適化する方式ですが、TorchTPUはこの過程で Torch Inductorを経由せず、直接XLAをバックエンドコンパイラとして使用します。
import torch
import torchtpu
model = MyModel().to("tpu")
# torch.compileをTPUで使用
compiled_model = torch.compile(model, backend="torchtpu")
# 以降の実行は自動的にXLA経由で最適化されたTPUバイナリとして実行される
output = compiled_model(input_tensor)
この判断は非常に意図的です。XLAはすでにTPUトポロジーに対して厳格に検証されたコンパイラであり、ICIを介した 密結合演算と集団通信(collective communication)のオーバーラップ を最適化する方法をネイティブに理解しています。TorchTPUの変換レイヤーはPyTorch演算子を直接 StableHLO(XLAの中間表現、IR)にマッピングし、最適化されたTPUバイナリを生成します。
カスタムカーネル:PallasとJAX関数をそのまま
高性能のためにカスタム演算子を記述する必要がある場合、TorchTPUは Pallas と JAX で書かれたカスタムカーネルをネイティブにサポートします。
import torchtpu
import jax
import jax.numpy as jnp
@torchtpu.pallas.custom_jax_kernel
def custom_matmul(x, y):
# JAX関数で低レベルTPU命令を記述
return jnp.dot(x, y)
# PyTorchテンソルと自然に接続
result = custom_matmul(torch_tensor_a, torch_tensor_b)
@torchtpu.pallas.custom_jax_kernel デコレータを一つ付けるだけで、PyTorchのlowering pathと直接インターフェースする低レベルハードウェア命令を記述できます。Helionカーネルのサポートも進行中です。

分散実行:DDP、FSDPv2、DTensorを完全サポート
大規模モデル学習のために、TorchTPUはPyTorchの分散APIをそのままサポートします。DDP(Distributed Data Parallel)、FSDPv2(Fully Sharded Data Parallel v2)、DTensor が即座に利用可能で、PyTorchの分散APIをベースとするサードパーティライブラリもほとんど修正なしで動作します。
MPMDサポート:現実世界の不均衡を受け入れる
PyTorch/XLAの最大の限界の一つは、純粋なSPMD(Single Program Multiple Data) のみをサポートしていた点です。実際のPyTorchコードでは、rank 0プロセスがロギングや分析のために少し異なるコードを実行するケースがよくあります。このような「不純な」入力は、TPUスタックにとって大きな課題でした。
TorchTPUは MPMD(Multiple Program Multiple Data) 実行をサポートし、rank間でコードが少しずつ異なっても正確に動作するよう設計されています。必要な場合は通信プリミティブを分離して正確性を保証しつつ、可能な限りXLAのグローバルビュー最適化を維持します。
注意点:TPUに適したモデル設計が必要
TPUは優れたパフォーマンスを提供しますが、最適なモデル設計はGPUとは異なる場合があります。例えば、多くのモデルがアテンションヘッド次元を64にハードコーディングしている一方、現行世代のTPUは 128または256次元 で最高の行列積効率を発揮します。
# GPUに最適化された設定(次元64)
num_heads = 8
head_dim = 64 # TPUでは非効率
# TPUに最適化された設定(次元128)
num_heads = 4
head_dim = 128 # TPU TensorCore活用度最大化
TorchTPUは、移植性(Portability)がハードウェアの現実を無視しないことを認識し、階層型ワークフローを提案しています:まず正確な実行を確保し、その後ディープダイブガイドを通じて最適なアーキテクチャにリファクタリングするか、カスタムカーネルを注入する方式です。
2026年のロードマップと展望
TorchTPUチームは現在、以下の課題に取り組んでいます。
- 動的シーケンス長/バッチサイズによる再コンパイルの最小化:XLA内に高度なバウンデッドダイナミズム(bounded dynamism)を実装し、shape変更時のコンパイルオーバーヘッドを削減することが中核目標です。特に反復的な次トークン予測(next-token prediction)ワークロードで重要です。
- プリコンパイル済みTPUカーネルライブラリの構築:標準演算に対するプリコンパイル済みTPUカーネルライブラリを整備し、初回実行レイテンシを劇的に短縮する予定です。
- Helionカーネルサポートの拡大:より多様なカスタム演算をカバーできるよう、Helionバックエンドのサポートを強化します。
まとめ:TorchTPUがもたらす変化
TorchTPUは、単なるPyTorch-XLAブリッジのアップデートではありません。PyTorchをTPU上で ネイティブ に実行するための、Googleによる全面エンジニアリング投資です。Eager First哲学、3段階Eagerモード、XLAベースのコンパイルパイプライン、MPMDサポートまで — これらすべての要素が組み合わさり、開発者は「ただPyTorchを使うように」TPUの圧倒的なパフォーマンスを活用できるようになりました。
国内のAI開発環境でも、Google Cloud TPUや産総研ABCIなどのTPUインフラを活用するケースが増えています。TorchTPUの登場により、PyTorchベースの研究者が追加の学習曲線なしにTPUを活用する道が開かれました。
次のステップとしての学習方向:
- 公式 TPU Developer Hub で最新のTorchTPUアップデートを確認
- 実際のColabでTunixを活用したFunctionGemmaファインチューニング演習:FunctionGemma、TPUで10分でファインチューニングする方法
- 大規模マルチテナントSaaS運用事例:6,000のAWSアカウント、3人のエンジニア、1つのプラットフォーム ProGloveのアカウント-テナント分離運用記
合わせて読みたい記事: