Training DeepSeek V3 on 24× A100s — Part 2: torchrun and DeepSpeed ZeRO-3

Exact launch commands, DeepSpeed configs, and how ZeRO-3 + MoE let a 671B model fine-tune stably across 3 nodes.

This post dives into how I orchestrated 24 GPUs with torchrun and stabilized memory with DeepSpeed ZeRO-3 while fine-tuning LoRA adapters on DeepSeek V3 (671B MoE).

The worker and head launch

I launched three nodes with 8 processes per node. The environment was explicit to avoid mystery defaults and to keep networking stable:

# Common environment per container
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export LOCAL_WORLD_SIZE=8
export WORLD_SIZE=$(( NNODES * NPROC_PER_NODE ))
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,garbage_collection_threshold:0.8"

# Rendezvous and NCCL over TCP (IB disabled for the smoke test)
export MASTER_ADDR=10.18.122.130
export MASTER_PORT=39500
export NCCL_SOCKET_IFNAME=eth0
export GLOO_SOCKET_IFNAME=eth0
export NCCL_P2P_DISABLE=1
export NCCL_IB_DISABLE=1
export NCCL_SHM_DISABLE=1
export OMP_NUM_THREADS=1
export HF_ENABLE_PARALLEL_LOADING=true
export TORCH_NCCL_TIMEOUT_MS=3600000

The actual run command:

torchrun \
  --nnodes=3 \
  --nproc_per_node=8 \
  --node_rank=$NODE_RANK \
  --master_addr=$MASTER_ADDR \
  --master_port=$MASTER_PORT \
  --max_restarts=0 \
  src/train.py \
  --stage sft \
  --do_train \
  --model_name_or_path /nfs/DeepSeek-V3-bf16 \
  --dataset all_creator_training \
  --template default \
  --finetuning_type lora \
  --lora_target q_proj,v_proj,k_proj,o_proj \
  --lora_rank 16 \
  --lora_alpha 32 \
  --output_dir "$OUTPUT_DIR" \
  --overwrite_output_dir \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 1 \
  --learning_rate 1e-5 \
  --adam_beta2 0.98 \
  --weight_decay 0.01 \
  --warmup_steps 100 \
  --bf16 \
  --deepspeed "$DS_CONFIG" \
  --logging_steps 1 \
  --save_strategy steps \
  --save_steps 50 \
  --save_on_each_node false \
  --save_safetensors true \
  --save_only_model true \
  --max_steps 2000 \
  --report_to tensorboard \
  --logging_dir "$OUTPUT_DIR/logs"

Notes:

  • I kept a small per-device batch (1) with no grad accumulation for early stability. Larger effective batches are possible once stable.
  • --save_only_model true avoids the massive optimizer/ZeRO shards (see Part 5 on janitoring and save semantics).

Trainer summary from a successful launch

[INFO|trainer.py:2409] 2025-09-02 20:12:55,418 >> ** Running training **
[INFO|trainer.py:2410] 2025-09-02 20:12:55,418 >>   Num examples = 5,437,649
[INFO|trainer.py:2411] 2025-09-02 20:12:55,418 >>   Num Epochs = 1
[INFO|trainer.py:2412] 2025-09-02 20:12:55,418 >>   Instantaneous batch size per device = 1
[INFO|trainer.py:2415] 2025-09-02 20:12:55,418 >>   Total train batch size (w. parallel, distributed & accumulation) = 24
[INFO|trainer.py:2416] 2025-09-02 20:12:55,418 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2417] 2025-09-02 20:12:55,418 >>   Total optimization steps = 2,000
[INFO|trainer.py:2418] 2025-09-02 20:12:55,689 >>   Number of trainable parameters = 97,006,592

The DeepSpeed configs I actually used

I iterated two main configs. First, a CPU-offloaded profile to reduce GPU footprint when I was memory-bound:

{
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": { "device": "cpu", "pin_memory": true },
    "offload_param": { "device": "cpu", "pin_memory": true },
    "overlap_comm": true,
    "contiguous_gradients": true,
    "reduce_bucket_size": 200000000,
    "stage3_prefetch_bucket_size": 200000000,
    "stage3_param_persistence_threshold": 1000000,
    "stage3_max_live_parameters": 500000000,
    "stage3_max_reuse_distance": 500000000,
    "stage3_gather_16bit_weights_on_model_save": true
  },
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "bf16": { "enabled": "auto" },
  "activation_checkpointing": {
    "partition_activations": true,
    "cpu_checkpointing": true,
    "contiguous_memory_optimization": true,
    "synchronize_checkpoint_boundary": false
  },
  "communication_data_type": "bf16",
  "wall_clock_breakdown": false
}

Later, once things were stable and I had more headroom, I used a no-offload profile with larger buckets to reduce CPU↔GPU traffic and improve throughput:

{
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": { "device": "none" },
    "offload_param": { "device": "none" },
    "overlap_comm": true,
    "contiguous_gradients": true,
    "reduce_bucket_size": 5e8,
    "stage3_prefetch_bucket_size": 5e8,
    "stage3_param_persistence_threshold": 1e6,
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "stage3_gather_16bit_weights_on_model_save": true
  },
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 16,
  "gradient_clipping": 1.0,
  "bf16": { "enabled": true },
  "optimizer": {
    "type": "AdamW",
    "params": { "lr": 1e-5, "betas": [0.9, 0.98], "eps": 1e-8, "weight_decay": 0.01 }
  },
  "scheduler": {
    "type": "WarmupDecayLR",
    "params": { "warmup_min_lr": 0, "warmup_max_lr": 1e-5, "warmup_num_steps": 100, "total_num_steps": 2000 }
  },
  "zero_allow_untested_optimizer": true,
  "activation_checkpointing": {
    "partition_activations": true,
    "cpu_checkpointing": false,
    "contiguous_memory_optimization": true,
    "synchronize_checkpoint_boundary": false
  },
  "moe": {
    "enabled": true,
    "moe_param_group": true,
    "expert_parallel_size": 8,
    "top_k": 2,
    "min_capacity": 4,
    "capacity_factor": 1.25
  },
  "communication_data_type": "bf16",
  "wall_clock_breakdown": true
}

Critical flag that fixed empty LoRA saves:

"stage3_gather_16bit_weights_on_model_save": true

Without that, saving only the adapters at step boundaries could yield an empty header if PEFT couldn’t find your targeted modules. I also ensured my target list matched DeepSeek’s transformer modules:

--lora_target q_proj,v_proj,k_proj,o_proj

Parallelism: what’s actually in play

  • Data parallel across 24 ranks (3 nodes × 8 GPUs)
  • ZeRO-3 partitions params/gradients/optimizer states across ranks
  • DeepSeek V3 MoE activates a small subset of experts per token; I didn’t enable extra tensor/pipeline parallelism
  • Activation checkpointing was on, with optional CPU checkpointing when memory was tight

From my notes, the observed footprint at world_size=24 was ~50–65 GB per GPU, which tracks with sharded parameters, activations, and working buffers.

Logging and stability flags that mattered

export TORCH_DISTRIBUTED_DEBUG=INFO
export NCCL_DEBUG=WARN      # or INFO when debugging collectives
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_IGNORE_CPU_AFFINITY=1
export OMP_NUM_THREADS=1

For long rendezvous or slow first steps:

export TORCH_NCCL_TIMEOUT_MS=3600000

Takeaways from iteration

  • Start with CPU offload to stabilize memory; move to no-offload + bigger buckets once things are healthy
  • Keep batch small and accumulation minimal until you confirm step 1 completes
  • Make LoRA save semantics explicit; don’t assume defaults
  • Prefer node-local checkpoint writes; aggregate later

In Part 3 I cover the CUDA/driver/Fabric Manager mismatches across nodes that produced the dreaded cudaGetDeviceCount -> error 802 and how I fixed them cleanly.