LLM training bug fixes - Gradient accumulation was wrong
Posted by danielhanchen@reddit | LocalLLaMA | View on Reddit | 38 comments
Hey r/LocalLLaMA! A few days ago, u/TheKaitchup posted an issue showing using gradient accumulation in training and finetuning LLMs caused training losses to be different. GA allowed one to mimic full batch training without using more VRAM.
Theoretically using gradient accumulation should be equivalent to full batch training if we hold bsz * ga to be constant. But, the training losses actually diverge. When the bsz=16 and ga=1, the training loss seems to be much lower than when bsz=1 and ga=16, as shown below:
Using naive gradient accumulation caused L2 norm errors between the LoRA weights for bsz=16 and ga=16 to be quite large, and increases with even large gradient accumulation steps.
After fixing it in Unsloth https://github.com/unslothai/unsloth, the L2 Norm becomes constant, and is a magnitude factor smaller than using standard gradient accumulation.
Our blog post https://unsloth.ai/blog/gradient has more details, but TLDR the normalizer factor during the cross entropy loss calculation was not correct, especially for training varying sequence length datasets.
Once you fix this, we get the below training losses which all match up (as expected) for ba=16 and ga=16.
To use Unsloth's fixed GA trainer, call:
from unsloth import unsloth_train
trainer_stats = unsloth_train(trainer)
Also don't forget to update Unsloth as well via pip install --upgrade --no-cache-dir unsloth
We also have a free Colab notebook to finetune Llama 3.2 1/3B conversational style 2x faster with 70% less VRAM with our fixed trainer here: https://colab.research.google.com/drive/1z0XJU2FCzDC8oyXa2Nd4jCxylRMI-o0-?usp=sharing
And a free Kaggle notebook as well: https://www.kaggle.com/code/danielhanchen/fixed-kaggle-llama-3-2-1b-3b-conversation
This issue affects all multi GPU training as well, since gradients have to be accumulated like in gradient accumulation. Trainers which use the naive gradient accumulation will have to fix it.
Electronic_Tune_657@reddit
Thanks for this!
_sqrkl@reddit
I upgraded using --no-cache-dir. Did you forget to update the version number? It's showing 2024.10.1
danielhanchen@reddit (OP)
Oh you'll have to use unsloth_train from Unsloth. Sadly we had to make our own Pytorch native trainer to circumvent the issue. DPO currently is not supported sorry - will add a fix in though!
Intelligent_Let_2538@reddit
Could you provide an update on the timeline for the bug fix in DPO/ORPO training? It would be helpful to know when we can expect it to be resolved.
_sqrkl@reddit
Ok noted, thanks
tgredditfc@reddit
No wonder I always find batch size is better than the equal AC which in theory they should have the same effect!
Great job! Thank you for the fix!
danielhanchen@reddit (OP)
Ye it was interesting why large batch sizes always seem better! Hope the fix will be helpful!
lordpuddingcup@reddit
Silly question as this isn't a LLM question, would this be something that would affect other trainers that use GA, like kohya_ss for flux (image gen)
breadwineandtits@reddit
Incredible work, thank you for your contribution!
WayBig7919@reddit
Amazing work again Daniel! Do you suggest retraining any finetuned model using this update or the accuracy difference is not much like maybe a percent or two. Were you able to determine the impact that this had on accuracy?
yoracale@reddit
Multimodal/vision support will be coming either this week or the next by the way! :)
FullOf_Bad_Ideas@reddit
Finetuning Qwen 2 VL 7B with unsloth "booster" in llama-factory already works btw.
UpperDog69@reddit
Now this I am very excited about. A lot of the current vision trainers seem very bad, for example llama factory by default uses nearest neighbor scaling to downsize images which is kinda insane in 2024.
_sqrkl@reddit
Not often you see a "we were wrong about this fundamental thing everyone takes for granted". That must have been a satisfying find, well done!
danielhanchen@reddit (OP)
Thanks!
FullOf_Bad_Ideas@reddit
That's pretty amazing! If I understand it right, the issue is there even when all individual samples in each gradient accumulation step have the same sequence length? I do my finetuning attempts with sample packing enabled and I think it roughly packs all steps to the same sequence length, hence the question.
danielhanchen@reddit (OP)
Yes so the issue still exists, albeit a bit less pronouned on packed sequences. If you're training on completions only, then the issue becomes much more pronounced!
FullOf_Bad_Ideas@reddit
I've been always training on the whole sample, so I am good at least there. I see in the graph you made that it looks like the error gets bigger the longer the training goes. So, would it make sense to assume that for finetuning on bigger dataset where I am doing 10000 or 40000 steps, this bug will be more influential?
danielhanchen@reddit (OP)
Ye it's possible the issue gets worse over time - there are some experiments showing it might get better overtime and match full batch training as well so it's a bit confusing.
Best to solve the issue directly!
un_passant@reddit
Β«This issue affects all multi GPU training as well,Β» I was under the impression that Unsloth was only for single GPU training. When did the multi-GPU training happen ? Is it also possible with the free version ?
Thx again for your incredible work !
yoracale@reddit
Not now but it's coming, for real. Multimodal/vision support will be firstly out next week!
AcanthaceaeNo5503@reddit
Thank you for the work ! πππ
danielhanchen@reddit (OP)
Thank you!!
Educational_Rent1059@reddit
GREAT JOB!!
danielhanchen@reddit (OP)
Thanks!
xadiant@reddit
1- Do you think this bug affects multi-gpu training differently?
2- Could many million $ pretraining and fine-tuning runs have performed slightly worse because of this?
3- It might be stupid question but do you think this error stacks up / gets worse with more steps?
As always, thank you for your contributions!
danielhanchen@reddit (OP)
Yes a bit differently but similarly, since gradients get accumulated then averaged.
Maybe yes. It's unclear on the exact accuracy differences but definitely the output has been incorrect.
The error accumulates over time, and generally it favours or overweights short sequence lengths and so long context training runs might be broken.
nero10579@reddit
Wow so basically all training libraries need to be updated with this?
danielhanchen@reddit (OP)
Yep unfortunately
nero10579@reddit
These are the moments I feel like I am really on the bleeding edge of this technology.
danielhanchen@reddit (OP)
π
az226@reddit
Because you are and the technology is in its infancy
nero10579@reddit
Yes exactly
TheKaitchup@reddit
Once again, incredible work!
Thanks for fixing this so quickly and for the detailed explanation!
danielhanchen@reddit (OP)
Thanks for the find!
vigneshwaran34@reddit
Nice blog.
So we need to traverse through the global batch(to find out the number of actual tokens) before loss computation on smaller batches with gradient accumulation. Can we do this efficiently?
danielhanchen@reddit (OP)
Yep correct! Yep it needs to be coded up efficiently! Added code to that efficiently in Unsloth's trainer :)
vigneshwaran34@reddit
Nice blog.
So we need to traverse through the global batch(to find out the number of actual tokens) before loss computation on smaller batches with gradient accumulation. Can we do this efficiently?