Add argparse for hyperparameters and support larger models

This commit is contained in:
Shahar Dickstein 2026-02-15 14:16:48 +02:00
parent 4c5f9fadbb
commit 2ad253570b

View File

@ -7,7 +7,7 @@ import numpy as np
import tensorflow as tf
from tqdm import tqdm
from architecture import GPTModel, load_weights_into_gpt
from config import GPT_CONFIG_124M
from config import GPT_CONFIG_124M, GPT_CONFIG_355M, GPT_CONFIG_774M, GPT_CONFIG_1558M
from tokenizer_utils import TokenizerWrapper
from dataset_prep import create_dataloader
@ -89,7 +89,7 @@ def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
return params
def train(model_size="124M", max_steps=1000, batch_size=4, accumulation_steps=4):
def train(model_size="124M", max_steps=1000, batch_size=2, accumulation_steps=8, max_length=1024):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
@ -101,7 +101,19 @@ def train(model_size="124M", max_steps=1000, batch_size=4, accumulation_steps=4)
print("Initializing architecture...")
# Map settings to our config format if needed, but we used GPT_CONFIG_124M as base.
# We should ensure config matches loaded settings.
cfg = GPT_CONFIG_124M
if model_size == "124M":
cfg = GPT_CONFIG_124M
elif model_size == "355M":
cfg = GPT_CONFIG_355M
elif model_size == "774M":
cfg = GPT_CONFIG_774M
elif model_size == "1558M":
cfg = GPT_CONFIG_1558M
else:
raise ValueError(f"Unknown model size: {model_size}")
# Update max_length from args in case we want shorter context to save memory
cfg["context_length"] = max_length
model = GPTModel(cfg)
print("Loading weights into model...")
@ -172,4 +184,15 @@ def train(model_size="124M", max_steps=1000, batch_size=4, accumulation_steps=4)
print("Model saved to 'tool_llm.pth'.")
if __name__ == "__main__":
train()
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model_size", type=str, default="124M", help="GPT-2 model size (124M, 355M, 774M, 1558M)")
parser.add_argument("--batch_size", type=int, default=2, help="Batch size per step. Decrease if OOM.")
parser.add_argument("--accumulation_steps", type=int, default=4, help="Gradient accumulation steps.")
parser.add_argument("--max_steps", type=int, default=1000, help="Total training steps (batches processed).")
parser.add_argument("--max_length", type=int, default=1024, help="Context length.")
args = parser.parse_args()
print(f"Training with: {args}")
train(model_size=args.model_size, max_steps=args.max_steps, batch_size=args.batch_size, accumulation_steps=args.accumulation_steps, max_length=args.max_length)