Fine-Tuning Stable Diffusion XL with LoRA on AWS SageMaker: A Production-Scale ML Engineering Journey
In the rapidly evolving landscape of AI-generated imagery, the ability to fine-tune large diffusion models for specific use cases has become increasingly crucial.
This post details our comprehensive approach to fine-tuning Stable Diffusion XL (SDXL) using Low-Rank Adaptation (LoRA) on AWS SageMaker, including our infrastructure setup, training pipeline, hyperparameter optimization experiments, and production deployment strategies.
Our goal was to create a scalable, production-ready system for fine-tuning SDXL on custom datasets, specifically focusing on generating storyboard-style images for creative content. We built a complete end-to-end pipeline that handles everything from data preparation to model deployment, with a strong emphasis on experimentation and optimization.
Architecture Overview
Our solution leverages several AWS services and modern ML engineering practices:
- AWS SageMaker for distributed training and inference endpoints
- Amazon S3 for data storage and model artifacts
- Amazon ECR for custom Docker containers
- Hugging Face Diffusers for the SDXL implementation
- LoRA for efficient fine-tuning with reduced memory requirements
The entire pipeline is orchestrated through Python scripts that manage the infrastructure, data processing, training, and deployment phases.
Cloud Infrastructure Setup
AWS Configuration
The foundation of our system starts with proper AWS configuration. We use AWS profiles to manage different environments and ensure proper IAM permissions:
boto_session = boto3.Session(
profile_name=config.profile_name,
region_name=config.region_name
)
sm_session = sagemaker.session.Session(boto_session=boto_session)
Our configuration is centralized in a YAML file that defines all the necessary parameters:
environment:
iam_role: arn:aws:iam::990321314519:role/service-role/AmazonSageMaker-ExecutionRole-20230729T210439
s3_bucket: tjk-thumbnail-generation
s3_base_prefix: sdxl-finetune-lora/mvp-finetuning
s3_images_prefix: images
s3_dataset_prefix: dataset
s3_models_prefix: models
S3 Organization
We structured our S3 buckets with clear separation of concerns:
- Raw Dataset:
s3://{bucket}/{base_prefix}/raw_{dataset_prefix}/
- Processed Dataset:
s3://{bucket}/{base_prefix}/proc_{dataset_prefix}/
- Model Artifacts:
s3://{bucket}/{base_prefix}/{models_prefix}/
This organization allows for easy tracking of data lineage and model versions throughout the experimentation process.
Data Preparation Pipeline
Image Captioning with InstructBLIP
One of the critical components of our pipeline is the automated image captioning system. We use Salesforce's InstructBLIP model to generate high-quality captions for our training images:
HF_MODEL_ID = "Salesforce/instructblip-vicuna-7b"
MODEL_GEN_CONFIG = {
"max_length": 512,
"min_length": 81,
"do_sample": True,
"num_beams": 5,
"temperature": 1.0,
"top_p": 0.9,
"repetition_penalty": 1.5,
"length_penalty": 1.0,
}
The captioning process runs on GPU-enabled SageMaker instances (ml.g4dn.12xlarge) and generates detailed descriptions that capture the essence of each image. For our storyboard dataset, we crafted specific prompts to ensure the captions included relevant details about composition, style, and narrative elements.
Containerized Data Processing
We containerized the data preparation steps using Docker to ensure reproducibility and scalability:
FROM 763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04
RUN pip3 install --upgrade pip && \
pip3 install --no-cache-dir \
accelerate==0.22.0 \
bitsandbytes==0.41.1 \
diffusers==0.20.0 \
safetensors==0.3.3 \
transformers==4.32.0 \
xformers==0.0.19
The container build process is automated through a bash script that handles ECR authentication, image building, and pushing:
#!/bin/bash
aws ecr get-login-password --region "${AWS_REGION}" | docker login --username AWS --password-stdin "${ECR_REPO_FULLNAME}"
docker build $BUILD_ARG -f ${DOCKERFILE_PATH} -t ${REPOSITORY_NAME} .
docker tag "${REPOSITORY_NAME}" "${ECR_REPO_FULLNAME}"
docker push "${ECR_REPO_FULLNAME}"
SDXL LoRA Training Implementation
Why LoRA?
LoRA (Low-Rank Adaptation) allows us to fine-tune large models efficiently by training only a small number of parameters. For SDXL, this means:
- Reduced memory requirements (crucial for fitting on single GPUs)
- Faster training times
- Smaller model artifacts (only the LoRA weights need to be stored)
- Easy switching between different fine-tuned versions
Training Configuration
Our training script implements several key optimizations:
hyperparameters = {
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0",
"resolution": 1024,
"train_batch_size": 1,
"learning_rate": 2e-4,
"lr_scheduler": "constant",
"num_train_epochs": 58,
"checkpointing_steps": 500,
"rank": 4, # LoRA rank
"train_text_encoder": True,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"enable_xformers_memory_efficient_attention": True,
"use_8bit_adam": True,
"mixed_precision": "fp16"
}
Key decisions in our configuration:
-
LoRA Rank: We experimented with ranks from 4 to 1024, finding that lower ranks (4-128) provided good results while keeping training efficient.
-
Text Encoder Training: We chose to also fine-tune the text encoders, which improved prompt adherence for our specific use case.
-
Memory Optimizations:
- Gradient accumulation to simulate larger batch sizes
- Gradient checkpointing to reduce memory usage
- xformers for efficient attention computation
- 8-bit Adam optimizer
- FP16 mixed precision training
Multi-GPU Training with DeepSpeed
For larger experiments, we configured DeepSpeed for distributed training across multiple GPUs:
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
mixed_precision: fp16
num_processes: 4
Grid Search Experiments
Hyperparameter Optimization
We implemented a comprehensive grid search system to find optimal hyperparameters:
parameters = {
"learning_rate": [0.0001, 0.00015, 0.0002],
"max_train_steps": [500, 1000, 2000, 3000],
"rank": [4, 150, 256, 512, 1024]
}
Our grid search infrastructure included:
- Automated Job Submission: Each parameter combination spawns a separate SageMaker training job
- Rate Limiting: To avoid overwhelming AWS quotas, we implemented a rate limiter:
def count_running_jobs():
sagemaker = boto3.client('sagemaker')
list_jobs_response = sagemaker.list_training_jobs(
StatusEquals='InProgress',
MaxResults=100,
)
return len(list_jobs_response['TrainingJobSummaries'])
# Rate limiting logic
while count_running_jobs() >= rate_limit:
print(f"Rate limit reached. Sleeping for {sleep_time} seconds.")
time.sleep(sleep_time)
- Systematic Naming: Jobs are named with their parameters for easy tracking:
sdxl-lora-gridsearch-run-16-lr-0-0002-mts-1500-r-4
Key Findings from Experiments
Through our extensive grid search experiments (trials 11-16), we discovered:
- Learning Rate: 2e-4 (0.0002) provided the best balance between training speed and stability
- Training Steps: 1500-2000 steps were sufficient for our dataset size
- LoRA Rank: Rank 4 was surprisingly effective, with diminishing returns beyond rank 128
- Epochs: 58 epochs gave us the best results for our final production model
Inference and Deployment
SageMaker Endpoint Configuration
We deployed our fine-tuned models as SageMaker endpoints for production inference:
model = HuggingFaceModel(
model_data=model_data,
role=role,
entry_point="inference.py",
transformers_version="4.28.1",
pytorch_version="2.0.0",
py_version="py310",
source_dir=os.path.join("code", "inference", "sdxl"),
sagemaker_session=sm_session,
)
model.deploy(
initial_instance_count=1,
instance_type="ml.p4d.24xlarge",
endpoint_name="mvp-finetuning-58-epochs-test-p4d-24xl"
)
Dynamic LoRA Loading
Our inference script supports dynamic LoRA scale adjustment, allowing us to control the influence of fine-tuning at inference time:
def predict_fn(data, model):
lora_scale = data.pop("lora_scale")
model.unfuse_lora()
model.load_lora_weights("/opt/ml/model", weight_name="pytorch_lora_weights.safetensors")
model.fuse_lora(lora_scale=lora_scale)
generated_images = model(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
).images
Performance Optimization
We tested various instance types for inference:
- ml.g5.2xlarge: Cost-effective for development
- ml.p4d.24xlarge: High-performance for production
- ml.p3.2xlarge: Good balance for medium workloads
Generation times averaged 17-18 seconds per image on p4d.24xlarge instances, as shown in our performance benchmarks.
Production Considerations
Cost Optimization
- Spot Instances: We use spot instances for training when possible, reducing costs by up to 70%
- Checkpoint Recovery: Regular checkpointing allows us to resume from interruptions
- Efficient Storage: Only storing LoRA weights (typically <200MB) instead of full models (>10GB)
Monitoring and Logging
We integrated comprehensive logging throughout the pipeline:
logger.info("Fine-tuning the image generative model and deploying an endpoint to %s is complete.", config.endpoint_name)
SageMaker CloudWatch integration provides real-time monitoring of:
- Training metrics (loss, learning rate)
- Instance utilization
- Inference latency and throughput
Scalability
Our architecture scales horizontally:
- Multiple training jobs can run concurrently
- Inference endpoints auto-scale based on load
- Data processing leverages SageMaker Processing jobs for parallelization
Lessons Learned
- Start Small: Begin with low-rank LoRA (rank 4-8) before scaling up
- Caption Quality Matters: Investing in high-quality captions significantly improves results
- Instance Selection: ml.g5.2xlarge offers the best price/performance for most training tasks
- Checkpoint Frequently: Save checkpoints every 500 steps to enable experimentation
- Monitor Costs: Set up AWS Cost Explorer alerts for training jobs
Future Improvements
- Automated Hyperparameter Tuning: Integrate SageMaker's built-in hyperparameter tuning
- Multi-Model Endpoints: Deploy multiple LoRA adapters on a single endpoint
- A/B Testing: Implement systematic A/B testing for different model versions
- Custom Schedulers: Experiment with more sophisticated learning rate schedules
Conclusion
Building a production-ready SDXL fine-tuning pipeline requires careful orchestration of multiple components. By leveraging AWS SageMaker's managed infrastructure, containerization for reproducibility, and systematic experimentation, we created a robust system capable of producing high-quality, domain-specific image generation models.
The combination of LoRA for efficient fine-tuning, comprehensive grid search for optimization, and scalable deployment infrastructure provides a solid foundation for any team looking to customize large diffusion models for their specific use cases.
The complete codebase demonstrates that with the right architecture and tooling, fine-tuning state-of-the-art image generation models is accessible and practical for production deployments.
This project represents months of experimentation and optimization. Special thanks to the AWS SageMaker team for their excellent documentation and the Hugging Face community for the incredible Diffusers library.