들어가며: TPU는 훌륭한데, PyTorch 생태계와의 벽이 문제였다
AI 인프라의 패러다임이 바뀌고 있습니다. 더 이상 단일 GPU로는 대규모 모델을 감당할 수 없고, 수만 개의 가속기가 연결된 분산 시스템이 필수가 되었죠. 구글의 TPU(Tensor Processing Unit)는 이런 대규모 AI 워크로드를 위해 설계된 커스텀 ASIC입니다. Gemini, Veo 같은 구글의 핵심 AI 플랫폼부터 수많은 클라우드 고객의 워크로드까지, TPU는 이미 검증된 인프라입니다.
하지만 문제가 하나 있었습니다. 바로 PyTorch와의 호환성이었죠. 전 세계 AI 연구자와 개발자의 대다수는 PyTorch를 사용합니다. 그런데 이들이 TPU의 성능을 제대로 활용하려면 PyTorch/XLA라는 별도의 브릿지를 거쳐야 했고, 이 과정에서 수많은 제약과 비효율이 발생했습니다.
구글은 이 문제를 해결하기 위해 TorchTPU라는 새로운 스택을 만들었습니다. 단순한 XLA 개선이 아닌, PyTorch의 내부 구조부터 다시 설계한 근본적인 접근법입니다. 이 글에서는 TorchTPU가 어떤 엔지니어링 원칙으로 만들어졌고, 내부 아키텍처가 어떻게 동작하는지, 그리고 2026년 로드맵까지 상세히 살펴보겠습니다.

TorchTPU 아키텍처: TPU 하드웨어 이해부터 시작
TPU 시스템은 단순한 칩이 아닙니다. 하나의 호스트(Host)에 여러 개의 TPU 칩이 연결되고, 각 칩은 ICI(Inter-Chip Interconnect) 라는 고속 네트워크로 서로 연결됩니다. 이 ICI는 칩들을 2D 또는 3D 토러스(Torus) 토폴로지로 구성하여, 전통적인 네트워크 병목 없이 대규모 확장을 가능하게 합니다.
각 TPU 칩 내부는 다시 TensorCore와 SparseCore로 나뉩니다.
- TensorCore: 단일 스레드로 동작하며, 행렬 곱셈 같은 밀집 연산(dense matrix math)에 특화
- SparseCore: 임베딩, gather/scatter 같은 불규칙한 메모리 접근 패턴 처리
이런 하드웨어 구조를 제대로 활용하려면, PyTorch가 TPU의 특성을 이해하고 최적화된 코드를 생성할 수 있어야 합니다. TorchTPU는 바로 이 지점을 목표로 합니다.
핵심 철학: "그냥 PyTorch처럼 느껴지게 하라"
TorchTPU의 가장 중요한 원칙은 사용성(Usability) 입니다. 개발자는 기존 PyTorch 스크립트에서 단 한 줄만 바꾸면 TPU에서 실행할 수 있어야 합니다.
# 기존 GPU 코드
device = torch.device("cuda")
# TPU로 변경: 이 한 줄만 바꾸면 끝!
device = torch.device("tpu")
이를 위해 TorchTPU는 PyTorch의 PrivateUse1 인터페이스를 활용했습니다. 서브클래스나 래퍼 없이, 평범한 PyTorch 텐서를 TPU 위에서 그대로 사용할 수 있게 만든 것입니다. 이는 PyTorch/XLA가 강제했던 정적 그래프 컴파일 방식에서 벗어나, PyTorch 개발자에게 익숙한 Eager 실행(즉시 실행) 환경을 최우선으로 제공하기 위함입니다.

3가지 Eager 모드: 개발부터 프로덕션까지
TorchTPU는 개발 생명주기 전반을 지원하기 위해 세 가지 Eager 모드를 제공합니다.
| 모드 | 동작 방식 | 사용 목적 | 성능 |
|---|---|---|---|
| Debug Eager | 한 번에 하나의 연산 실행 후 CPU와 동기화 | 디버깅: shape 불일치, NaN, OOM 추적 | 매우 느림 |
| Strict Eager | 단일 연산 디스패치, 비동기 실행 | 기본 PyTorch 경험 미러링 | 보통 |
| Fused Eager | 연산 스트림을 자동 분석하여 실시간 퓨전 | 프로덕션: 추가 설정 없이 최대 성능 | Strict 대비 50~100%+ 향상 |
가장 주목할 점은 Fused Eager 모드입니다. TorchTPU는 실행되는 연산들의 스트림을 자동으로 리플렉션(reflection)하여, 여러 단계를 하나의 큰 연산 덩어리로 퓨전(fusion)합니다. 이렇게 하면 TensorCore 활용도가 극대화되고, 메모리 대역폭 오버헤드가 최소화됩니다. 사용자는 아무것도 설정할 필요 없이, Strict Eager보다 2배 가까이 빠른 성능을 얻을 수 있습니다.
이 모든 모드는 컴파일 캐시(Compilation Cache) 와 함께 동작합니다. 단일 호스트뿐 아니라 멀티 호스트 환경에서도 지속적으로 캐시를 공유할 수 있어, 반복 실행 시 컴파일 시간을 획기적으로 줄여줍니다.
torch.compile과의 통합: XLA 백엔드의 힘
최고 성능이 필요하다면, TorchTPU는 torch.compile 인터페이스와 네이티브로 통합됩니다. PyTorch 2.0부터 도입된 torch.compile은 FX 그래프를 캡처한 후 컴파일러로 보내 최적화하는 방식인데, TorchTPU는 이 과정에서 Torch Inductor를 거치지 않고 직접 XLA를 백엔드 컴파일러로 사용합니다.
import torch
import torchtpu
model = MyModel().to("tpu")
# torch.compile을 TPU에서 사용
compiled_model = torch.compile(model, backend="torchtpu")
# 이후 실행은 자동으로 XLA를 통해 최적화된 TPU 바이너리로 실행됨
output = compiled_model(input_tensor)
이 결정은 매우 의도적입니다. XLA는 이미 TPU 토폴로지에 대해 엄격하게 검증된 컴파일러이며, ICI를 통한 밀집 연산과 집단 통신(collective communication)의 중첩(overlap) 을 최적화하는 방법을 네이티브로 이해하고 있습니다. TorchTPU의 변환 레이어는 PyTorch 연산자를 직접 StableHLO(XLA의 중간 표현, IR)로 매핑하여, 최적화된 TPU 바이너리를 생성합니다.
커스텀 커널: Pallas와 JAX 함수를 그대로
고성능을 위해 커스텀 연산자를 작성해야 한다면? TorchTPU는 Pallas와 JAX로 작성된 커스텀 커널을 네이티브로 지원합니다.
import torchtpu
import jax
import jax.numpy as jnp
@torchtpu.pallas.custom_jax_kernel
def custom_matmul(x, y):
# JAX 함수로 저수준 TPU 명령어 작성
return jnp.dot(x, y)
# PyTorch 텐서와 자연스럽게 연결
result = custom_matmul(torch_tensor_a, torch_tensor_b)
@torchtpu.pallas.custom_jax_kernel 데코레이터 하나만 붙이면, PyTorch의 lowering path와 직접 인터페이스되는 저수준 하드웨어 명령어를 작성할 수 있습니다. Helion 커널 지원도 진행 중입니다.

분산 실행: DDP, FSDPv2, DTensor 완벽 지원
대규모 모델 학습을 위해 TorchTPU는 PyTorch의 분산 API를 그대로 지원합니다. DDP(Distributed Data Parallel), FSDPv2(Fully Sharded Data Parallel v2), DTensor가 즉시 사용 가능하며, PyTorch의 분산 API를 기반으로 하는 서드파티 라이브러리들도 대부분 수정 없이 동작합니다.
MPMD 지원: 현실 세계의 불균형을 받아들이다
PyTorch/XLA의 가장 큰 한계 중 하나는 순수 SPMD(Single Program Multiple Data) 만 지원했다는 점입니다. 실제 PyTorch 코드에서는 rank 0 프로세스가 로깅이나 분석을 위해 약간 다른 코드를 실행하는 경우가 흔합니다. 이런 "불순한" 입력은 TPU 스택에 큰 도전 과제였습니다.
TorchTPU는 MPMD(Multiple Program Multiple Data) 실행을 지원하여, rank 간 코드가 약간씩 달라도 정확하게 동작하도록 설계되었습니다. 필요한 경우 통신 프리미티브를 격리하여 정확성을 보존하면서도, 가능한 한 XLA의 글로벌 뷰 최적화를 유지합니다.
주의사항: TPU에 맞는 모델 설계가 필요하다
TPU는 뛰어난 성능을 제공하지만, 최적의 모델 설계는 GPU와 다를 수 있습니다. 예를 들어, 많은 모델이 어텐션 헤드 차원을 64로 하드코딩하는 반면, 현재 세대 TPU는 128 또는 256 차원에서 최고의 행렬 곱셈 효율을 냅니다.
# GPU에 최적화된 설정 (차원 64)
num_heads = 8
head_dim = 64 # TPU에서는 비효율적
# TPU에 최적화된 설정 (차원 128)
num_heads = 4
head_dim = 128 # TPU TensorCore 활용도 극대화
TorchTPU는 이식성(Portability)이 하드웨어 현실을 무시하지 않는다는 점을 인지하고, 계층형 워크플로우를 제안합니다: 먼저 정확한 실행을 확보한 후, 딥다이브 가이드를 통해 최적의 아키텍처로 리팩터링하거나 커스텀 커널을 주입하는 방식입니다.
2026년 로드맵과 전망
TorchTPU 팀은 현재 다음과 같은 과제를 해결 중입니다.
- 동적 시퀀스 길이/배치 크기로 인한 재컴파일 최소화: XLA 내에 고급 바운디드 다이나미즘(bounded dynamism)을 구현하여, shape 변경 시 컴파일 오버헤드를 줄이는 것이 핵심 목표입니다. 특히 반복적인 다음 토큰 예측(next-token prediction) 워크로드에 중요합니다.
- 사전 컴파일된 TPU 커널 라이브러리 구축: 표준 연산에 대한 사전 컴파일된 TPU 커널 라이브러리를 만들어, 첫 번째 실행 지연 시간을 획기적으로 단축할 예정입니다.
- Helion 커널 지원 확대: 더 다양한 커스텀 연산을 커버할 수 있도록 Helion 백엔드 지원을 강화합니다.
결론: TorchTPU가 가져올 변화
TorchTPU는 단순한 PyTorch-XLA 브릿지의 업데이트가 아닙니다. PyTorch를 TPU에서 네이티브로 실행하기 위한 구글의 전폭적인 엔지니어링 투자입니다. Eager First 철학, 3단계 Eager 모드, XLA 기반 컴파일 파이프라인, MPMD 지원까지 — 이 모든 요소가 결합되어 개발자는 "그냥 PyTorch 쓰듯이" TPU의 압도적인 성능을 활용할 수 있게 되었습니다.
국내 AI 개발 환경에서도 TPU 접근성이 높아지고 있습니다. Google Cloud TPU를 사용하거나, 한국전자통신연구원(ETRI) 등 국내 연구기관의 TPU 인프라를 활용하는 사례가 늘고 있습니다. TorchTPU의 등장으로 PyTorch 기반 연구자들이 추가 학습 곡선 없이 TPU를 활용할 수 있는 길이 열렸습니다.
다음 단계 학습 방향:
- 공식 TPU Developer Hub에서 최신 TorchTPU 업데이트 확인
- 실제 Colab에서 Tunix를 활용한 FunctionGemma 파인튜닝 실습: FunctionGemma, TPU로 10분 만에 파인튜닝하기
- 대규모 멀티 테넌트 SaaS 운영 사례: 6,000개 AWS 계정, 3명의 엔지니어, 하나의 플랫폼 ProGlove의 계정-테넌트 분리 운영기
함께 보면 좋은 글: