Fine-Tuning Stable Diffusion XL with LoRA on AWS SageMaker: A Production-Scale ML Engineering Journey

In the rapidly evolving landscape of AI-generated imagery, the ability to fine-tune large diffusion models for specific use cases has become increasingly crucial.

This post details our comprehensive approach to fine-tuning Stable Diffusion XL (SDXL) using Low-Rank Adaptation (LoRA) on AWS SageMaker, including our infrastructure setup, training pipeline, hyperparameter optimization experiments, and production deployment strategies.

Our goal was to create a scalable, production-ready system for fine-tuning SDXL on custom datasets, specifically focusing on generating storyboard-style images for creative content. We built a complete end-to-end pipeline that handles everything from data preparation to model deployment, with a strong emphasis on experimentation and optimization.

Architecture Overview

Our solution leverages several AWS services and modern ML engineering practices:

  • AWS SageMaker for distributed training and inference endpoints
  • Amazon S3 for data storage and model artifacts
  • Amazon ECR for custom Docker containers
  • Hugging Face Diffusers for the SDXL implementation
  • LoRA for efficient fine-tuning with reduced memory requirements

The entire pipeline is orchestrated through Python scripts that manage the infrastructure, data processing, training, and deployment phases.

Cloud Infrastructure Setup

AWS Configuration

The foundation of our system starts with proper AWS configuration. We use AWS profiles to manage different environments and ensure proper IAM permissions:

boto_session = boto3.Session(
    profile_name=config.profile_name, 
    region_name=config.region_name
)
sm_session = sagemaker.session.Session(boto_session=boto_session)

Our configuration is centralized in a YAML file that defines all the necessary parameters:

environment:
  iam_role: arn:aws:iam::990321314519:role/service-role/AmazonSageMaker-ExecutionRole-20230729T210439
  s3_bucket: tjk-thumbnail-generation
  s3_base_prefix: sdxl-finetune-lora/mvp-finetuning
  s3_images_prefix: images
  s3_dataset_prefix: dataset
  s3_models_prefix: models

S3 Organization

We structured our S3 buckets with clear separation of concerns:

  • Raw Dataset: s3://{bucket}/{base_prefix}/raw_{dataset_prefix}/
  • Processed Dataset: s3://{bucket}/{base_prefix}/proc_{dataset_prefix}/
  • Model Artifacts: s3://{bucket}/{base_prefix}/{models_prefix}/

This organization allows for easy tracking of data lineage and model versions throughout the experimentation process.

Data Preparation Pipeline

Image Captioning with InstructBLIP

One of the critical components of our pipeline is the automated image captioning system. We use Salesforce's InstructBLIP model to generate high-quality captions for our training images:

HF_MODEL_ID = "Salesforce/instructblip-vicuna-7b"
MODEL_GEN_CONFIG = {
    "max_length": 512,
    "min_length": 81,
    "do_sample": True,
    "num_beams": 5,
    "temperature": 1.0,
    "top_p": 0.9,
    "repetition_penalty": 1.5,
    "length_penalty": 1.0,
}

The captioning process runs on GPU-enabled SageMaker instances (ml.g4dn.12xlarge) and generates detailed descriptions that capture the essence of each image. For our storyboard dataset, we crafted specific prompts to ensure the captions included relevant details about composition, style, and narrative elements.

Containerized Data Processing

We containerized the data preparation steps using Docker to ensure reproducibility and scalability:

FROM 763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04

RUN pip3 install --upgrade pip && \
    pip3 install --no-cache-dir \
    accelerate==0.22.0 \
    bitsandbytes==0.41.1 \
    diffusers==0.20.0 \
    safetensors==0.3.3 \
    transformers==4.32.0 \
    xformers==0.0.19

The container build process is automated through a bash script that handles ECR authentication, image building, and pushing:

#!/bin/bash
aws ecr get-login-password --region "${AWS_REGION}" | docker login --username AWS --password-stdin "${ECR_REPO_FULLNAME}"
docker build $BUILD_ARG -f ${DOCKERFILE_PATH} -t ${REPOSITORY_NAME} .
docker tag "${REPOSITORY_NAME}" "${ECR_REPO_FULLNAME}"
docker push "${ECR_REPO_FULLNAME}"

SDXL LoRA Training Implementation

Why LoRA?

LoRA (Low-Rank Adaptation) allows us to fine-tune large models efficiently by training only a small number of parameters. For SDXL, this means:

  • Reduced memory requirements (crucial for fitting on single GPUs)
  • Faster training times
  • Smaller model artifacts (only the LoRA weights need to be stored)
  • Easy switching between different fine-tuned versions

Training Configuration

Our training script implements several key optimizations:

hyperparameters = {
    "pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0",
    "resolution": 1024,
    "train_batch_size": 1,
    "learning_rate": 2e-4,
    "lr_scheduler": "constant",
    "num_train_epochs": 58,
    "checkpointing_steps": 500,
    "rank": 4,  # LoRA rank
    "train_text_encoder": True,
    "gradient_accumulation_steps": 4,
    "gradient_checkpointing": True,
    "enable_xformers_memory_efficient_attention": True,
    "use_8bit_adam": True,
    "mixed_precision": "fp16"
}

Key decisions in our configuration:

  1. LoRA Rank: We experimented with ranks from 4 to 1024, finding that lower ranks (4-128) provided good results while keeping training efficient.

  2. Text Encoder Training: We chose to also fine-tune the text encoders, which improved prompt adherence for our specific use case.

  3. Memory Optimizations:

    • Gradient accumulation to simulate larger batch sizes
    • Gradient checkpointing to reduce memory usage
    • xformers for efficient attention computation
    • 8-bit Adam optimizer
    • FP16 mixed precision training

Multi-GPU Training with DeepSpeed

For larger experiments, we configured DeepSpeed for distributed training across multiple GPUs:

compute_environment: LOCAL_MACHINE
deepspeed_config:
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: false
  zero_stage: 2
distributed_type: DEEPSPEED
mixed_precision: fp16
num_processes: 4

Grid Search Experiments

Hyperparameter Optimization

We implemented a comprehensive grid search system to find optimal hyperparameters:

parameters = {
    "learning_rate": [0.0001, 0.00015, 0.0002],
    "max_train_steps": [500, 1000, 2000, 3000],
    "rank": [4, 150, 256, 512, 1024]
}

Our grid search infrastructure included:

  1. Automated Job Submission: Each parameter combination spawns a separate SageMaker training job
  2. Rate Limiting: To avoid overwhelming AWS quotas, we implemented a rate limiter:
def count_running_jobs():
    sagemaker = boto3.client('sagemaker')
    list_jobs_response = sagemaker.list_training_jobs(
        StatusEquals='InProgress',
        MaxResults=100,
    )
    return len(list_jobs_response['TrainingJobSummaries'])

# Rate limiting logic
while count_running_jobs() >= rate_limit:
    print(f"Rate limit reached. Sleeping for {sleep_time} seconds.")
    time.sleep(sleep_time)
  1. Systematic Naming: Jobs are named with their parameters for easy tracking:
    sdxl-lora-gridsearch-run-16-lr-0-0002-mts-1500-r-4
    

Key Findings from Experiments

Through our extensive grid search experiments (trials 11-16), we discovered:

  1. Learning Rate: 2e-4 (0.0002) provided the best balance between training speed and stability
  2. Training Steps: 1500-2000 steps were sufficient for our dataset size
  3. LoRA Rank: Rank 4 was surprisingly effective, with diminishing returns beyond rank 128
  4. Epochs: 58 epochs gave us the best results for our final production model

Inference and Deployment

SageMaker Endpoint Configuration

We deployed our fine-tuned models as SageMaker endpoints for production inference:

model = HuggingFaceModel(
    model_data=model_data,
    role=role,
    entry_point="inference.py",
    transformers_version="4.28.1",
    pytorch_version="2.0.0",
    py_version="py310",
    source_dir=os.path.join("code", "inference", "sdxl"),
    sagemaker_session=sm_session,
)

model.deploy(
    initial_instance_count=1,
    instance_type="ml.p4d.24xlarge",
    endpoint_name="mvp-finetuning-58-epochs-test-p4d-24xl"
)

Dynamic LoRA Loading

Our inference script supports dynamic LoRA scale adjustment, allowing us to control the influence of fine-tuning at inference time:

def predict_fn(data, model):
    lora_scale = data.pop("lora_scale")
    model.unfuse_lora()
    model.load_lora_weights("/opt/ml/model", weight_name="pytorch_lora_weights.safetensors")
    model.fuse_lora(lora_scale=lora_scale)
    
    generated_images = model(
        prompt=prompt,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        negative_prompt=negative_prompt,
        num_images_per_prompt=num_images_per_prompt,
        generator=generator,
    ).images

Performance Optimization

We tested various instance types for inference:

  • ml.g5.2xlarge: Cost-effective for development
  • ml.p4d.24xlarge: High-performance for production
  • ml.p3.2xlarge: Good balance for medium workloads

Generation times averaged 17-18 seconds per image on p4d.24xlarge instances, as shown in our performance benchmarks.

Production Considerations

Cost Optimization

  1. Spot Instances: We use spot instances for training when possible, reducing costs by up to 70%
  2. Checkpoint Recovery: Regular checkpointing allows us to resume from interruptions
  3. Efficient Storage: Only storing LoRA weights (typically <200MB) instead of full models (>10GB)

Monitoring and Logging

We integrated comprehensive logging throughout the pipeline:

logger.info("Fine-tuning the image generative model and deploying an endpoint to %s is complete.", config.endpoint_name)

SageMaker CloudWatch integration provides real-time monitoring of:

  • Training metrics (loss, learning rate)
  • Instance utilization
  • Inference latency and throughput

Scalability

Our architecture scales horizontally:

  • Multiple training jobs can run concurrently
  • Inference endpoints auto-scale based on load
  • Data processing leverages SageMaker Processing jobs for parallelization

Lessons Learned

  1. Start Small: Begin with low-rank LoRA (rank 4-8) before scaling up
  2. Caption Quality Matters: Investing in high-quality captions significantly improves results
  3. Instance Selection: ml.g5.2xlarge offers the best price/performance for most training tasks
  4. Checkpoint Frequently: Save checkpoints every 500 steps to enable experimentation
  5. Monitor Costs: Set up AWS Cost Explorer alerts for training jobs

Future Improvements

  1. Automated Hyperparameter Tuning: Integrate SageMaker's built-in hyperparameter tuning
  2. Multi-Model Endpoints: Deploy multiple LoRA adapters on a single endpoint
  3. A/B Testing: Implement systematic A/B testing for different model versions
  4. Custom Schedulers: Experiment with more sophisticated learning rate schedules

Conclusion

Building a production-ready SDXL fine-tuning pipeline requires careful orchestration of multiple components. By leveraging AWS SageMaker's managed infrastructure, containerization for reproducibility, and systematic experimentation, we created a robust system capable of producing high-quality, domain-specific image generation models.

The combination of LoRA for efficient fine-tuning, comprehensive grid search for optimization, and scalable deployment infrastructure provides a solid foundation for any team looking to customize large diffusion models for their specific use cases.

The complete codebase demonstrates that with the right architecture and tooling, fine-tuning state-of-the-art image generation models is accessible and practical for production deployments.


This project represents months of experimentation and optimization. Special thanks to the AWS SageMaker team for their excellent documentation and the Hugging Face community for the incredible Diffusers library.