LLM training bug fixes - Gradient accumulation was wrong

Posted by danielhanchen@reddit | LocalLLaMA | View on Reddit | 33 comments

Hey r/LocalLLaMA! A few days ago, u/TheKaitchup posted an issue showing using gradient accumulation in training and finetuning LLMs caused training losses to be different. GA allowed one to mimic full batch training without using more VRAM.

Theoretically using gradient accumulation should be equivalent to full batch training if we hold bsz * ga to be constant. But, the training losses actually diverge. When the bsz=16 and ga=1, the training loss seems to be much lower than when bsz=1 and ga=16, as shown below:

Using naive gradient accumulation caused L2 norm errors between the LoRA weights for bsz=16 and ga=16 to be quite large, and increases with even large gradient accumulation steps.

After fixing it in Unsloth https://github.com/unslothai/unsloth, the L2 Norm becomes constant, and is a magnitude factor smaller than using standard gradient accumulation.

Our blog post https://unsloth.ai/blog/gradient has more details, but TLDR the normalizer factor during the cross entropy loss calculation was not correct, especially for training varying sequence length datasets.

Once you fix this, we get the below training losses which all match up (as expected) for ba=16 and ga=16.

To use Unsloth's fixed GA trainer, call:

from unsloth import unsloth_train
trainer_stats = unsloth_train(trainer)

Also don't forget to update Unsloth as well via pip install --upgrade --no-cache-dir unsloth

We also have a free Colab notebook to finetune Llama 3.2 1/3B conversational style 2x faster with 70% less VRAM with our fixed trainer here: https://colab.research.google.com/drive/1z0XJU2FCzDC8oyXa2Nd4jCxylRMI-o0-?usp=sharing

And a free Kaggle notebook as well: https://www.kaggle.com/code/danielhanchen/fixed-kaggle-llama-3-2-1b-3b-conversation

This issue affects all multi GPU training as well, since gradients have to be accumulated like in gradient accumulation. Trainers which use the naive gradient accumulation will have to fix it.