This commit is contained in:
casinca 2026-01-14 09:59:05 +01:00
parent 34891ed7c4
commit 78f4492ab7
7 changed files with 7 additions and 7 deletions

View File

@ -358,7 +358,7 @@
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" # Reusuable utilities\n",
" # Reusable utilities\n",
" cos, sin = compute_rope_params(\n",
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
" theta_base=cfg[\"rope_base\"],\n",

View File

@ -432,7 +432,7 @@
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" # Reusuable utilities\n",
" # Reusable utilities\n",
" if cfg[\"head_dim\"] is None:\n",
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
" else:\n",

View File

@ -422,7 +422,7 @@
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" # Reusuable utilities\n",
" # Reusable utilities\n",
" if cfg[\"head_dim\"] is None:\n",
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
" else:\n",

View File

@ -388,7 +388,7 @@
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n",
" # Reusuable utilities\n",
" # Reusable utilities\n",
" if cfg[\"head_dim\"] is None:\n",
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
" else:\n",

View File

@ -333,7 +333,7 @@
" # self.out_head.weight = self.tok_emb.weight\n",
" # torch.nn.init.normal_(self.out_head.weight, mean=0.0, std=0.02)\n",
"\n",
" # Reusuable utilities\n",
" # Reusable utilities\n",
" cos, sin = compute_rope_params(\n",
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
" theta_base=cfg[\"rope_base\"],\n",

View File

@ -336,7 +336,7 @@
" # self.out_head.weight = self.tok_emb.weight\n",
" # torch.nn.init.normal_(self.out_head.weight, mean=0.0, std=0.02)\n",
"\n",
" # Reusuable utilities\n",
" # Reusable utilities\n",
" if cfg[\"head_dim\"] is None:\n",
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
" else:\n",

View File

@ -65,7 +65,7 @@ class Llama3Model(nn.Module):
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
# Reusuable utilities
# Reusable utilities
cos, sin = compute_rope_params(
head_dim=cfg["emb_dim"] // cfg["n_heads"],
theta_base=cfg["rope_base"],