はじめに: なぜ FunctionGemma を TPU でファインチューニングするのか
FunctionGemma は、自然言語を API 呼び出しに変換するエージェントをエッジデバイスでも高速に実行できるよう設計された軽量言語モデルです。これまでは Hugging Face TRL + GPU の組み合わせが一般的でしたが、GPU コストが気になる方や TPU インフラを活用したい方には別のアプローチが必要でした。
本記事では Google Tunix という JAX ベースのライブラリを使用し、Colab 無料ティア TPU v5e-1 で FunctionGemma を LoRA 方式でファインチューニングする方法をステップバイステップで紹介します。GPU に対する TPU のメリットは、何と言っても コスト効率 と JAX ネイティブのパフォーマンス です。特に大規模バッチ処理や反復実験が必要な研究・プロトタイピング段階では、TPU は強力な選択肢となります。
参考: 本チュートリアルは メタが公開した RCCLX AMD プラットフォーム GPU 通信性能を革新する で扱った GPU 通信最適化とは異なる軸ですが、ハードウェアアクセラレータ選択の幅を広げるという点で併せてご覧ください。
準備
- Google Colab (TPU ランタイム設定: ランタイム → ランタイムタイプ変更 → TPU v5e-1)
- Hugging Face アカウント (FunctionGemma モデルへのアクセス権限が必要)
- Python, JAX, Transformers の基礎知識

実践ファインチューニング: Tunix で FunctionGemma LoRA 学習
1. モデルとデータセットのダウンロード
Hugging Face Hub から FunctionGemma(270M) と Mobile Action データセットをダウンロードします。
from huggingface_hub import snapshot_download, hf_hub_download
MODEL_ID = "google/functiongemma-270m-it"
DATASET_ID = "google/mobile-actions"
# safetensors のみダウンロード (PyTorch 重みは除外)
local_model_path = snapshot_download(
repo_id=MODEL_ID,
ignore_patterns=["*.pth"]
)
data_file = hf_hub_download(
repo_id=DATASET_ID,
filename="dataset.jsonl",
repo_type="dataset"
)
2. TPU メッシュ構成
Colab 無料ティアはシングルコア TPU なので、シャーディングなしでシンプルなメッシュを作成します。
import jax
NUM_TPUS = len(jax.devices())
# シングルコア: (1, 1) メッシュ
MESH = [(1, 1), ("fsdp", "tp")]
mesh = jax.make_mesh(
*MESH,
axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0])
)
3. LoRA アダプターの適用
Tunix は create_model_from_safe_tensors() で safetensors を直接ロードし、Qwix で LoRA を適用します。
import qwix
from tunix import params_safetensors_lib
from tunix import nnx
with mesh:
# ベースモデルのロード
base_model = params_safetensors_lib.create_model_from_safe_tensors(
local_model_path, model_config, mesh
)
# LoRA プロバイダー: attention レイヤーに適用
lora_provider = qwix.LoraProvider(
module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj",
rank=LORA_RANK,
alpha=LORA_ALPHA,
)
model_input = base_model.get_model_input()
model = qwix.apply_lora_to_model(
base_model, lora_provider,
rngs=nnx.Rngs(0), **model_input
)
# 状態のシャーディング
state = nnx.state(model)
pspecs = nnx.get_partition_spec(state)
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state)
4. カスタムデータセット (Completion-only Loss)
Completion-only loss をサポートするため、__iter__ でプロンプトマスキングを実装します。
class CustomDataset:
def __init__(self, data, tokenizer, max_length=1024):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __iter__(self):
for item in self.data:
template_inputs = json.loads(item['text'])
# 全体テキスト (プロンプト + 完了)
prompt_and_completion = self.tokenizer.apply_chat_template(
template_inputs['messages'],
tools=template_inputs['tools'],
tokenize=False,
add_generation_prompt=False
)
# プロンプトのみ (完了部分を除外)
prompt_only = self.tokenizer.apply_chat_template(
template_inputs['messages'][:-1],
tools=template_inputs['tools'],
tokenize=False,
add_generation_prompt=True
)
tokenized_full = self.tokenizer(
prompt_and_completion, add_special_tokens=False
)
tokenized_prompt = self.tokenizer(
prompt_only, add_special_tokens=False
)
full_ids = tokenized_full['input_ids']
prompt_len = len(tokenized_prompt['input_ids'])
if len(full_ids) > self.max_length:
full_ids = full_ids[:self.max_length]
input_tokens = np.full(
(self.max_length,), self.tokenizer.pad_token_id, dtype=np.int32
)
input_tokens[:len(full_ids)] = full_ids
input_mask = np.zeros((self.max_length,), dtype=np.int32)
if len(full_ids) > prompt_len:
mask_end = min(len(full_ids), self.max_length)
input_mask[prompt_len:mask_end] = 1
yield peft_trainer.TrainingInput(
input_tokens=jnp.array(input_tokens, dtype=jnp.int32),
input_mask=jnp.array(input_mask, dtype=jnp.int32)
)
5. 学習の実行
trainer = peft_trainer.PeftTrainer(
model,
optax.adamw(lr_schedule),
training_config
).with_gen_model_input_fn(gen_model_input_fn)
with mesh:
trainer.train(train_batches, val_batches)
6. LoRA マージとエクスポート
from tunix import gemma_params
merged_output_dir = os.path.join(OUTPUT_DIR, "merged")
gemma_params.save_lora_merged_model_as_safetensors(
local_model_path=local_model_path,
output_dir=merged_output_dir,
lora_model=model,
rank=LORA_RANK,
alpha=LORA_ALPHA,
)
print(f"マージされた LoRA モデルを保存: {merged_output_dir}")

注意点と実務上のヒント
TPU 使用時の注意
- Colab 無料ティア TPU v5e-1 はシングルコアのため、大規模モデル(7B 以上)のファインチューニングには適しません。FunctionGemma 270M 程度が適切です。
- TPU は JAX に最適化されているため、PyTorch コードをそのまま流用することはできません。
optax,jax.numpyなど JAX エコシステムに慣れる必要があります。 - Tunix はまだ初期段階(2025年時点)のため、プロダクション適用前には十分なテストが必要です。
日本開発コミュニティでの適用コンテキスト
- 国内クラウド環境(GCP, さくらインターネット等)で TPU を直接使用する事例はまだ少ないですが、コスト最適化の観点で GPU より TPU が魅力的な選択肢になるケースが増えています。特にスタートアップや研究所で反復実験が多い場合、TPU スポットインスタンスを活用するとコストを大幅に削減できます。
- Qiita や Zenn でも TPU 関連の記事は増加傾向にあり、JAX/Tunix スタックに習熟しておくと今後のインフラ選択の幅が広がります。
本技術の限界
- Tunix は Hugging Face TRL ほどコミュニティが大きくないため、問題発生時に参照できる資料が限られています。
- TPU は GPU と異なり動的な shape 変更に弱いため、データセット前処理時に
max_lengthでパディングを固定する必要があります。 - FunctionGemma 自体が軽量モデルであるため、複雑な API スキーマを学習するにはパラメータ数が不足する可能性があります。

まとめ: 次のステップは?
本チュートリアルでは LoRA ベースの教師ありファインチューニングのみを扱いましたが、Tunix は 強化学習(RL) や 選好チューニング(DPO) などより高度な手法もサポートしています。興味のある方は以下の資料をご参照ください。
- Tunix 公式ドキュメント: Tunix Documentation
- Tunix GitHub リポジトリ: Tunix Repository
また、本記事の根拠資料である Google Developers Blog 原文 でさらに多くの例や説明を確認できます。
合わせて読みたい記事: