mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-11 14:21:41 +08:00
Optimize training: add gradient accumulation & increase batch size
This commit is contained in:
parent
9e47e507dd
commit
4c5f9fadbb
@ -89,7 +89,7 @@ def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
|
||||
return params
|
||||
|
||||
|
||||
def train(model_size="124M", max_steps=1000):
|
||||
def train(model_size="124M", max_steps=1000, batch_size=4, accumulation_steps=4):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
@ -130,10 +130,11 @@ def train(model_size="124M", max_steps=1000):
|
||||
model.out_head = new_head.to(device)
|
||||
|
||||
# 4. Data Loader
|
||||
train_loader = create_dataloader(tokenizer, batch_size=2, max_length=cfg["context_length"])
|
||||
train_loader = create_dataloader(tokenizer, batch_size=batch_size, max_length=cfg["context_length"])
|
||||
|
||||
# 5. Optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)
|
||||
optimizer.zero_grad() # Initialize gradients
|
||||
|
||||
# 6. Loop
|
||||
model.train()
|
||||
@ -143,7 +144,7 @@ def train(model_size="124M", max_steps=1000):
|
||||
for input_chunk, target_chunk in train_loader:
|
||||
input_chunk, target_chunk = input_chunk.to(device), target_chunk.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
# optimizer.zero_grad() # Moved to accumulation step
|
||||
logits = model(input_chunk)
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(
|
||||
@ -151,14 +152,19 @@ def train(model_size="124M", max_steps=1000):
|
||||
target_chunk.flatten(0, 1)
|
||||
)
|
||||
|
||||
# Gradient Accumulation
|
||||
loss = loss / accumulation_steps
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if step % 10 == 0:
|
||||
print(f"Step {step}: Loss {loss.item():.4f}")
|
||||
if (step + 1) % accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if (step + 1) % (10 * accumulation_steps) == 0:
|
||||
print(f"Step {step + 1}: Loss {loss.item() * accumulation_steps:.4f}")
|
||||
|
||||
step += 1
|
||||
if step >= max_steps:
|
||||
if step >= max_steps * accumulation_steps:
|
||||
break
|
||||
|
||||
print("Training complete.")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user