diff --git a/tool_calling_experiment/train_colab.py b/tool_calling_experiment/train_colab.py index 5033329..112255c 100644 --- a/tool_calling_experiment/train_colab.py +++ b/tool_calling_experiment/train_colab.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf from tqdm import tqdm from architecture import GPTModel, load_weights_into_gpt -from config import GPT_CONFIG_124M +from config import GPT_CONFIG_124M, GPT_CONFIG_355M, GPT_CONFIG_774M, GPT_CONFIG_1558M from tokenizer_utils import TokenizerWrapper from dataset_prep import create_dataloader @@ -89,7 +89,7 @@ def load_gpt2_params_from_tf_ckpt(ckpt_path, settings): return params -def train(model_size="124M", max_steps=1000, batch_size=4, accumulation_steps=4): +def train(model_size="124M", max_steps=1000, batch_size=2, accumulation_steps=8, max_length=1024): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") @@ -101,7 +101,19 @@ def train(model_size="124M", max_steps=1000, batch_size=4, accumulation_steps=4) print("Initializing architecture...") # Map settings to our config format if needed, but we used GPT_CONFIG_124M as base. # We should ensure config matches loaded settings. - cfg = GPT_CONFIG_124M + if model_size == "124M": + cfg = GPT_CONFIG_124M + elif model_size == "355M": + cfg = GPT_CONFIG_355M + elif model_size == "774M": + cfg = GPT_CONFIG_774M + elif model_size == "1558M": + cfg = GPT_CONFIG_1558M + else: + raise ValueError(f"Unknown model size: {model_size}") + + # Update max_length from args in case we want shorter context to save memory + cfg["context_length"] = max_length model = GPTModel(cfg) print("Loading weights into model...") @@ -172,4 +184,15 @@ def train(model_size="124M", max_steps=1000, batch_size=4, accumulation_steps=4) print("Model saved to 'tool_llm.pth'.") if __name__ == "__main__": - train() + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--model_size", type=str, default="124M", help="GPT-2 model size (124M, 355M, 774M, 1558M)") + parser.add_argument("--batch_size", type=int, default=2, help="Batch size per step. Decrease if OOM.") + parser.add_argument("--accumulation_steps", type=int, default=4, help="Gradient accumulation steps.") + parser.add_argument("--max_steps", type=int, default=1000, help="Total training steps (batches processed).") + parser.add_argument("--max_length", type=int, default=1024, help="Context length.") + + args = parser.parse_args() + + print(f"Training with: {args}") + train(model_size=args.model_size, max_steps=args.max_steps, batch_size=args.batch_size, accumulation_steps=args.accumulation_steps, max_length=args.max_length)