TorchTune BuiltinTrainer

How to fine-tune LLMs with TorchTune BuiltinTrainer

This guide describes how to fine-tune LLMs using BuiltinTrainer and TorchTuneConfig. TorchTuneConfig leverages TorchTune to streamline LLMs fine-tuning on Kubernetes. To understand the concept of BuiltinTrainer, see the overview guide.

Prerequisites

Before exploring this guide, make sure to follow the Getting Started guide to understand the basics of Kubeflow Trainer.

# Get torchtune-llama3.2-1b TorchTune runtime.
from kubeflow.trainer import *

# List all TorchTune runtimes
client = TrainerClient()
for r in TrainerClient().list_runtimes():
    if r.name.startswith("torchtune"):
        print(r)

runtime = client.get_runtime("torchtune-llama3.2-1b")

Output:

Runtime(name='torchtune-llama3.2-1b', trainer=Trainer(trainer_type=<TrainerType.BUILTIN_TRAINER: 'BuiltinTrainer'>, framework=<Framework.TORCHTUNE: 'torchtune'>, entrypoint=['tune', 'run'], accelerator='Unknown', accelerator_count='2.0'), pretrained_model=None)
Runtime(name='torchtune-llama3.2-3b', trainer=Trainer(trainer_type=<TrainerType.BUILTIN_TRAINER: 'BuiltinTrainer'>, framework=<Framework.TORCHTUNE: 'torchtune'>, entrypoint=['tune', 'run'], accelerator='Unknown', accelerator_count='2.0'), pretrained_model=None)

Fine-Tune LLM with TorchTune BuiltinTrainer

The guide below shows how to fine-tune Llama-3.2-1B-Instruct with alpaca dataset by BuiltinTrainer.

Create PVCs for Models and Datasets

We need to manually create PVCs for each models we want to fine-tune. Please note that the PVC name must be equal to the TorchTune runtime’s name. In this example, it’s torchtune-llama3.2-1b.

# Create a PersistentVolumeClaim for the TorchTune Llama 3.2 1B model.
client.core_api.create_namespaced_persistent_volume_claim(
  body=client.V1PersistentVolumeClaim(
    api_version="v1",
    kind="PersistentVolumeClaim",
    metadata=client.V1ObjectMeta(name="torchtune-llama3.2-1b"),
    spec=client.V1PersistentVolumeClaimSpec(
      access_modes=["ReadWriteOnce"],
      resources=client.V1ResourceRequirements(
        requests={"storage": "20Gi"}
      ),
    ),
  ),
)

Use TorchTune BuiltinTrainer with train() API

You need to provide the following parameters to use TorchTune BuiltinTrainer with train() API:

  • Reference to existing Runtimes.
  • Initializer parameters.
  • Trainer configuration.

Kubeflow TrainJob will train the model using the referenced TorchTune runtime torchtune-llama3.2-1b.

For example, you can use the train() API to fine-tune the Llama-3.2-1B-Instruct model using the alpaca dataset from HuggingFace with the code below (need 2 GPUs):

job_name = client.train(
    runtime=client.get_runtime("torchtune-llama3.2-1b"),
    initializer=Initializer(
        model=HuggingFaceModelInitializer(
            storage_uri="hf://meta-llama/Llama-3.2-1B-Instruct",
            access_token="<YOUR_HF_TOKEN>"  # Replace with your Hugging Face token
        )
    ),
    trainer=BuiltinTrainer(
        config=TorchTuneConfig(
            dataset_preprocess_config=TorchTuneInstructDataset(
                source=DataFormat.PARQUET,
            ),
        )
    )
)

Watch the TrainJob Logs

We can use the get_job_logs() API to get the TrainJob logs.

Dataset Initializer

from kubeflow.trainer.constants import constants

log_dict = client.get_job_logs(job_name, step=constants.DATASET_INITIALIZER)
print(log_dict[constants.DATASET_INITIALIZER])

Output:

2025-07-02T08:24:01Z INFO     [__main__.py:16] Starting dataset initialization
2025-07-02T08:24:01Z INFO     [huggingface.py:28] Downloading dataset: tatsu-lab/alpaca
2025-07-02T08:24:01Z INFO     [huggingface.py:29] ----------------------------------------
Fetching 3 files: 100%|██████████| 3/3 [00:01<00:00,  1.82it/s]
2025-07-02T08:24:04Z INFO     [huggingface.py:40] Dataset has been downloaded

Model Initializer

log_dict = client.get_job_logs(job_name, step=constants.MODEL_INITIALIZER)
print(log_dict[constants.MODEL_INITIALIZER])

Output:

2025-07-02T08:24:23Z INFO     [__main__.py:16] Starting pre-trained model initialization
2025-07-02T08:24:23Z INFO     [huggingface.py:26] Downloading model: meta-llama/Llama-3.2-1B-Instruct
2025-07-02T08:24:23Z INFO     [huggingface.py:27] ----------------------------------------
Fetching 8 files: 100%|██████████| 8/8 [01:02<00:00,  7.87s/it]
2025-07-02T08:25:27Z INFO     [huggingface.py:43] Model has been downloaded

Training Node

log_dict = client.get_job_logs(job_name, follow=False)
print(log_dict[f"{constants.NODE}-0"])

Output:

INFO:torchtune.utils._logging:Running FullFinetuneRecipeDistributed with resolved config:

<Complete TorchTune config used by `tune run` command>

DEBUG:torchtune.utils._logging:Setting manual seed to local seed 3686749453. Local seed is seed + rank = 3686749453 + 0
Writing logs to /workspace/output/logs/log_1751444966.txt
INFO:torchtune.utils._logging:Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ...
INFO:torchtune.utils._logging:Instantiating model and loading checkpoint took 17.31 secs
INFO:torchtune.utils._logging:Memory stats after model init:
	GPU peak memory allocation: 2.33 GiB
	GPU peak memory reserved: 2.34 GiB
	GPU peak memory active: 2.33 GiB
/opt/conda/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py:4631: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. 
  warnings.warn(  # warn only once
INFO:torchtune.utils._logging:Optimizer is initialized.
INFO:torchtune.utils._logging:Loss is initialized.
Generating train split: 52002 examples [00:00, 192755.58 examples/s]
INFO:torchtune.utils._logging:No learning rate scheduler configured. Using constant learning rate.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
1|1625|Loss: 1.5839784145355225: 100%|██████████| 1625/1625 [21:54<00:00,  1.23it/s]INFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...
INFO:torchtune.utils._logging:Getting full model state dict took 1.69 secs
INFO:torchtune.utils._logging:Model checkpoint of size 2.30 GiB saved to /workspace/output/epoch_0/model-00001-of-00001.safetensors
INFO:torchtune.utils._logging:Saving final epoch checkpoint.
INFO:torchtune.utils._logging:The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
INFO:torchtune.utils._logging:Saving checkpoint took 6.64 secs
1|1625|Loss: 1.5839784145355225: 100%|██████████| 1625/1625 [22:01<00:00,  1.23it/s]
Running with torchrun...

Get the Fine-Tuned Model

After TrainJob completes the fine-tuning task, the fine-tuned model will be stored into the /workspace/output directory, which can be shared across Pods through PVC mounting. You can find it in another Pod’s /<mountDir>/output directory if you mount the PVC under /<mountDir>.

Parameters

Runtime

For TorchTune BuiltinTrainer, you can just find the runtime you want in the manifest folder, and get it using get_runtime. Like:

runtime=client.get_runtime("torchtune-llama3.2-1b")

Initializer

We’ll use parameters in Initializer to download dataset and model from remote storage.

Dataset Initializer

Currently, we only support downloading datasets from HuggingFace by defining HuggingFaceDatasetInitializer:

Model Initializer

Currently, we only support downloading models from HuggingFace by defining HuggingFaceModelInitializer:

Example Usage

initializer=Initializer(
    dataset=HuggingFaceDatasetInitializer(
        storage_uri="hf://tatsu-lab/alpaca/data",
        access_token="<YOUR_HF_TOKEN>"  # Replace with your Hugging Face token
    ),
    model=HuggingFaceModelInitializer(
        storage_uri="hf://meta-llama/Llama-3.2-1B-Instruct",
        access_token="<YOUR_HF_TOKEN>"  # Replace with your Hugging Face token
    )
)

TorchTune BuiltinTrainer

Description

The TorchTuneConfig class is used for configuring TorchTune BuiltinTrainer that already includes the fine-tuning logic. You can find the API definition here.

Example Usage

torchtune_config = TorchTuneConfig(
    dtype=DataType.BF16,
    batch_size=10,
    epochs=10,
    loss=Loss.CEWithChunkedOutputLoss,
    num_nodes=1,
    dataset_preprocess_config=TorchTuneInstructDataset(
        source=DataFormat.PARQUET,
        split="train[:95%]",
        train_on_input=True,
        new_system_prompt="You are an AI assistant. ",
        column_map={"input": "incorrect", "output": "correct"},
    ),
    resources_per_node={"gpu": 1},
)

Next Steps

Feedback

Was this page helpful?