diff --git a/ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb b/ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb index 02b61e5..fe9560b 100644 --- a/ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb +++ b/ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb @@ -233,10 +233,8 @@ "source": [ "class GroupedQueryAttention(nn.Module):\n", " def __init__(\n", - " self, d_in, d_out, context_length, num_heads,\n", + " self, d_in, d_out, num_heads,\n", " num_kv_groups,\n", - " rope_base=10_000,\n", - " rope_config=None,\n", " dtype=None\n", " ):\n", " super().__init__()\n", @@ -322,11 +320,8 @@ " self.att = GroupedQueryAttention(\n", " d_in=cfg[\"emb_dim\"],\n", " d_out=cfg[\"emb_dim\"],\n", - " context_length=cfg[\"context_length\"],\n", " num_heads=cfg[\"n_heads\"],\n", " num_kv_groups=cfg[\"n_kv_groups\"],\n", - " rope_base=cfg[\"rope_base\"],\n", - " rope_config=cfg[\"rope_freq\"],\n", " dtype=cfg[\"dtype\"]\n", " )\n", " self.ff = FeedForward(cfg)\n",