mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-11 14:21:41 +08:00
Add argparse for hyperparameters and support larger models
This commit is contained in:
parent
4c5f9fadbb
commit
2ad253570b
@ -7,7 +7,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
from tqdm import tqdm
|
||||
from architecture import GPTModel, load_weights_into_gpt
|
||||
from config import GPT_CONFIG_124M
|
||||
from config import GPT_CONFIG_124M, GPT_CONFIG_355M, GPT_CONFIG_774M, GPT_CONFIG_1558M
|
||||
from tokenizer_utils import TokenizerWrapper
|
||||
from dataset_prep import create_dataloader
|
||||
|
||||
@ -89,7 +89,7 @@ def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
|
||||
return params
|
||||
|
||||
|
||||
def train(model_size="124M", max_steps=1000, batch_size=4, accumulation_steps=4):
|
||||
def train(model_size="124M", max_steps=1000, batch_size=2, accumulation_steps=8, max_length=1024):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
@ -101,7 +101,19 @@ def train(model_size="124M", max_steps=1000, batch_size=4, accumulation_steps=4)
|
||||
print("Initializing architecture...")
|
||||
# Map settings to our config format if needed, but we used GPT_CONFIG_124M as base.
|
||||
# We should ensure config matches loaded settings.
|
||||
cfg = GPT_CONFIG_124M
|
||||
if model_size == "124M":
|
||||
cfg = GPT_CONFIG_124M
|
||||
elif model_size == "355M":
|
||||
cfg = GPT_CONFIG_355M
|
||||
elif model_size == "774M":
|
||||
cfg = GPT_CONFIG_774M
|
||||
elif model_size == "1558M":
|
||||
cfg = GPT_CONFIG_1558M
|
||||
else:
|
||||
raise ValueError(f"Unknown model size: {model_size}")
|
||||
|
||||
# Update max_length from args in case we want shorter context to save memory
|
||||
cfg["context_length"] = max_length
|
||||
|
||||
model = GPTModel(cfg)
|
||||
print("Loading weights into model...")
|
||||
@ -172,4 +184,15 @@ def train(model_size="124M", max_steps=1000, batch_size=4, accumulation_steps=4)
|
||||
print("Model saved to 'tool_llm.pth'.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_size", type=str, default="124M", help="GPT-2 model size (124M, 355M, 774M, 1558M)")
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="Batch size per step. Decrease if OOM.")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=4, help="Gradient accumulation steps.")
|
||||
parser.add_argument("--max_steps", type=int, default=1000, help="Total training steps (batches processed).")
|
||||
parser.add_argument("--max_length", type=int, default=1024, help="Context length.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Training with: {args}")
|
||||
train(model_size=args.model_size, max_steps=args.max_steps, batch_size=args.batch_size, accumulation_steps=args.accumulation_steps, max_length=args.max_length)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user