Fine-tuning allows you to customize pre-trained open-source models for your specific tasks and domain. Vertex AI provides managed fine-tuning services that handle infrastructure provisioning, distributed training, and hyperparameter optimization.
Format your training data in JSONL (JSON Lines) format:
{"messages": [{"role": "user", "content": "What is the capital of France?"}, {"role": "assistant", "content": "The capital of France is Paris."}]}{"messages": [{"role": "user", "content": "Explain photosynthesis"}, {"role": "assistant", "content": "Photosynthesis is the process by which plants convert light energy into chemical energy..."}]}
from datasets import load_dataset# Load dataset from Hugging Facedataset = load_dataset("meta-math/MetaMathQA")["train"]# Split into train/validationsplit_dataset = dataset.train_test_split(test_size=0.2, seed=42)train_split = split_dataset["train"]validation_split = split_dataset["test"]# Limit validation to 5000 examplesif len(validation_split) > 5000: validation_split = validation_split.shuffle(seed=42).select(range(4999))
2
Format Data
METAMATH_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:{instruction}### Response:"""def format_for_tuning(example): query = example["query"] response = example["response"] instruction = METAMATH_TEMPLATE.format(instruction=query) return { "messages": [ {"role": "user", "content": instruction}, {"role": "assistant", "content": f" {response}"} ] }train_formatted = train_split.map(format_for_tuning, remove_columns=train_split.column_names)val_formatted = validation_split.map(format_for_tuning, remove_columns=validation_split.column_names)
3
Save and Upload
import jsondef save_to_jsonl(dataset, output_path): with open(output_path, "w") as f: for example in dataset: json.dump(example, f) f.write("\n")save_to_jsonl(train_formatted, "metamath_train.jsonl")save_to_jsonl(val_formatted, "metamath_val.jsonl")# Upload to GCS!gcloud storage cp metamath_train.jsonl {BUCKET_URI}/datasets/!gcloud storage cp metamath_val.jsonl {BUCKET_URI}/datasets/
# Make predictionprompt_template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response: Let's think step by step."instruction = "James buys 5 packs of beef that are 4 pounds each. The price of beef is $5.50 per pound. How much did he pay?"prediction = endpoint.predict( instances=[{ "prompt": prompt_template.format(instruction=instruction), "max_tokens": 250, "temperature": 0.2, "top_p": 1.0, "top_k": 1 }])print(prediction.predictions[0])
# Full fine-tuning: 1e-5 to 5e-5# LoRA: 1e-4 to 5e-4learning_rates = [1e-5, 2e-5, 5e-5]for lr in learning_rates: job = sft.train( source_model=source_model, learning_rate=lr, # ... other parameters )
# Larger batch sizes for stability# Smaller for memory constraintsbatch_sizes = [8, 16, 32]for bs in batch_sizes: job = sft.train( source_model=source_model, per_device_train_batch_size=bs, # ... other parameters )
# Higher rank = more parameters = better quality# Lower rank = faster training = lower costlora_ranks = [4, 8, 16, 32]for rank in lora_ranks: job = sft.train( source_model=source_model, tuning_mode="LORA", lora_rank=rank, lora_alpha=rank * 2, # Common practice # ... other parameters )
# Use spot instances for trainingjob = sft.train( source_model=source_model, # ... other parameters enable_spot_vm=True, # Up to 80% cost savings spot_vm_retention_time=3600 # 1 hour retention)