diff --git a/tool_calling_experiment/train_colab.py b/tool_calling_experiment/train_colab.py index 2d38897..5033329 100644 --- a/tool_calling_experiment/train_colab.py +++ b/tool_calling_experiment/train_colab.py @@ -89,7 +89,7 @@ def load_gpt2_params_from_tf_ckpt(ckpt_path, settings): return params -def train(model_size="124M", max_steps=1000): +def train(model_size="124M", max_steps=1000, batch_size=4, accumulation_steps=4): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") @@ -130,10 +130,11 @@ def train(model_size="124M", max_steps=1000): model.out_head = new_head.to(device) # 4. Data Loader - train_loader = create_dataloader(tokenizer, batch_size=2, max_length=cfg["context_length"]) + train_loader = create_dataloader(tokenizer, batch_size=batch_size, max_length=cfg["context_length"]) # 5. Optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01) + optimizer.zero_grad() # Initialize gradients # 6. Loop model.train() @@ -143,7 +144,7 @@ def train(model_size="124M", max_steps=1000): for input_chunk, target_chunk in train_loader: input_chunk, target_chunk = input_chunk.to(device), target_chunk.to(device) - optimizer.zero_grad() + # optimizer.zero_grad() # Moved to accumulation step logits = model(input_chunk) loss = torch.nn.functional.cross_entropy( @@ -151,14 +152,19 @@ def train(model_size="124M", max_steps=1000): target_chunk.flatten(0, 1) ) + # Gradient Accumulation + loss = loss / accumulation_steps loss.backward() - optimizer.step() - if step % 10 == 0: - print(f"Step {step}: Loss {loss.item():.4f}") + if (step + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + + if (step + 1) % (10 * accumulation_steps) == 0: + print(f"Step {step + 1}: Loss {loss.item() * accumulation_steps:.4f}") step += 1 - if step >= max_steps: + if step >= max_steps * accumulation_steps: break print("Training complete.")