들어가며: 왜 FunctionGemma를 TPU로 파인튜닝해야 할까?
FunctionGemma는 자연어를 API 호출로 변환하는 에이전트를 엣지 디바이스에서도 빠르게 실행할 수 있도록 설계된 경량 언어 모델입니다. 기존에는 Hugging Face TRL + GPU 조합으로 파인튜닝하는 가이드가 많았지만, GPU 비용이 부담스럽거나 TPU 인프라를 활용하고 싶은 분들에게는 다른 접근이 필요했죠.
이 글에서는 Google Tunix라는 JAX 기반 라이브러리를 사용해 Colab 무료 티어 TPU v5e-1에서 FunctionGemma를 LoRA 방식으로 파인튜닝하는 방법을 단계별로 소개합니다. GPU 대비 TPU의 장점은 무엇보다 비용 효율성과 JAX 네이티브 성능입니다. 특히 대규모 배치 처리나 반복 실험이 필요한 연구/프로토타이핑 단계에서 TPU는 강력한 대안이 됩니다.
참고: 이 튜토리얼은 메타가 공개한 RCCLX AMD 플랫폼 GPU 통신 성능을 혁신하다에서 다룬 GPU 통신 최적화와는 다른 축이지만, 하드웨어 가속기 선택의 폭을 넓힌다는 점에서 함께 보시면 좋습니다.
준비물
- Google Colab (TPU 런타임 설정: 런타임 → 런타임 유형 변경 → TPU v5e-1)
- Hugging Face 계정 (FunctionGemma 모델 접근 권한 필요)
- 기본적인 Python, JAX, Transformers 지식
![]()
실전 파인튜닝: Tunix로 FunctionGemma LoRA 학습하기
1. 모델 및 데이터셋 다운로드
Hugging Face Hub에서 FunctionGemma(270M)와 Mobile Action 데이터셋을 내려받습니다.
from huggingface_hub import snapshot_download, hf_hub_download
MODEL_ID = "google/functiongemma-270m-it"
DATASET_ID = "google/mobile-actions"
# safetensors만 다운로드 (PyTorch 가중치 제외)
local_model_path = snapshot_download(
repo_id=MODEL_ID,
ignore_patterns=["*.pth"]
)
data_file = hf_hub_download(
repo_id=DATASET_ID,
filename="dataset.jsonl",
repo_type="dataset"
)
2. TPU 메시 구성
Colab 무료 티어는 단일 코어 TPU이므로, 샤딩 없이 간단한 메시를 만듭니다.
import jax
NUM_TPUS = len(jax.devices())
# 단일 코어: (1, 1) 메시
MESH = [(1, 1), ("fsdp", "tp")]
mesh = jax.make_mesh(
*MESH,
axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0])
)
3. LoRA 어댑터 적용
Tunix는 create_model_from_safe_tensors()로 safetensors를 바로 로드하고, Qwix로 LoRA를 적용합니다.
import qwix
from tunix import params_safetensors_lib
from tunix import nnx
with mesh:
# 베이스 모델 로드
base_model = params_safetensors_lib.create_model_from_safe_tensors(
local_model_path, model_config, mesh
)
# LoRA 프로바이더: attention 레이어에 적용
lora_provider = qwix.LoraProvider(
module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj",
rank=LORA_RANK,
alpha=LORA_ALPHA,
)
model_input = base_model.get_model_input()
model = qwix.apply_lora_to_model(
base_model, lora_provider,
rngs=nnx.Rngs(0), **model_input
)
# 상태 샤딩
state = nnx.state(model)
pspecs = nnx.get_partition_spec(state)
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state)
4. 커스텀 데이터셋 (Completion-only Loss)
Completion-only loss를 지원하기 위해 __iter__에서 프롬프트 마스킹을 구현합니다.
class CustomDataset:
def __init__(self, data, tokenizer, max_length=1024):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __iter__(self):
for item in self.data:
template_inputs = json.loads(item['text'])
# 전체 텍스트 (프롬프트 + 완성)
prompt_and_completion = self.tokenizer.apply_chat_template(
template_inputs['messages'],
tools=template_inputs['tools'],
tokenize=False,
add_generation_prompt=False
)
# 프롬프트만 (완성 부분 제외)
prompt_only = self.tokenizer.apply_chat_template(
template_inputs['messages'][:-1],
tools=template_inputs['tools'],
tokenize=False,
add_generation_prompt=True
)
tokenized_full = self.tokenizer(
prompt_and_completion, add_special_tokens=False
)
tokenized_prompt = self.tokenizer(
prompt_only, add_special_tokens=False
)
full_ids = tokenized_full['input_ids']
prompt_len = len(tokenized_prompt['input_ids'])
if len(full_ids) > self.max_length:
full_ids = full_ids[:self.max_length]
input_tokens = np.full(
(self.max_length,), self.tokenizer.pad_token_id, dtype=np.int32
)
input_tokens[:len(full_ids)] = full_ids
input_mask = np.zeros((self.max_length,), dtype=np.int32)
if len(full_ids) > prompt_len:
mask_end = min(len(full_ids), self.max_length)
input_mask[prompt_len:mask_end] = 1
yield peft_trainer.TrainingInput(
input_tokens=jnp.array(input_tokens, dtype=jnp.int32),
input_mask=jnp.array(input_mask, dtype=jnp.int32)
)
5. 학습 실행
trainer = peft_trainer.PeftTrainer(
model,
optax.adamw(lr_schedule),
training_config
).with_gen_model_input_fn(gen_model_input_fn)
with mesh:
trainer.train(train_batches, val_batches)
6. LoRA 병합 및 내보내기
from tunix import gemma_params
merged_output_dir = os.path.join(OUTPUT_DIR, "merged")
gemma_params.save_lora_merged_model_as_safetensors(
local_model_path=local_model_path,
output_dir=merged_output_dir,
lora_model=model,
rank=LORA_RANK,
alpha=LORA_ALPHA,
)
print(f"병합된 LoRA 모델 저장 완료: {merged_output_dir}")

주의사항 및 실무 팁
TPU 사용 시 알아둘 점
- Colab 무료 티어 TPU v5e-1은 단일 코어이므로, 대규모 모델(7B 이상) 파인튜닝에는 적합하지 않습니다. FunctionGemma 270M 정도가 적당합니다.
- TPU는 JAX에 최적화되어 있으므로, PyTorch 코드를 그대로 가져오면 동작하지 않습니다.
optax,jax.numpy등 JAX 생태계에 익숙해져야 합니다. - Tunix는 아직 초기 단계(2025년 기준)이므로, 프로덕션에 적용하기 전에 충분한 테스트가 필요합니다.
한국 개발 생태계에서의 적용 맥락
- 국내 클라우드 환경(GCP, NCP 등)에서 TPU를 직접 사용하는 사례는 아직 드물지만, 비용 최적화 측면에서 GPU 대비 TPU가 매력적인 선택지가 될 수 있습니다. 특히 스타트업이나 연구소에서 반복적인 실험을 많이 하는 경우 TPU 스팟 인스턴스를 활용하면 비용을 크게 줄일 수 있습니다.
- 네이버 클라우드나 KT Cloud 등 국내 CSP에서도 TPU 서비스를 도입하는 추세이므로, 미리 JAX/Tunix 스택에 익숙해져 두면 향후 인프라 선택 폭이 넓어집니다.
이 기술의 한계
- Tunix는 아직 Hugging Face TRL만큼 커뮤니티가 크지 않아, 문제 발생 시 참고할 자료가 부족할 수 있습니다.
- TPU는 GPU와 달리 동적 shape 변경에 취약하므로, 데이터셋 전처리 시
max_length로 패딩을 고정해야 합니다. - FunctionGemma 자체가 소형 모델이므로, 복잡한 API 스키마를 학습하기에는 파라미터 수가 부족할 수 있습니다.

마무리: 다음 스텝은?
이 튜토리얼에서는 LoRA 기반 지도 파인튜닝만 다뤘지만, Tunix는 **강화학습(RL)**이나 선호도 튜닝(DPO) 등 더 고급 기법도 지원합니다. 관심 있으신 분들은 아래 자료를 참고해 보세요.
- Tunix 공식 문서: Tunix Documentation
- Tunix GitHub 저장소: Tunix Repository
또한, 이 글의 근거 자료인 Google Developers Blog 원문에서 더 많은 예제와 설명을 확인할 수 있습니다.
함께 보면 좋은 글: