diff --git a/examples/config_tiny_llama.yaml b/examples/config_tiny_llama.yaml index cbc053b6..f1ca16d2 100644 --- a/examples/config_tiny_llama.yaml +++ b/examples/config_tiny_llama.yaml @@ -73,7 +73,7 @@ optimizer: learning_rate_scheduler: learning_rate: 0.0003 lr_decay_starting_step: null - lr_decay_steps: 8 + lr_decay_steps: 13 lr_decay_style: cosine lr_warmup_steps: 2 lr_warmup_style: linear @@ -104,6 +104,6 @@ tokens: limit_test_batches: 0 limit_val_batches: 0 micro_batch_size: 2 - sequence_length: 32 - train_steps: 10 + sequence_length: 256 + train_steps: 15 val_check_interval: -1 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 0d338296..1bc701bf 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -52,6 +52,9 @@ if DISABLE_FLASH_ATTENTION: print("Warning: Flash attention was disabled!") + # FSDP + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_flash_sdp(False) RMSNorm = RMSNorm if DISABLE_FLASH_ATTENTION else TritonRMSNorm diff --git a/tests/test_llama.py b/tests/test_llama.py index 942c45bc..3881af93 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,5 +1,5 @@ # Script to test correctness of training script by comparing loss value after 100th iteration with expected loss value -# pytest -sv tests/test_train_llama.py or python tests/test_train_llama.py +# pytest -sv tests/test_llama.py or python tests/test_train_llama.py import atexit import os