Skip to main content
Version: Next

Post-Training

Post-training in Llama Stack allows you to fine-tune models using various providers and frameworks. This section covers all available post-training providers and how to use them effectively.

Overview​

Llama Stack provides multiple post-training providers:

  • HuggingFace SFTTrainer (inline::huggingface) - Fine-tuning using HuggingFace ecosystem
  • TorchTune (inline::torchtune) - Fine-tuning using Meta's TorchTune framework
  • NVIDIA (remote::nvidia) - Fine-tuning using NVIDIA's platform

HuggingFace SFTTrainer​

HuggingFace SFTTrainer is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets.

Features​

  • Simple access through the post_training API
  • Fully integrated with Llama Stack
  • GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)

Configuration​

FieldTypeRequiredDefaultDescription
devicestrNocuda
distributed_backendLiteral['fsdp', 'deepspeed']No
checkpoint_formatLiteral['full_state', 'huggingface']Nohuggingface
chat_templatestrNo
model_specific_configdictNo{'trust_remote_code': True, 'attn_implementation': 'sdpa'}
max_seq_lengthintNo2048
gradient_checkpointingboolNoFalse
save_total_limitintNo3
logging_stepsintNo10
warmup_ratiofloatNo0.1
weight_decayfloatNo0.01
dataloader_num_workersintNo4
dataloader_pin_memoryboolNoTrue

Sample Configuration​

checkpoint_format: huggingface
distributed_backend: null
device: cpu

Setup​

You can access the HuggingFace trainer via the starter distribution:

llama stack build --distro starter --image-type venv
llama stack run --image-type venv ~/.llama/distributions/starter/starter-run.yaml

Usage Example​

import time
import uuid

from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)

def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")

client = create_http_client()

# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)

training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)

algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)

job_uuid = f"test-job{uuid.uuid4()}"

# Example Model
training_model = "ibm-granite/granite-3.3-8b-instruct"

start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)

# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break

print(status)
if status.status == "completed":
break

print("Waiting for job to complete...")
time.sleep(5)

end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")

print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))

TorchTune​

TorchTune is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.

Features​

  • Simple access through the post_training API
  • Fully integrated with Llama Stack
  • GPU support and single device capabilities
  • Support for LoRA

Configuration​

FieldTypeRequiredDefaultDescription
torch_seedint | NoneNo
checkpoint_formatLiteral['meta', 'huggingface']Nometa

Sample Configuration​

checkpoint_format: meta

Setup​

You can access the TorchTune trainer by writing your own yaml pointing to the provider:

post_training:
- provider_id: torchtune
provider_type: inline::torchtune
config: {}

You can then build and run your own stack with this provider.

Usage Example​

import time
import uuid

from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)

def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")

client = create_http_client()

# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)

training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)

algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)

job_uuid = f"test-job{uuid.uuid4()}"

# Example Model
training_model = "meta-llama/Llama-2-7b-hf"

start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)

# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break

print(status)
if status.status == "completed":
break

print("Waiting for job to complete...")
time.sleep(5)

end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")

print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))

NVIDIA​

NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.

Configuration​

FieldTypeRequiredDefaultDescription
api_keystr | NoneNoThe NVIDIA API key.
dataset_namespacestr | NoneNodefaultThe NVIDIA dataset namespace.
project_idstr | NoneNotest-example-model@v1The NVIDIA project ID.
customizer_urlstr | NoneNoBase URL for the NeMo Customizer API
timeoutintNo300Timeout for the NVIDIA Post Training API
max_retriesintNo3Maximum number of retries for the NVIDIA Post Training API
output_model_dirstrNotest-example-model@v1Directory to save the output model

Sample Configuration​

api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}

Best Practices​

  • Choose the right provider: Use HuggingFace for broader compatibility, TorchTune for Meta models, or NVIDIA for their ecosystem
  • Configure hardware appropriately: Ensure your configuration matches your available hardware (CPU, GPU, MPS)
  • Monitor jobs: Always monitor job status and handle completion appropriately
  • Use appropriate datasets: Ensure your dataset format matches the expected input format for your chosen provider

Next Steps​