Post-Training
Post-training in Llama Stack allows you to fine-tune models using various providers and frameworks. This section covers all available post-training providers and how to use them effectively.
Overview​
Llama Stack provides multiple post-training providers:
- HuggingFace SFTTrainer (
inline::huggingface
) - Fine-tuning using HuggingFace ecosystem - TorchTune (
inline::torchtune
) - Fine-tuning using Meta's TorchTune framework - NVIDIA (
remote::nvidia
) - Fine-tuning using NVIDIA's platform
HuggingFace SFTTrainer​
HuggingFace SFTTrainer is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets.
Features​
- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)
Configuration​
Field | Type | Required | Default | Description |
---|---|---|---|---|
device | str | No | cuda | |
distributed_backend | Literal['fsdp', 'deepspeed'] | No | ||
checkpoint_format | Literal['full_state', 'huggingface'] | No | huggingface | |
chat_template | str | No | ||
model_specific_config | dict | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | |
max_seq_length | int | No | 2048 | |
gradient_checkpointing | bool | No | False | |
save_total_limit | int | No | 3 | |
logging_steps | int | No | 10 | |
warmup_ratio | float | No | 0.1 | |
weight_decay | float | No | 0.01 | |
dataloader_num_workers | int | No | 4 | |
dataloader_pin_memory | bool | No | True |
Sample Configuration​
checkpoint_format: huggingface
distributed_backend: null
device: cpu
Setup​
You can access the HuggingFace trainer via the starter
distribution:
llama stack build --distro starter --image-type venv
llama stack run --image-type venv ~/.llama/distributions/starter/starter-run.yaml
Usage Example​
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model
training_model = "ibm-granite/granite-3.3-8b-instruct"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
TorchTune​
TorchTune is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.
Features​
- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support and single device capabilities
- Support for LoRA
Configuration​
Field | Type | Required | Default | Description |
---|---|---|---|---|
torch_seed | int | None | No | ||
checkpoint_format | Literal['meta', 'huggingface'] | No | meta |
Sample Configuration​
checkpoint_format: meta
Setup​
You can access the TorchTune trainer by writing your own yaml pointing to the provider:
post_training:
- provider_id: torchtune
provider_type: inline::torchtune
config: {}
You can then build and run your own stack with this provider.
Usage Example​
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model
training_model = "meta-llama/Llama-2-7b-hf"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
NVIDIA​
NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.
Configuration​
Field | Type | Required | Default | Description |
---|---|---|---|---|
api_key | str | None | No | The NVIDIA API key. | |
dataset_namespace | str | None | No | default | The NVIDIA dataset namespace. |
project_id | str | None | No | test-example-model@v1 | The NVIDIA project ID. |
customizer_url | str | None | No | Base URL for the NeMo Customizer API | |
timeout | int | No | 300 | Timeout for the NVIDIA Post Training API |
max_retries | int | No | 3 | Maximum number of retries for the NVIDIA Post Training API |
output_model_dir | str | No | test-example-model@v1 | Directory to save the output model |
Sample Configuration​
api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
Best Practices​
- Choose the right provider: Use HuggingFace for broader compatibility, TorchTune for Meta models, or NVIDIA for their ecosystem
- Configure hardware appropriately: Ensure your configuration matches your available hardware (CPU, GPU, MPS)
- Monitor jobs: Always monitor job status and handle completion appropriately
- Use appropriate datasets: Ensure your dataset format matches the expected input format for your chosen provider
Next Steps​
- Check out the Building Applications - Fine-tuning guide for application-level examples
- See the Providers section for detailed provider documentation
- Review the API Reference for complete API documentation