Training DeepSeek V3 on 24× A100s — Part 2: torchrun and DeepSpeed ZeRO-3
Exact launch commands, DeepSpeed configs, and how ZeRO-3 + MoE let a 671B model fine-tune stably across 3 nodes.
This post dives into how I orchestrated 24 GPUs with torchrun
and stabilized memory with DeepSpeed ZeRO-3 while fine-tuning LoRA adapters on DeepSeek V3 (671B MoE).
The worker and head launch
I launched three nodes with 8 processes per node. The environment was explicit to avoid mystery defaults and to keep networking stable:
# Common environment per container
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export LOCAL_WORLD_SIZE=8
export WORLD_SIZE=$(( NNODES * NPROC_PER_NODE ))
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,garbage_collection_threshold:0.8"
# Rendezvous and NCCL over TCP (IB disabled for the smoke test)
export MASTER_ADDR=10.18.122.130
export MASTER_PORT=39500
export NCCL_SOCKET_IFNAME=eth0
export GLOO_SOCKET_IFNAME=eth0
export NCCL_P2P_DISABLE=1
export NCCL_IB_DISABLE=1
export NCCL_SHM_DISABLE=1
export OMP_NUM_THREADS=1
export HF_ENABLE_PARALLEL_LOADING=true
export TORCH_NCCL_TIMEOUT_MS=3600000
The actual run command:
torchrun \
--nnodes=3 \
--nproc_per_node=8 \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
--max_restarts=0 \
src/train.py \
--stage sft \
--do_train \
--model_name_or_path /nfs/DeepSeek-V3-bf16 \
--dataset all_creator_training \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj,k_proj,o_proj \
--lora_rank 16 \
--lora_alpha 32 \
--output_dir "$OUTPUT_DIR" \
--overwrite_output_dir \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--learning_rate 1e-5 \
--adam_beta2 0.98 \
--weight_decay 0.01 \
--warmup_steps 100 \
--bf16 \
--deepspeed "$DS_CONFIG" \
--logging_steps 1 \
--save_strategy steps \
--save_steps 50 \
--save_on_each_node false \
--save_safetensors true \
--save_only_model true \
--max_steps 2000 \
--report_to tensorboard \
--logging_dir "$OUTPUT_DIR/logs"
Notes:
- I kept a small per-device batch (
1
) with no grad accumulation for early stability. Larger effective batches are possible once stable. --save_only_model true
avoids the massive optimizer/ZeRO shards (see Part 5 on janitoring and save semantics).
Trainer summary from a successful launch
[INFO|trainer.py:2409] 2025-09-02 20:12:55,418 >> ** Running training **
[INFO|trainer.py:2410] 2025-09-02 20:12:55,418 >> Num examples = 5,437,649
[INFO|trainer.py:2411] 2025-09-02 20:12:55,418 >> Num Epochs = 1
[INFO|trainer.py:2412] 2025-09-02 20:12:55,418 >> Instantaneous batch size per device = 1
[INFO|trainer.py:2415] 2025-09-02 20:12:55,418 >> Total train batch size (w. parallel, distributed & accumulation) = 24
[INFO|trainer.py:2416] 2025-09-02 20:12:55,418 >> Gradient Accumulation steps = 1
[INFO|trainer.py:2417] 2025-09-02 20:12:55,418 >> Total optimization steps = 2,000
[INFO|trainer.py:2418] 2025-09-02 20:12:55,689 >> Number of trainable parameters = 97,006,592
The DeepSpeed configs I actually used
I iterated two main configs. First, a CPU-offloaded profile to reduce GPU footprint when I was memory-bound:
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": { "device": "cpu", "pin_memory": true },
"offload_param": { "device": "cpu", "pin_memory": true },
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 200000000,
"stage3_prefetch_bucket_size": 200000000,
"stage3_param_persistence_threshold": 1000000,
"stage3_max_live_parameters": 500000000,
"stage3_max_reuse_distance": 500000000,
"stage3_gather_16bit_weights_on_model_save": true
},
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"bf16": { "enabled": "auto" },
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": true,
"contiguous_memory_optimization": true,
"synchronize_checkpoint_boundary": false
},
"communication_data_type": "bf16",
"wall_clock_breakdown": false
}
Later, once things were stable and I had more headroom, I used a no-offload profile with larger buckets to reduce CPU↔GPU traffic and improve throughput:
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": { "device": "none" },
"offload_param": { "device": "none" },
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 16,
"gradient_clipping": 1.0,
"bf16": { "enabled": true },
"optimizer": {
"type": "AdamW",
"params": { "lr": 1e-5, "betas": [0.9, 0.98], "eps": 1e-8, "weight_decay": 0.01 }
},
"scheduler": {
"type": "WarmupDecayLR",
"params": { "warmup_min_lr": 0, "warmup_max_lr": 1e-5, "warmup_num_steps": 100, "total_num_steps": 2000 }
},
"zero_allow_untested_optimizer": true,
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": false,
"contiguous_memory_optimization": true,
"synchronize_checkpoint_boundary": false
},
"moe": {
"enabled": true,
"moe_param_group": true,
"expert_parallel_size": 8,
"top_k": 2,
"min_capacity": 4,
"capacity_factor": 1.25
},
"communication_data_type": "bf16",
"wall_clock_breakdown": true
}
Critical flag that fixed empty LoRA saves:
"stage3_gather_16bit_weights_on_model_save": true
Without that, saving only the adapters at step boundaries could yield an empty header if PEFT couldn’t find your targeted modules. I also ensured my target list matched DeepSeek’s transformer modules:
--lora_target q_proj,v_proj,k_proj,o_proj
Parallelism: what’s actually in play
- Data parallel across 24 ranks (3 nodes × 8 GPUs)
- ZeRO-3 partitions params/gradients/optimizer states across ranks
- DeepSeek V3 MoE activates a small subset of experts per token; I didn’t enable extra tensor/pipeline parallelism
- Activation checkpointing was on, with optional CPU checkpointing when memory was tight
From my notes, the observed footprint at world_size=24
was ~50–65 GB per GPU, which tracks with sharded parameters, activations, and working buffers.
Logging and stability flags that mattered
export TORCH_DISTRIBUTED_DEBUG=INFO
export NCCL_DEBUG=WARN # or INFO when debugging collectives
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_IGNORE_CPU_AFFINITY=1
export OMP_NUM_THREADS=1
For long rendezvous or slow first steps:
export TORCH_NCCL_TIMEOUT_MS=3600000
Takeaways from iteration
- Start with CPU offload to stabilize memory; move to no-offload + bigger buckets once things are healthy
- Keep batch small and accumulation minimal until you confirm step 1 completes
- Make LoRA save semantics explicit; don’t assume defaults
- Prefer node-local checkpoint writes; aggregate later
In Part 3 I cover the CUDA/driver/Fabric Manager mismatches across nodes that produced the dreaded cudaGetDeviceCount -> error 802
and how I fixed them cleanly.