Optimize training: add gradient accumulation & increase batch size

This commit is contained in:
Shahar Dickstein 2026-02-15 14:10:16 +02:00
parent 9e47e507dd
commit 4c5f9fadbb

View File

@ -89,7 +89,7 @@ def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
return params
def train(model_size="124M", max_steps=1000):
def train(model_size="124M", max_steps=1000, batch_size=4, accumulation_steps=4):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
@ -130,10 +130,11 @@ def train(model_size="124M", max_steps=1000):
model.out_head = new_head.to(device)
# 4. Data Loader
train_loader = create_dataloader(tokenizer, batch_size=2, max_length=cfg["context_length"])
train_loader = create_dataloader(tokenizer, batch_size=batch_size, max_length=cfg["context_length"])
# 5. Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)
optimizer.zero_grad() # Initialize gradients
# 6. Loop
model.train()
@ -143,7 +144,7 @@ def train(model_size="124M", max_steps=1000):
for input_chunk, target_chunk in train_loader:
input_chunk, target_chunk = input_chunk.to(device), target_chunk.to(device)
optimizer.zero_grad()
# optimizer.zero_grad() # Moved to accumulation step
logits = model(input_chunk)
loss = torch.nn.functional.cross_entropy(
@ -151,14 +152,19 @@ def train(model_size="124M", max_steps=1000):
target_chunk.flatten(0, 1)
)
# Gradient Accumulation
loss = loss / accumulation_steps
loss.backward()
optimizer.step()
if step % 10 == 0:
print(f"Step {step}: Loss {loss.item():.4f}")
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
if (step + 1) % (10 * accumulation_steps) == 0:
print(f"Step {step + 1}: Loss {loss.item() * accumulation_steps:.4f}")
step += 1
if step >= max_steps:
if step >= max_steps * accumulation_steps:
break
print("Training complete.")