들어가며: 왜 FunctionGemma를 TPU로 파인튜닝해야 할까?

FunctionGemma는 자연어를 API 호출로 변환하는 에이전트를 엣지 디바이스에서도 빠르게 실행할 수 있도록 설계된 경량 언어 모델입니다. 기존에는 Hugging Face TRL + GPU 조합으로 파인튜닝하는 가이드가 많았지만, GPU 비용이 부담스럽거나 TPU 인프라를 활용하고 싶은 분들에게는 다른 접근이 필요했죠.

이 글에서는 Google Tunix라는 JAX 기반 라이브러리를 사용해 Colab 무료 티어 TPU v5e-1에서 FunctionGemma를 LoRA 방식으로 파인튜닝하는 방법을 단계별로 소개합니다. GPU 대비 TPU의 장점은 무엇보다 비용 효율성JAX 네이티브 성능입니다. 특히 대규모 배치 처리나 반복 실험이 필요한 연구/프로토타이핑 단계에서 TPU는 강력한 대안이 됩니다.

참고: 이 튜토리얼은 메타가 공개한 RCCLX AMD 플랫폼 GPU 통신 성능을 혁신하다에서 다룬 GPU 통신 최적화와는 다른 축이지만, 하드웨어 가속기 선택의 폭을 넓힌다는 점에서 함께 보시면 좋습니다.

준비물

  • Google Colab (TPU 런타임 설정: 런타임 → 런타임 유형 변경 → TPU v5e-1)
  • Hugging Face 계정 (FunctionGemma 모델 접근 권한 필요)
  • 기본적인 Python, JAX, Transformers 지식

Developer using Google Colab TPU to fine-tune FunctionGemma model with Tunix library Development Concept Image

실전 파인튜닝: Tunix로 FunctionGemma LoRA 학습하기

1. 모델 및 데이터셋 다운로드

Hugging Face Hub에서 FunctionGemma(270M)와 Mobile Action 데이터셋을 내려받습니다.

from huggingface_hub import snapshot_download, hf_hub_download

MODEL_ID = "google/functiongemma-270m-it"
DATASET_ID = "google/mobile-actions"

# safetensors만 다운로드 (PyTorch 가중치 제외)
local_model_path = snapshot_download(
    repo_id=MODEL_ID,
    ignore_patterns=["*.pth"]
)
data_file = hf_hub_download(
    repo_id=DATASET_ID,
    filename="dataset.jsonl",
    repo_type="dataset"
)

2. TPU 메시 구성

Colab 무료 티어는 단일 코어 TPU이므로, 샤딩 없이 간단한 메시를 만듭니다.

import jax

NUM_TPUS = len(jax.devices())
# 단일 코어: (1, 1) 메시
MESH = [(1, 1), ("fsdp", "tp")]
mesh = jax.make_mesh(
    *MESH,
    axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0])
)

3. LoRA 어댑터 적용

Tunix는 create_model_from_safe_tensors()로 safetensors를 바로 로드하고, Qwix로 LoRA를 적용합니다.

import qwix
from tunix import params_safetensors_lib
from tunix import nnx

with mesh:
    # 베이스 모델 로드
    base_model = params_safetensors_lib.create_model_from_safe_tensors(
        local_model_path, model_config, mesh
    )
    # LoRA 프로바이더: attention 레이어에 적용
    lora_provider = qwix.LoraProvider(
        module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj",
        rank=LORA_RANK,
        alpha=LORA_ALPHA,
    )
    model_input = base_model.get_model_input()
    model = qwix.apply_lora_to_model(
        base_model, lora_provider,
        rngs=nnx.Rngs(0), **model_input
    )
    # 상태 샤딩
    state = nnx.state(model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(model, sharded_state)

4. 커스텀 데이터셋 (Completion-only Loss)

Completion-only loss를 지원하기 위해 __iter__에서 프롬프트 마스킹을 구현합니다.

class CustomDataset:
    def __init__(self, data, tokenizer, max_length=1024):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        for item in self.data:
            template_inputs = json.loads(item['text'])
            # 전체 텍스트 (프롬프트 + 완성)
            prompt_and_completion = self.tokenizer.apply_chat_template(
                template_inputs['messages'],
                tools=template_inputs['tools'],
                tokenize=False,
                add_generation_prompt=False
            )
            # 프롬프트만 (완성 부분 제외)
            prompt_only = self.tokenizer.apply_chat_template(
                template_inputs['messages'][:-1],
                tools=template_inputs['tools'],
                tokenize=False,
                add_generation_prompt=True
            )
            tokenized_full = self.tokenizer(
                prompt_and_completion, add_special_tokens=False
            )
            tokenized_prompt = self.tokenizer(
                prompt_only, add_special_tokens=False
            )
            full_ids = tokenized_full['input_ids']
            prompt_len = len(tokenized_prompt['input_ids'])
            if len(full_ids) > self.max_length:
                full_ids = full_ids[:self.max_length]
            input_tokens = np.full(
                (self.max_length,), self.tokenizer.pad_token_id, dtype=np.int32
            )
            input_tokens[:len(full_ids)] = full_ids
            input_mask = np.zeros((self.max_length,), dtype=np.int32)
            if len(full_ids) > prompt_len:
                mask_end = min(len(full_ids), self.max_length)
                input_mask[prompt_len:mask_end] = 1
            yield peft_trainer.TrainingInput(
                input_tokens=jnp.array(input_tokens, dtype=jnp.int32),
                input_mask=jnp.array(input_mask, dtype=jnp.int32)
            )

5. 학습 실행

trainer = peft_trainer.PeftTrainer(
    model,
    optax.adamw(lr_schedule),
    training_config
).with_gen_model_input_fn(gen_model_input_fn)

with mesh:
    trainer.train(train_batches, val_batches)

6. LoRA 병합 및 내보내기

from tunix import gemma_params

merged_output_dir = os.path.join(OUTPUT_DIR, "merged")
gemma_params.save_lora_merged_model_as_safetensors(
    local_model_path=local_model_path,
    output_dir=merged_output_dir,
    lora_model=model,
    rank=LORA_RANK,
    alpha=LORA_ALPHA,
)
print(f"병합된 LoRA 모델 저장 완료: {merged_output_dir}")

Diagram of Tunix JAX-based LLM fine-tuning workflow on TPU accelerator Dev Environment Setup

주의사항 및 실무 팁

TPU 사용 시 알아둘 점

  • Colab 무료 티어 TPU v5e-1은 단일 코어이므로, 대규모 모델(7B 이상) 파인튜닝에는 적합하지 않습니다. FunctionGemma 270M 정도가 적당합니다.
  • TPU는 JAX에 최적화되어 있으므로, PyTorch 코드를 그대로 가져오면 동작하지 않습니다. optax, jax.numpy 등 JAX 생태계에 익숙해져야 합니다.
  • Tunix는 아직 초기 단계(2025년 기준)이므로, 프로덕션에 적용하기 전에 충분한 테스트가 필요합니다.

한국 개발 생태계에서의 적용 맥락

  • 국내 클라우드 환경(GCP, NCP 등)에서 TPU를 직접 사용하는 사례는 아직 드물지만, 비용 최적화 측면에서 GPU 대비 TPU가 매력적인 선택지가 될 수 있습니다. 특히 스타트업이나 연구소에서 반복적인 실험을 많이 하는 경우 TPU 스팟 인스턴스를 활용하면 비용을 크게 줄일 수 있습니다.
  • 네이버 클라우드나 KT Cloud 등 국내 CSP에서도 TPU 서비스를 도입하는 추세이므로, 미리 JAX/Tunix 스택에 익숙해져 두면 향후 인프라 선택 폭이 넓어집니다.

이 기술의 한계

  • Tunix는 아직 Hugging Face TRL만큼 커뮤니티가 크지 않아, 문제 발생 시 참고할 자료가 부족할 수 있습니다.
  • TPU는 GPU와 달리 동적 shape 변경에 취약하므로, 데이터셋 전처리 시 max_length로 패딩을 고정해야 합니다.
  • FunctionGemma 자체가 소형 모델이므로, 복잡한 API 스키마를 학습하기에는 파라미터 수가 부족할 수 있습니다.

AI agent generating API calls from natural language using fine-tuned FunctionGemma model Developer Related Image

마무리: 다음 스텝은?

이 튜토리얼에서는 LoRA 기반 지도 파인튜닝만 다뤘지만, Tunix는 **강화학습(RL)**이나 선호도 튜닝(DPO) 등 더 고급 기법도 지원합니다. 관심 있으신 분들은 아래 자료를 참고해 보세요.

또한, 이 글의 근거 자료인 Google Developers Blog 원문에서 더 많은 예제와 설명을 확인할 수 있습니다.

함께 보면 좋은 글:

본 콘텐츠는 신뢰할 수 있는 출처를 바탕으로 AI 도구를 활용하여 초안이 작성되었으며, 편집자의 검토를 거쳐 발행되었습니다. 전문가의 조언을 대체하지 않습니다.