diff --git a/docs_nnx/api_reference/flax.nnx/training/ema.rst b/docs_nnx/api_reference/flax.nnx/training/ema.rst new file mode 100644 index 000000000..69b22d151 --- /dev/null +++ b/docs_nnx/api_reference/flax.nnx/training/ema.rst @@ -0,0 +1,8 @@ +EMA +------------------------ + +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx + +.. autoclass:: EMA + :members: __init__, update diff --git a/docs_nnx/api_reference/flax.nnx/training/index.rst b/docs_nnx/api_reference/flax.nnx/training/index.rst index 32404f1de..c51f73283 100644 --- a/docs_nnx/api_reference/flax.nnx/training/index.rst +++ b/docs_nnx/api_reference/flax.nnx/training/index.rst @@ -8,4 +8,5 @@ Experimental API. See the `NNX page " - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# initialization\n", "model = make_model(rngs)\n", - "ema = EMA(model, decay=0.9)\n", - "\n", - "# simulate parameter update\n", - "def double(param):\n", - " param[...] *= 2.0\n", - "jax.tree.map(double, model, is_leaf=lambda x: isinstance(x, nnx.Variable))\n", - "ema.update(model)\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", + "ema = nnx.EMA(model, decay=0.9)\n", + "ema_model = ema.apply_to(model)\n", "\n", "@nnx.jit\n", "def train_step(model, optimizer, ema, x, y):\n", " loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)\n", " optimizer.update(model, grads)\n", " ema.update(model)\n", - " return loss\n", "\n", - "optimizer = nnx.Optimizer(\n", - " model,\n", - " tx=optax.adam(1e-3),\n", - " wrt=nnx.Param,\n", - ")\n", - "losses = []\n", + "@nnx.jit\n", + "def eval_step(model, x, y):\n", + " return loss_fn(model, x, y)\n", + "\n", "for _ in range(50):\n", - " loss = train_step(model, optimizer, ema, x, y)\n", - " losses.append(loss)\n", - "plt.plot(losses);" + " train_step(model, optimizer, ema, x, y)\n", + "\n", + "loss = eval_step(ema_model, x, y)\n", + "print(f\"final eval loss: {loss}\")" ] }, { "cell_type": "markdown", - "id": "cf02190b", + "id": "b90e900d", "metadata": {}, "source": [ "# Low Rank Adaptation" @@ -194,7 +112,7 @@ }, { "cell_type": "markdown", - "id": "73f6f98c", + "id": "84cc64ee", "metadata": {}, "source": [ "The pattern for adding low rank adaptation to an optimization loop is very similar to adding an exponential moving average. As before, we create a new pytree with the same structure as our model parameters, but here we store low rank additions to these parameters rather than weighted average values. " @@ -202,37 +120,10 @@ }, { "cell_type": "code", - "execution_count": 15, - "id": "80765fbd", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [ - { - "data": { - "text/html": [ - "
" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "execution_count": null, + "id": "05cc6cf7", + "metadata": {}, + "outputs": [], "source": [ "def add_rank2_lora(path, node):\n", " if isinstance(node, nnx.Linear):\n", @@ -240,28 +131,33 @@ " return node\n", "\n", "base_model = make_model(rngs)\n", - "lora_model = nnx.recursive_map(add_rank2_lora, base_model)\n", + "model = nnx.recursive_map(add_rank2_lora, base_model)\n", "nnx.display(model)" ] }, + { + "cell_type": "markdown", + "id": "8efe3ab5", + "metadata": {}, + "source": [ + "The training loop is the same as before, but we pass `wrt=nnx.LoRAParam` to the optimizer so that only the low-rank adaptation parameters are updated while the base model weights remain frozen." + ] + }, { "cell_type": "code", - "execution_count": 16, - "id": "5d509ad2", + "execution_count": null, + "id": "8807ea4e", "metadata": {}, "outputs": [], "source": [ + "optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.LoRAParam)\n", + "\n", "@nnx.jit\n", "def train_step(model, optimizer, x, y):\n", " loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)\n", " optimizer.update(model, grads)\n", " return loss\n", "\n", - "optimizer = nnx.Optimizer(\n", - " model,\n", - " tx=optax.adam(1e-3),\n", - " wrt=nnx.LoRAParam,\n", - ")\n", "\n", "losses = []\n", "for _ in range(50):\n", @@ -271,7 +167,7 @@ }, { "cell_type": "markdown", - "id": "73fea71f", + "id": "12176089", "metadata": {}, "source": [ "# LBFGS" @@ -279,7 +175,7 @@ }, { "cell_type": "markdown", - "id": "bab507b1", + "id": "4d96e4de", "metadata": {}, "source": [ "So far, we've been using optax optimizers with the interface ``optimizer.update(grads, opt_state)``. This works for simple optimization algorithms like ADAM, but for algorithms that use a line search like LBFGS, we need to pass more parameters. Below, we can see how the call to ``optimizer.update`` is given additional parameters when using LBFGS." @@ -287,31 +183,10 @@ }, { "cell_type": "code", - "execution_count": 17, - "id": "9c9e332d", + "execution_count": null, + "id": "f18bfad9", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQAANl5JREFUeJzt3Ql8VOW9//HfTPZAFkIISUgIm4CABAVEQCwUBPlzKdhbr3q1gEq99YJ/KVor/VfU6r241VJbXtjWJVKK4AZWVBZZRUDKJptQgkASSNiz7zPzfz1PMkMGQtaZObN83vd1OufMnDM5nJmbfH2e3/Mck81mswkAAIAXMxt9AgAAAI0hsAAAAK9HYAEAAF6PwAIAALwegQUAAHg9AgsAAPB6BBYAAOD1CCwAAMDrBYsfsFqtcvr0aYmKihKTyWT06QAAgCZQc9cWFRVJcnKymM1m/w8sKqykpqYafRoAAKAFsrOzJSUlxf8Di2pZsf+Do6OjjT4dAADQBIWFhbrBwf533O8Di70bSIUVAgsAAL6lKeUcFN0CAACvR2ABAABej8ACAAC8HoEFAAB4PQILAADwegQWAADg9QgsAADA6xFYAACA1yOwAAAAr0dgAQAAXo/AAgAAvB6BBQAAeD0CSwMKy6vktTVH5Fcf7vPcJwIAAK5CYGlAiNksr6/PlGU7syW/tLKhXQEAgBsRWBoQERokHaPD9PrJC6Xu/BwAAEADCCyNSItrox9PXiSwAABgFAJLIzq3j9SPJ8+XeOLzAAAA9SCwNCItrjaw0MICAIBhCCxNbGHJooYFAADDEFga0aW9vYaFLiEAAIxCYGlEWm0Ly5nCCimrtHjiMwEAAFcgsDQiNjJUosOD9XoWdSwAABiCwNIEafZuoQt0CwEAYAQCSzO6hWhhAQDAGASWZgQWZrsFAMAYBJZmzHZ7gi4hAAAMQWBpzlwsFN0CAGAIAkszuoROXSqTaovV3Z8JAAC4AoGlCTpGhUtYsFmqrTY5nV/elEMAAIALEViacpHMJulce08h6lgAAPA8AktzRwpRxwIAgMcRWJqoc+1IoSxGCgEA4N2BZd68eTJ48GCJioqShIQEmTx5shw5cqTBY0aOHCkmk+mqZcKECY59pk2bdtXrd9xxh3iTLvHMxQIAgFFqbpLTRJs2bZIZM2bo0FJdXS2//vWvZezYsXLo0CFp06amBeJKH3/8sVRWVjq2L1y4IOnp6XLXXXc57acCyjvvvOPYDgsLE29ir2FhaDMAAF4eWFatWuW0nZGRoVtadu3aJbfddlu9x8TFxTltL126VCIjI68KLCqgJCYmivffT6hUbDabbgUCAAA+UMNSUFBQbyhpyFtvvSX33HPPVS0yGzdu1OGnV69e8sgjj+iWmGupqKiQwsJCp8XdOsVGiNkkUlZlkXNFFW7/eQAAwAWBxWq1yqxZs2T48OHSr1+/Jh2zY8cOOXDggEyfPv2q7qBFixbJunXr5KWXXtJdT+PHjxeLxXLNWpqYmBjHkpqaKu4WGmyW5NgIvc5IIQAAPMtkU/0bLaBaQb744gvZsmWLpKSkNOmY//qv/5Jt27bJvn37Gtzv+++/l+7du8uXX34po0ePrreFRS12qoVFhRbV4hMdHS3ucv+b38iWzPPy6l3p8pOBTfs3AwCA+qm/36rhoSl/v1vUwjJz5kxZuXKlbNiwoclhpaSkRNevPPTQQ43u261bN4mPj5fMzMx6X1f1LuofVnfx5D2FTjK0GQAA7y26VY0xjz76qCxfvlzXnHTt2rXJx37wwQe6VeT+++9vdN+cnBxdw5KUlCTeJK12pJAqvAUAAJ7TrBYWNaR58eLFsmTJEj0XS15enl7Kysoc+0yZMkXmzJlTb7Gtmrelffv2Ts8XFxfLL3/5S9m+fbucOHFC17FMmjRJevToIePGjRNvwmy3AAD4QAvLwoULHZPB1aXmT1GTvylZWVliNjvnIDW5nKp1WbNmzVXvGRQUpGta3n33XcnPz5fk5GQ9t8vzzz/vdXOx2Ic2M9stAABe3iXUGNVVdCU1VPlax0ZERMjq1avFF9gnj7tUWiUFZVUSExFi9CkBABAQuJdQM7QJC5b4tjWtPlnUsQAA4DEElhbXsZS44/MAAAD1ILA0EyOFAADwPAJLiwtvGdoMAICnEFha2CV0gsnjAADwGAJLC2e7zbpICwsAAJ5CYGlhDUteYbmUV9V/c0YAAOBaBJZmimsTKlFhwaKmlcm5RCsLAACeQGBpJpPJ5OgWOnGewAIAgCcQWFqAewoBAOBZBJYW6BzHPYUAAPAkAksL0MICAIBnEVhaEViYPA4AAM8gsLRittvsS6VisTZ+B2sAANA6BJYWSIwOl9Ags1RZbHI6v6yVHwEAAGgMgaUFgswmSYmL0OvMeAsAgPsRWFqIuzYDAOA5BJZW1rGc5CaIAAC4HYGltUObLzDbLQAA7kZgaSHmYgEAwHMILC6Y7dam7oQIAADchsDSQqlxEWIyiZRUWuRCSaVrPxUAAOCEwNJCYcFBkhxTM7SZwlsAANyLwNIKneMovAUAwBMILK3ASCEAADyDwNIKne03QbzI0GYAANyJwNIKXWonjzvB5HEAALgVgcUFNSxZTB4HAIBbEVhcUMOihjUXV1S76jMBAABXILC0QlR4iMS1CdXrDG0GAMB9CCwuamWhWwgAAPchsLRSWm0dywnqWAAA8I7AMm/ePBk8eLBERUVJQkKCTJ48WY4cOdLgMRkZGWIymZyW8PBwp33UvXjmzp0rSUlJEhERIWPGjJGjR4+KL+hcO1Io62KJ0acCAIDfalZg2bRpk8yYMUO2b98ua9eulaqqKhk7dqyUlDT8xzo6Olpyc3Mdy8mTJ51ef/nll+X111+XN954Q7755htp06aNjBs3TsrLy8VXWlhO0sICAIDbBDdn51WrVl3VeqJaWnbt2iW33XbbNY9TrSqJiYn1vqZaV+bPny+/+c1vZNKkSfq5RYsWSceOHWXFihVyzz33iDdjtlsAALy8hqWgoEA/xsXFNbhfcXGxpKWlSWpqqg4lBw8edLx2/PhxycvL091AdjExMTJkyBDZtm1bve9XUVEhhYWFTotR0mq7hE4XlElFtcWw8wAAwJ+1OLBYrVaZNWuWDB8+XPr163fN/Xr16iVvv/22fPLJJ7J48WJ93LBhwyQnJ0e/rsKKolpU6lLb9tfqq6VRoca+qCBklPi2oRIZGiQ2m0jOpTLDzgMAAH/W4sCialkOHDggS5cubXC/oUOHypQpU2TAgAHygx/8QD7++GPp0KGD/PnPf27pj5Y5c+bo1h37kp2dLUZR3V32KfqPnS027DwAAPBnLQosM2fOlJUrV8qGDRskJSWlWceGhITIjTfeKJmZmXrbXtty5swZp/3U9rXqXsLCwnQhb93FSH2Sa37+wdPGdU0BAODPmhVYVIGsCivLly+X9evXS9euXZv9Ay0Wi+zfv18PYVbUe6hgsm7dOsc+qiZFjRZSrTO+oJ8jsNTU9AAAAANHCaluoCVLluh6FDUXi73GRNWRqPlTFNX906lTJ11novz2t7+VW265RXr06CH5+fnyyiuv6GHN06dPd3SpqFqYF154Qa677jodYJ5++mlJTk7W87z4gn6dYvTjgVO0sAAAYHhgWbhwoX4cOXKk0/PvvPOOTJs2Ta9nZWWJ2Xy54ebSpUvys5/9TIebdu3aycCBA2Xr1q3Sp08fxz5PPvmknsvl4Ycf1qHm1ltv1UOor5xgzltdnxQtJpNIXmG5nCuqkA5RYUafEgAAfsVkU/08Pk51IalWHlWAa1Q9y+jfbZRj50ok44HBMrJXgiHnAACAv/795l5CLu4WovAWAADXI7C4SL9kex0LhbcAALgagcVF+naqaco6wEghAABcjsDiIn1rW1iyL5ZJQWmVq94WAAAQWFwnJiJEOtfeuZn5WAAAcC1aWFyoH91CAAC4BYHFDd1CTCAHAIBrEVjcMeMthbcAALgUgcWF+tbeU+j4+RIprqh25VsDABDQCCwuFN82TJJiwkXNHfxdLvcVAgDAVQgsbqtjYQI5AABchcDirpFC3LkZAACXIbC4aYp+5mIBAMB1CCwudkNKTWA5erZYyqssrn57AAACEoHFxRKiwnTxrcVqk8N5Ra5+ewAAAhKBxcVMJlOdOhYKbwEAcAUCixtQxwIAgGsRWNyAkUIAALgWgcWNc7EcySuSymqrO34EAAABhcDiBintIiQmIkQqLVY5epbCWwAAWovA4ubC24NMIAcAQKsRWNxceMudmwEAaD0Ci5v07cQ9hQAAcBUCi5v0S67pEjqUWyjVFgpvAQBoDQKLm3Rp30bahAZJeZVVvj9f4q4fAwBAQCCwuOvCmk2O4c3MeAsAQOsQWNyor2OK/kJ3/hgAAPwegcWNGCkEAIBrEFjcqF/tSKFDpwvFarW580cBAODXCCxu1L1DGwkLNktxRbWcvFjqzh8FAIBfI7C4UXCQWa5PstexFLjzRwEA4NcILJ66c/NpAgsAAB4JLPPmzZPBgwdLVFSUJCQkyOTJk+XIkSMNHvPXv/5VRowYIe3atdPLmDFjZMeOHU77TJs2Td9/p+5yxx13iD8V3nJPIQAAPBRYNm3aJDNmzJDt27fL2rVrpaqqSsaOHSslJdeeGG3jxo1y7733yoYNG2Tbtm2Smpqqjzl16pTTfiqg5ObmOpb33ntP/KnwVrWw2GwU3gIA0BLBzdl51apVTtsZGRm6pWXXrl1y22231XvM3//+d6ftN998Uz766CNZt26dTJkyxfF8WFiYJCYmir+5rmNbCQkySX5plZzKL5OUdpFGnxIAAIFVw1JQUFOXERcX1+RjSktLdcvMlceolhgVfnr16iWPPPKIXLhw4ZrvUVFRIYWFhU6LtwoLDpKeHaP0OhPIAQDg4cBitVpl1qxZMnz4cOnXr1+Tj/vVr34lycnJupalbnfQokWLdKvLSy+9pLuexo8fLxaL5Zq1NDExMY5FdTP5RB0LhbcAALi/S6guVcty4MAB2bJlS5OPefHFF2Xp0qW6NSU8PNzx/D333ONYv+GGG6R///7SvXt3vd/o0aOvep85c+bI7NmzHduqhcWbQ4saKbRsJ0ObAQDwaAvLzJkzZeXKlbqQNiUlpUnHvPrqqzqwrFmzRgeShnTr1k3i4+MlMzOz3tdVvUt0dLTT4s36OgpvvbfrCgAAvwksapSLCivLly+X9evXS9euXZt03MsvvyzPP/+8LtodNGhQo/vn5OToGpakpCTxB9cnRovZJHKuqELOFpYbfToAAPh3YFHdQIsXL5YlS5bouVjy8vL0UlZW5thHjfxRXTZ2qibl6aeflrffflu6dOniOKa4uFi/rh5/+ctf6qHSJ06c0HUskyZNkh49esi4cePEH0SEBkmPhLZ6nQnkAABwc2BZuHChHhk0cuRI3fphX5YtW+bYJysrS8+jUveYyspK+clPfuJ0jOoiUoKCgmTfvn3yox/9SHr27CkPPfSQDBw4UL766ivd9eMv7IW3+3PoFgIAwK1Ft02Z+EwVytalWk0aEhERIatXrxZ/d0NKjHy855Tsy8k3+lQAAPA53EvIQ9JTY/XjtznMeAsAQHMRWDykT1K0BJtNcr64QnILKLwFAKA5CCweEh4SJL0Sa2a8/TabbiEAAJqDwOJB/VMudwsBAICmI7B4UHpKzUghCm8BAGgeAosH2VtY9ucUiNXa+IgrAABQg8DiQT07tpXwELMUVVTL8QslnvzRAAD4NAKLBwUHmR0TyFF4CwBA0xFYDOoW2kfhLQAATUZg8bD01NoWFma8BQCgyQgsBrWwHDpdKFUWq6d/PAAAPonA4mFd2kdKdHiwVFRb5Uhekad/PAAAPonA4mEmk8lxXyHqWAAAaBoCiwH6104gx0ghAACahsBi6BT93FMIAICmILAYIL02sBw9WyxllRYjTgEAAJ9CYDFAYky4JESFicVqk4OnuREiAACNIbAYhDs3AwDQdAQWgwywTyCXTR0LAACNIbAYPkU/gQUAgMYQWAwe2nziQqkUlFYZdRoAAPgEAotBYiNDJa19pF7fd4pWFgAAGkJgMRB3bgYAoGkILAZKr+0W2kvhLQAADSKwGOjyPYXoEgIAoCEEFgP1TY4Ws0nkTGGFnCksN/JUAADwagQWA0WGBkvPjlF6nflYAAC4NgKLlwxv3pfDFP0AAFwLgcVg3LkZAIDGEVgMNsBReFsgNpvN6NMBAMArEVgM1isxSkKDzVJQViUnL5QafToAAHglAovBQoLM0icpWq9/y/BmAABaH1jmzZsngwcPlqioKElISJDJkyfLkSNHGj3ugw8+kN69e0t4eLjccMMN8vnnnzu9rrpC5s6dK0lJSRIRESFjxoyRo0ePSqBNIEfhLQAALggsmzZtkhkzZsj27dtl7dq1UlVVJWPHjpWSkpJrHrN161a599575aGHHpI9e/bokKOWAwcOOPZ5+eWX5fXXX5c33nhDvvnmG2nTpo2MGzdOysvLA6vwlhlvAQCol8nWikrPc+fO6ZYWFWRuu+22eve5++67daBZuXKl47lbbrlFBgwYoAOK+vHJycny+OOPyxNPPKFfLygokI4dO0pGRobcc889jZ5HYWGhxMTE6OOio2u6V3xJ5tkiGfPaZgkPMcuBZ8dJcBA9dQAA/1fYjL/frfrLqH6AEhcXd819tm3bprt46lKtJ+p55fjx45KXl+e0jzr5IUOGOPa5UkVFhf5H1l18Wbf4ttI2LFjKq6xy9Gyx0acDAIDXaXFgsVqtMmvWLBk+fLj069fvmvupMKJaS+pS2+p5++v25661T321NCrU2JfU1FTxZWazSW7oZK9j4b5CAAC4LLCoWhZVh7J06VLxtDlz5ujWHfuSnZ0tvq5/ak1g+ZYZbwEAuEqwtMDMmTN1TcrmzZslJSWlwX0TExPlzJkzTs+pbfW8/XX7c2qUUN19VJ1LfcLCwvTiT9JrC29pYQEAoJUtLKpAVoWV5cuXy/r166Vr166NHjN06FBZt26d03NqhJF6XlHvoUJL3X1UTYoaLWTfJ5DuKXQ4t0jKqyxGnw4AAL4bWFQ30OLFi2XJkiV6LhZVY6KWsrIyxz5TpkzRXTZ2jz32mKxatUp+97vfyeHDh+XZZ5+VnTt36uCjmEwmXQvzwgsvyD/+8Q/Zv3+/fg81ckgNfw4UnWIjpH2bUKm22uRQrm8XEQMAYGhgWbhwoa4ZGTlypO6+sS/Lli1z7JOVlSW5ubmO7WHDhumA85e//EXS09Plww8/lBUrVjgV6j755JPy6KOPysMPP6wnpisuLtYhR000FyhUcEu331eI+VgAAHDdPCzewtfnYbGb/+W/ZP6XR+XOGzvJ7++uv34HAAB/4bF5WOBaN3Zupx93Z13i0gIAUAeBxYsMSI0Vk0n0XZvPFVUYfToAAHgNAosXiYkIkZ4JUXqdVhYAAC4jsHiZm9Jqu4VO0i0EAIAdgcXLDKwNLLsILAAAOBBYvDSw7DtVIBXVTCAHAIBCYPEyXdpHSlybUKmstsrB00wgBwCAQmDxwgnkbrIPb6ZbCAAAjcDihahjAQDAGYHFiwPLzpOX9A0nAQAIdAQWL71zc7DZpCePy7l0+caSAAAEKgKLFwoPCZK+nWL0OhPIAQBAYPFaA2sLb5mPBQAAAovXovAWAIDL6BLyUjelxerH73ILpaSi2ujTAQDAUAQWL5UUEyGdYiPEahP5Njvf6NMBAMBQBBYfuBGiGt4MAEAgI7B4sYGda7qFKLwFAAQ6AosXG5gW5xjabFV9QwAABCgCixe7PilKIkKCpKi8WjLPFRt9OgAAGIbA4sWCg8wyIJVuIQAACCxejvlYAACghcVnAstuRgoBAAIYLSxe7sbakULfny+RiyWVRp8OAACGILB4udjIUOmR0Fav08oCAAhUBBZfuhFiFhPIAQACE4HFB1B4CwAIdAQWH5qiX91TqMpiNfp0AADwOAKLD+gW30ZiI0Okotoqh04XGn06AAB4HIHFB5jNJrnJXsfC8GYAQAAisPhaHQuFtwCAANTswLJ582aZOHGiJCcni8lkkhUrVjS4/7Rp0/R+Vy59+/Z17PPss89e9Xrv3r1b9i/yU/YWFoY2AwACUbMDS0lJiaSnp8uCBQuatP8f/vAHyc3NdSzZ2dkSFxcnd911l9N+KsDU3W/Lli3NPTW/lp4aI0Fmk+QWlMvp/DKjTwcAAI8Kbu4B48eP10tTxcTE6MVOtchcunRJHnjgAecTCQ6WxMTE5p5OwIgMDZY+SdGy/1SBrmNJjo0w+pQAAPDfGpa33npLxowZI2lpaU7PHz16VHczdevWTe677z7Jysry9Kl5PeZjAQAEKo8GltOnT8sXX3wh06dPd3p+yJAhkpGRIatWrZKFCxfK8ePHZcSIEVJUVFTv+1RUVEhhYaHTEkjzseym8BYAEGCa3SXUGu+++67ExsbK5MmTnZ6v28XUv39/HWBUC8z7778vDz300FXvM2/ePHnuueckUFtYDp4ulNLKat1NBABAIPBYC4vNZpO3335bfvrTn0poaGiD+6pQ07NnT8nMzKz39Tlz5khBQYFjUYW8gSA5JlwSo8PFYrXJt9kFRp8OAAD+F1g2bdqkA0h9LSZXKi4ulmPHjklSUlK9r4eFhUl0dLTTEgjUcG97KwvdQgCAQNLswKLCxN69e/WiqHoTtW4vklWtH1OmTKm32FZ19fTr1++q15544gkdaE6cOCFbt26VO++8U4KCguTee+9t2b8qAOpYmPEWABBIml0EsXPnThk1apRje/bs2fpx6tSpunBWzaFy5Qgf1W3z0Ucf6TlZ6pOTk6PDyYULF6RDhw5y6623yvbt2/U6nA2qDSw7T1zUXUNqbhYAAPydyaaKS3ycGiWk5npRwcjfu4eqLVYZ8Nu1UlxRLZ/OvFVuSLk8xw0AAP7695t7CfmY4CCzDOkap9e3fX/e6NMBAMAjCCw+aGj39vpx67ELRp8KAAAeQWDx4cCy4/hFqbJYjT4dAADcjsDig65PjJbYyBAprbTIvhzmYwEA+D8Ciw8ym00ytFtNK8u2Y9SxAAD8H4HFR1HHAgAIJAQWHzWsto5FTSBXXmUx+nQAAHArAouP6t6hrXSICpOKaqvsyco3+nQAAHArAosP31eIOhYAQKAgsPhBt9C275mPBQDg3wgsPmxY93j9qLqESiurjT4dAADchsDiw1LjIqRTbIRUW23yzxOXjD4dAADchsDi63Us9m4hpukHAPgxAou/1LEwgRwAwI8RWHycvYVl/6kCKSyvMvp0AABwCwKLj0uKiZCu8W3EahPZ8f1Fo08HAAC3ILD4AabpBwD4OwKLH9WxbKWOBQDgpwgsfuCW2js3H84rkosllUafDgAALkdg8QPxbcOkV8covb6dWW8BAH6IwOJ3dSznjT4VAABcjsDiJyi8BQD4MwKLn7ila3sxmUS+P1ciZwrLjT4dAABcisDiJ2IiQ6RfcoxeZ5p+AIC/IbD4EepYAAD+isDiRxw3QmSkEADAzxBY/MjgLnESbDZJ9sUyyb5YavTpAADgMgQWP9I2LFjSU2P1OnUsAAB/QmDxM0NrZ72lWwgA4E8ILH58XyGbzWb06QAA4BIEFj9zU1o7CQ02y5nCCvn+fInRpwMAgEsQWPxMeEiQ3NSZOhYAQIAHls2bN8vEiRMlOTlZTCaTrFixosH9N27cqPe7csnLy3Pab8GCBdKlSxcJDw+XIUOGyI4dO5r/r4E2rHu8fqTwFgAQsIGlpKRE0tPTdcBojiNHjkhubq5jSUhIcLy2bNkymT17tjzzzDOye/du/f7jxo2Ts2fPNvf0UKeORRXeWq3UsQAAfF9wcw8YP368XppLBZTY2Jquiiu99tpr8rOf/UweeOABvf3GG2/IZ599Jm+//bY89dRTzf5Zga5/Sqy0CQ2SiyWVsu9UgQyoHeoMAICv8lgNy4ABAyQpKUluv/12+frrrx3PV1ZWyq5du2TMmDGXT8ps1tvbtm2r970qKiqksLDQacFlquh2ZO+aFqw1B5273gAA8EVuDywqpKgWk48++kgvqampMnLkSN31o5w/f14sFot07NjR6Ti1fWWdi928efMkJibGsaj3hLOxfWqu52oCCwAgELuEmqtXr156sRs2bJgcO3ZMfv/738vf/va3Fr3nnDlzdM2LnWphIbQ4G9U7QUKCTHLsXIlkni2WHgltW/wZAgBgNEOGNd98882SmZmp1+Pj4yUoKEjOnDnjtI/aTkxMrPf4sLAwiY6OdlrgLDo8RIbWjhZae8j52gIA4GsMCSx79+7VXUVKaGioDBw4UNatW+d43Wq16u2hQ4cacXp+g24hAEDAdgkVFxc7WkeU48eP6wASFxcnnTt31t01p06dkkWLFunX58+fL127dpW+fftKeXm5vPnmm7J+/XpZs2aN4z1U987UqVNl0KBBuvVFHaOGT9tHDaHlgeU3Kw7I3ux8OVNYLh2jw7mUAIDACCw7d+6UUaNGObbttSQqcGRkZOg5VrKyspxGAT3++OM6xERGRkr//v3lyy+/dHqPu+++W86dOydz587VhbZqRNGqVauuKsRF8yREh8uNnWNlT1a+7ha6/5Y0LiEAwCeZbH5whzxVdKtGCxUUFFDPcoWFG4/JS6sOy4jr4uVvDw0x5gMCAKCVf7+5l5CfG9e3o2Oa/oKyKqNPBwCAFiGw+LluHdrqIc3VVptsPMKtDgAAvonAEkCjhdYcZHgzAMA3EVgCwLi+NfPZqBaW8iqL0acDAECzEVgCwA2dYiQxOlxKKi2y9dh5o08HAIBmI7AEALPZJGNri2/pFgIA+CICS4AY26emW+jL786IxerzI9kBAAGGwBIghnSLk+jwYDlfXCm7sy4ZfToAADQLgSVAhASZZfT19m6hPKNPBwCAZiGwBOTNEM+IH0xwDAAIIASWAPKDXh0kLNgsWRdL5ciZIqNPBwCAJiOwBJDI0GB9TyGF0UIAAF9CYAnQ0UKrqWMBAPgQAkuAGX19gphNIgdPF0rOpVKjTwcAgCYhsASY9m3DZFCXOL2+9hD3FgIA+AYCS0CPFmJ4MwDANxBYAvhmiDuOX5RLJZVGnw4AAI0isASg1LhIuT4pWtQM/WqqfgAAvB2BJUCNs98MkToWAIAPILAE+PDmr46ek7JKi9GnAwBAgwgsAer6pChJjYuQ8iorxbcAAK9HYAlQJpNJ7hqYqtff3XbC6NMBAKBBBJYAdu/NnSUkyCR7svJlX06+0acDAMA1EVgCWIeoMJlwQ5Jez9hKKwsAwHsRWALc1GFd9OPKb3PlfHGF0acDAEC9CCwB7sbO7SQ9JUYqLVZZuiPL6NMBAKBeBBY4WlkWb8+SaouVKwIA8DoEFsiE/kkS3zZU8grLmUgOAOCVCCyQsOAgPWJIofgWAOCNCCzQ7huSJkFmk74h4ne5hVwVAIBXIbBAS4wJlztq7+L8LkOcAQBehsCCq4pvV+w9JfmllVwZAIDvBpbNmzfLxIkTJTk5WU/vvmLFigb3//jjj+X222+XDh06SHR0tAwdOlRWr17ttM+zzz6r36vu0rt37+b/a9Aqg7u0k+uTovX9hd7fmc3VBAD4bmApKSmR9PR0WbBgQZMDjgosn3/+uezatUtGjRqlA8+ePXuc9uvbt6/k5uY6li1btjT31NBKKihOG5am1xdtOykWq41rCgDwCsHNPWD8+PF6aar58+c7bf/v//6vfPLJJ/Lpp5/KjTfeePlEgoMlMbGmhgLGmTSgk8z74rDkXCqT9YfPyu19OvJxAAACr4bFarVKUVGRxMXFOT1/9OhR3c3UrVs3ue+++yQr69qzrlZUVEhhYaHTAtcIDwmSuwfX3sWZ4lsAQKAGlldffVWKi4vlP/7jPxzPDRkyRDIyMmTVqlWycOFCOX78uIwYMUIHm/rMmzdPYmJiHEtqas0fWLjG/UPSxGwS2ZJ5XjLP1v8ZAADgt4FlyZIl8txzz8n7778vCQkJjudVF9Ndd90l/fv3l3Hjxul6l/z8fL1ffebMmSMFBQWOJTubAlFXSo2LlNHXd3TUsgAAEDCBZenSpTJ9+nQdQsaMGdPgvrGxsdKzZ0/JzMys9/WwsDA94qjuAteaVjvE+aNdOVJUXsXlBQD4f2B577335IEHHtCPEyZMaHR/1WV07NgxSUpK8sTpoR7DureX6xLaSkmlRT7clcM1AgD4VmBRYWLv3r16UVS9iVq3F8mq7popU6Y4dQOp7d/97ne6ViUvL08vqivH7oknnpBNmzbJiRMnZOvWrXLnnXdKUFCQ3Hvvva75V6JFQ5yn1LayqG4hK0OcAQC+FFh27typhyPbhyTPnj1br8+dO1dvqzlU6o7w+ctf/iLV1dUyY8YM3WJiXx577DHHPjk5OTqc9OrVSxfjtm/fXrZv364nm4NxfnxjJ4kKD5bj50tk07/O8VEAAAxjstlsPj87mBrWrEYLqVYb6llc638+OyR//eq4dO/QRj5/bIS+szMAAJ7++829hNCgmaOuk/i2oXLsXIn8dfP3XC0AgCEILGhQTGSIPP1vffT6H9dnyskLJVwxAIDHEVjQqB+lJ8utPeKlotoqv1lxQPygFxEA4GMILGjSiKHnJ/eT0GCzfHX0vHy6L5erBgDwKAILmqRrfBuZMbKHXn9+5SEpKGMyOQCA5xBY0GQ/H9lNusW3kXNFFfLq6iNcOQCAxxBY0GRqSPMLd/bT64u/OSl7s/O5egAAjyCwoFmGdY/XE8qputtff7xfqi1WriAAwO0ILGi2X0+4XmIiQuRQbqFkbD3BFQQAuB2BBc0W3zZM5ozvrddfW/svOZ1fxlUEALgVgQUt8h+DUmVQWjsprbTIc58e5CoCANyKwIKWfXHMJvmfO2+QYLNJVh88I18eOsOVBAC4DYEFLdYrMUqmj+im15/5x0EprazmagIA3ILAglZ5bPR1ktIuQk7ll8nzK79j2n4AgFsQWNAqEaFB8sLkfmIyiby3I0v+sO4oVxQA4HIEFrTayF4J8ttJNRPKzf/yqLzLUGcAgIsRWOASP70lTWbf3tNRz/LJ3lNcWQCAyxBY4DKP/rCHTBvWRa8//v63suHIWa4uAMAlCCxwGZPJJHP/rY9MHpAs1VabPLJ4l+w8cZErDABoNQILXD4/yyt3pcsPeydIeZVVHsz4p3yXW8hVBgC0CoEFLhcSZJYF/3mTngm3sLxapry9Q7IulHKlAQAtRmCB24Y7vzVtsPROjJJzRRVy/1vfyNmicq42AKBFCCxwG3VH50UP3iyd4yIl62KpTHlrhxSUVXHFAQDNRmCBWyVEh8vih4ZIh6gwOZxXJPe9uZ3uIQBAsxFY4Had20fqlpZ2kSFy4FShTHj9K1m57zRXHgDQZAQWeMT1SdHy2f8doQtxiyqqZeaSPfL/lu+X8ioLnwAAoFEEFnhMcmyELH34Fpk5qoe+99Dfv8mSyQu+lsyzxXwKAIAGEVjgUcFBZnliXC/dRRTfNlTXtUz84xb5cFcOnwQA4JoILDDEiOs6yOePjZDhPdpLWZVFnvjgW5n9/l4pqajmEwEAXIXAAsMkRIXLogeHyBNje4rZJPLx7lMy8U9b5NBpZsYFADgjsMBQQWaTzPzhdbL04aGSGB0u358r0aHlVx/uk1P5ZXw6AACNwAKvcHPXON1FdEffRLFYbbJsZ7aMemWjPPuPg8yQCwBofmDZvHmzTJw4UZKTk/XdeVesWNHoMRs3bpSbbrpJwsLCpEePHpKRkXHVPgsWLJAuXbpIeHi4DBkyRHbs2MHHE2Di2oTKGz8dKB89MkyGdmsvlRarZGw9Ibe9vEFe/OKwXCqpNPoUAQC+ElhKSkokPT1dB4ymOH78uEyYMEFGjRole/fulVmzZsn06dNl9erVjn2WLVsms2fPlmeeeUZ2796t33/cuHFy9uzZ5p4e/MDAtHby3sO3yJLpQ+TGzrH6rs9vbDqmg8v8L/8lReVM7w8AgcZks9lsLT7YZJLly5fL5MmTr7nPr371K/nss8/kwIEDjufuueceyc/Pl1WrVult1aIyePBg+dOf/qS3rVarpKamyqOPPipPPfVUo+dRWFgoMTExUlBQINHR0S3958ALqa/n+sNn5dU1/5LvcmuKcdWMuQ/f1l3uHpyqW2UAAL6pOX+/3V7Dsm3bNhkzZozTc6r1RD2vVFZWyq5du5z2MZvNetu+z5UqKir0P7LuAv+kQvHo6zvKZ4/eKn/6zxulW4c2cqm0Sl5adVhu/p8v5eFFO2XNwTyprLYafaoAADcKFjfLy8uTjh07Oj2ntlXIKCsrk0uXLonFYql3n8OHD9f7nvPmzZPnnnvOrecN72I2m+Tf+ifrotzle07J37aflH05BbLm0Bm9qJaWSQOS5ScDU6RvcozRpwsA8LXA4g5z5szRNS92KvyoLiQExky5dw1K1cu/zhTJR7ty5OM9p+RcUYW88/UJvfROjNLBZdKATvou0QAA3+f2wJKYmChnzpxxek5tq76qiIgICQoK0kt9+6hj66NGG6kFga1nxyiZ83+ul1+O6yVfHT0vH+7OkbUHz+jp/l/47DuZ98VhGZAaKyOui9cz66anxOjAAwDwPW4PLEOHDpXPP//c6bm1a9fq55XQ0FAZOHCgrFu3zlG8q4pu1fbMmTPdfXrwAyqEjOqdoJeC0ir5dN9pfW+ivdn5suvkJb3M//KoRIUHy/Du8TKiZ7zcdl0HSY2LNPrUAQDuCizFxcWSmZnpNGxZDVeOi4uTzp076+6aU6dOyaJFi/TrP//5z/XonyeffFIefPBBWb9+vbz//vt65JCd6t6ZOnWqDBo0SG6++WaZP3++Hj79wAMPNPf0EOBiIkPk/lvS9JJzqVS2HD2vW1+2ZJ6XgrIqWXUwTy9Kl/aRuuVlUJd20j8lVtLiInWtDADAD4Y1q0ng1JwqV1KBQ00IN23aNDlx4oTer+4xv/jFL+TQoUOSkpIiTz/9tN6vLhVqXnnlFV2kO2DAAHn99df1cOemYFgzGqNmz92Xk6/Dy1dHz8nurHz9XF2qBeaGTjE6vPRPidHrKe0i9EglAIDrNefvd6vmYfEWBBY0l5p8btuxC/J15nn5NqdADuUW1js0Wo0+UsGld1KUdI9vq4dVd+/QVtox/wsAtBqBBWimKotVjzran1Mg+04V6NaYw7lFUn1FK4ydmryuW4e20i2+jXRPqHnsEt9GkmMjpG2YTw6+AwCPI7AALlBeZZEjeUU6wGSeKZLvz5fIsbPFcrqgvMHjYiJCJCkmXDrFRugAU7PUbCfGhEt82zAJDwniMwIQ8Aqb0SXEfwoC16BCRXpqrF7qKq2sluMqvJwrke/PFcv359R6sWRfLJXC8mpd3KsWNbz6WqLCgiU+Kkzi24bqAKOW9nXWYyND9KLCT2xEqISHmKmlARDQCCxAM0WGBuvZdOubUVfVxuQWlMup/DI5Xbvk5tduF5RJXkG5VFlsUlRRrRcVfJoiNMisR0DVBJiaR1Uk3CYsWNqGB0vb0JpHta3CkP35yNAgiQwJlvBQsz7viJAgCWIkFAAfRGABXCgqXAWJED2pXX1UjXthWbWcL6mQ80UVcr64Us4XV8iF4go5V2fd3kqTX1ql62gqLVY9m69aWis02KyDiwoz6lG1JIWFmCUsWC1BNY/qef1Y85w6RoUm9RgSZJIQx3rN8+rR/nxwkEmCzTXbao6cYLPJ6TkVmNS6erQvwVesMzILwJUILIAHqT/EuqUkMkSPNmqMCjillRbJVwGmtEryyyqlsDbIFKtWmvJqKamo1uv2paT2ebVeVmmRsqqaxT4eUI2GUosKRN5KjSQPMpn0vDjqUQUZ1TBk33Y81j5ntq+bVNiR2v1rgo8alG42175W+xmofezH6Gf1ds26/TX7aHb7uuPY2vPTezuev3ys/fzt71uz5+VjL79uX69ZcwyedxxT+3yd96h7feruU/e5K9flWvvU836Nfi5OR9V/XFPeikDqfUxN+ODUf0z8vwl9PHE69f98w34ygCb9YlfdO2pRRbstpYJPRbVVhx8dYNRjnTBToZZqa+2iti+vl+t1i1RbbDroqBFVqsVHParuLR2AarfVPvrRapPq2terrVY9541er31NbV9rBFbN+YpUq/9pYB8AnqVaVQksANwefFTXj7eNTrLWBpeaAGMVq1WkSj/axGKz6byi1+3b1prn1LZVv167j82mQ1nd12y16yryqG21Yt/fduWj+r/a1xX7eu1hjuPt+6nnbFds25uw7K/VPFXPc3Wet7v82uWfb9+37nZ9+zi/T/3H1/ezrny/hva51ns16cAWvI/vzw7mPWxN++SaJEg1VRqIFhYAhlHdOaGOImDvClMAvAu3rgUAAF6PwAIAALwegQUAAHg9AgsAAPB6BBYAAOD1CCwAAMDrEVgAAIDXI7AAAACvR2ABAABej8ACAAC8HoEFAAB4PQILAADwegQWAADg9fzibs32W6oXFhYafSoAAKCJ7H+37X/H/T6wFBUV6cfU1FSjTwUAALTg73hMTEyD+5hsTYk1Xs5qtcrp06clKipKTCaTy9OfCkLZ2dkSHR3t0vcG19tofL+53v6M77f3X28VQVRYSU5OFrPZ7P8tLOofmZKS4tafoS4+gcVzuN6exfXmevszvt/efb0ba1mxo+gWAAB4PQILAADwegSWRoSFhckzzzyjH+F+XG/P4npzvf0Z32//ut5+UXQLAAD8Gy0sAADA6xFYAACA1yOwAAAAr0dgAQAAXo/A0ogFCxZIly5dJDw8XIYMGSI7duzwzCfj5zZv3iwTJ07Usxuq2YlXrFjh9LqqBZ87d64kJSVJRESEjBkzRo4ePWrY+fqyefPmyeDBg/VM0AkJCTJ58mQ5cuSI0z7l5eUyY8YMad++vbRt21b+/d//Xc6cOWPYOfuyhQsXSv/+/R2TZw0dOlS++OILx+tca/d68cUX9e+UWbNmcc3d4Nlnn9XXt+7Su3dvj3y/CSwNWLZsmcyePVsP09q9e7ekp6fLuHHj5OzZsy65+IGspKREX08VCOvz8ssvy+uvvy5vvPGGfPPNN9KmTRt97dX/M6B5Nm3apH+BbN++XdauXStVVVUyduxY/RnY/eIXv5BPP/1UPvjgA72/utXFj3/8Yy51C6hZt9UfzV27dsnOnTvlhz/8oUyaNEkOHjzItXazf/7zn/LnP/9ZB8a6+H67Vt++fSU3N9exbNmyxTPXWg1rRv1uvvlm24wZMxzbFovFlpycbJs3bx6XzIXU13D58uWObavVaktMTLS98sorjufy8/NtYWFhtvfee49r30pnz57V13zTpk2OaxsSEmL74IMPHPt89913ep9t27ZxvV2gXbt2tjfffJNr7UZFRUW26667zrZ27VrbD37wA9tjjz2mn+f77VrPPPOMLT09vd7X3H2taWG5hsrKSv1fSKorou49i9T2tm3bXJMWUa/jx49LXl6e07VX95pQXXJc+9YrKCjQj3FxcfpRfc9Vq0vd662aeDt37sz1biWLxSJLly7VrVmqa4hr7T6qFXHChAlO32OFa+56qnteded369ZN7rvvPsnKyvLItfaLmx+6w/nz5/Uvm44dOzo9r7YPHz5s2HkFAhVWlPquvf01tPzO5qpvf/jw4dKvXz/H9Q4NDZXY2Fiut4vs379fBxTVhan68ZcvXy59+vSRvXv3cq3dQIVC1W2vuoSuxPfbtdR/OGZkZEivXr10d9Bzzz0nI0aMkAMHDrj9WhNYgAD7r1D1i6VunzNcT/0yV+FEtWZ9+OGHMnXqVN2fD9fLzs6Wxx57TNdnqcERcK/x48c71lWtkAowaWlp8v777+sBEu5El9A1xMfHS1BQ0FXVzWo7MTHRrR9KoLNfX669a82cOVNWrlwpGzZs0IWhda+36gLNz8932p/vesup/8rs0aOHDBw4UI/SUgXmf/jDH7jWbqC6IdRAiJtuukmCg4P1osKhKtpX6+q/7vl+u49qTenZs6dkZma6/ftNYGngF476ZbNu3Tqn5nS1rZp64T5du3bVX+66176wsFCPFuLaN5+qa1ZhRXVLrF+/Xl/futT3PCQkxOl6q2HPql+a6+0a6ndHRUUF19oNRo8erbvgVIuWfRk0aJCurbCv8/12n+LiYjl27JiegsLtv0taXbbrx5YuXapHpmRkZNgOHTpke/jhh22xsbG2vLw8o0/NLyr69+zZoxf1NXzttdf0+smTJ/XrL774or7Wn3zyiW3fvn22SZMm2bp27WorKysz+tR9ziOPPGKLiYmxbdy40Zabm+tYSktLHfv8/Oc/t3Xu3Nm2fv16286dO21Dhw7VC5rvqaee0iOwjh8/rr+7attkMtnWrFnDtfaQuqOEFL7frvP444/r3yXq+/3111/bxowZY4uPj9ejD919rQksjfjjH/+oL35oaKge5rx9+3aXXPhAt2HDBh1UrlymTp3qGNr89NNP2zp27KhD4+jRo21Hjhwx+rR9Un3XWS3vvPOOYx8VBP/7v/9bD7+NjIy03XnnnTrUoPkefPBBW1pamv6d0aFDB/3dtYcVrrUxgYXvt+vcfffdtqSkJP397tSpk97OzMz0yLU2qf9pfTsNAACA+1DDAgAAvB6BBQAAeD0CCwAA8HoEFgAA4PUILAAAwOsRWAAAgNcjsAAAAK9HYAEAAF6PwAIAALwegQUAAHg9AgsAAPB6BBYAACDe7v8DDt7p3+XcHScAAAAASUVORK5CYII=", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "def train_step(model, optimizer, x, y):\n", " # Create state-based loss function for LBFGS\n", @@ -328,10 +203,7 @@ " return loss\n", "\n", "model = make_model(rngs)\n", - "optimizer = nnx.Optimizer(\n", - " model,\n", - " tx=optax.lbfgs(1e-3),\n", - " wrt=nnx.Param)\n", + "optimizer = nnx.Optimizer(model, optax.lbfgs(1e-3), wrt=nnx.Param)\n", "\n", "losses = []\n", "for _ in range(50):\n", @@ -342,7 +214,7 @@ }, { "cell_type": "markdown", - "id": "e6bacba1", + "id": "b14f69c8", "metadata": {}, "source": [ "# Per-Parameter Learning Rates\n", @@ -356,8 +228,8 @@ }, { "cell_type": "code", - "execution_count": 18, - "id": "59dd4440", + "execution_count": null, + "id": "1463df77", "metadata": { "lines_to_next_cell": 2 }, @@ -383,7 +255,7 @@ }, { "cell_type": "markdown", - "id": "3ddce67c", + "id": "31ca36e2", "metadata": {}, "source": [ "# Gradient Accumulation" @@ -391,7 +263,7 @@ }, { "cell_type": "markdown", - "id": "3088c386", + "id": "3710c939", "metadata": {}, "source": [ "Gradient accumulation in Flax is easy: just use the `optax.MultiSteps` optimizer." @@ -399,34 +271,14 @@ }, { "cell_type": "code", - "execution_count": 19, - "id": "3865a3dc", + "execution_count": null, + "id": "d64e4551", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAOqBJREFUeJzt3Ql4VdWh9vE380QSDJAQSAhhFERQEBCCiBXBoVaq1xFEFPQWwYra9oqt2tpeafW2XycvWq+CFhmKFbWoKDIjBAQcAJHBhCRAwkwSMg/ne9YKiaQmkOEkZ/r/nmebfZJ9dtZZiTkva/RzOBwOAQAAuDF/VxcAAADgfAgsAADA7RFYAACA2yOwAAAAt0dgAQAAbo/AAgAA3B6BBQAAuD0CCwAAcHuB8gKVlZU6dOiQIiMj5efn5+riAACABjBr1+bn56tTp07y9/f3/sBiwkpiYqKriwEAAJogKytLCQkJ3h9YTMtK9QuOiopydXEAAEAD5OXl2QaH6vdxrw8s1d1AJqwQWAAA8CwNGc7BoFsAAOD2CCwAAMDtEVgAAIDbI7AAAAC3R2ABAABuj8ACAADcHoEFAAC4PQILAABwewQWAADg9ggsAADA7RFYAACA2yOwAAAAt+cVmx+2lPKKSv33+7ucdr+I4EDdP7KbosOCnHZPAAB8AYHlHCod0pxP9ju1wovKKvTk9/s69Z4AAHg7Ass5+PtJ067q7pSKPpZfqkVbsvTPbQf007G9FRoU4JT7AgDgCwgs56qcAH/9dOyFTqnoikqH1u87poOnivT+9mzdPDDBKfcFAMAXMOi2lQT4++nOIYn2/I1Nma31bQEA8AoEllZ022WJCvT309aMk/o6J681vzUAAB6NwNKKYqNCdU3fOHs+n1YWAAAajMDSysYPTbIf39p2UAUl5a397QEA8EgEllY2vHs7dW0XrtMl5frXF4da+9sDAOCRCCytXeF28G0Xez5/M4NvAQBoCAKLC/zHoAQFB/jrywO5+vLAKVcUAQAAj0JgcYF2bUJ03cUd7TmDbwEAOD8Ci4vcdaZb6N0vDimvuMxVxQAAwCMQWFxkSHKMesS2UWFphd757KCrigEAgEcgsLiIn5+fxg/tUrPyrcPhcFVRAABwewQWF7r50gSFBvnr65x8bcs86cqiAADg1ggsLhQdHqTv9+9kz9lfCACA+hFYXKy6W2jpl9k6VVjq6uIAAOCWCCwudkliW/WNj1JpeaXe3HrA1cUBAMAtEVjcYPDtXUO/XfmWwbcAAHwXgcUNjLu0syKCA5R2tECpaSdcXRwAANwOgcUNtAkJ1E2Xdrbnb2zKcHVxAABwOwQWN1v59sOdOTp2usTVxQEAwK0EuroAqNKvc7QGJLbVF1mn9P0/r1d4SECzq+aShLb6n1sH2B2iAQDwZAQWNzJ5RLJ+vOAz5eQVO+V+ZkyM6Wq6slcHp9wPAABXIbC4kRv7x6tnbBvlF5c3+14LN2fqrc8O6o3UDAILAMDjEVjcbIpzn/gop9zrgvAgG1g+3nVY2blFio8Oc8p9AQBw+0G3s2bN0uDBgxUZGanY2FiNGzdOu3fvPu/z/vjHP6p3794KCwtTYmKiHnnkERUX1+72eOGFF9S1a1eFhoZq6NCh2rx5c+NfDWr0jIvU0OQYVTqkBZuzqBkAgO8EljVr1mjatGlKTU3V8uXLVVZWpjFjxqigoKDe58yfP1+PP/64nn76ae3atUuvvPKKFi1apCeeeKLmGvP40Ucftdds27ZNAwYM0NixY3XkyJHmvTofN+HypJruobKKSlcXBwCAJvNzNGNp1aNHj9qWFhNkRo4cWec106dPt0FlxYoVNZ977LHHtGnTJq1fv94+Ni0qpuXmr3/9q31cWVlpW2IeeughG3bOJy8vT9HR0crNzVVUlHO6VLyBWe5/+G9X6NjpUr04YaCu7Rfv6iIBANCk9+9mrcNivoERExNT7zXDhw/X1q1ba7p40tLS9P777+v666+3j0tLS+3XR48e/W2h/P3t440bN9Z5z5KSEvsizz7wXcGB/rrtskR7Pi81kyoCAHisJgcW0woyY8YMpaSkqF+/fvVed9ddd+mZZ57RiBEjFBQUpO7du2vUqFE1XULHjh1TRUWF4uLiaj3PPM7Jyal3LI1JZNWHaY1B3e4c0kV+ftL6fceUfqz+rjsAALwysJixLDt27NDChQvPed3q1av17LPP6n//93/t+JS33npL7733nn7961839Vtr5syZtnWn+sjKYlBpfRJjwjXqzDos81n2HwDgS9OazbiUpUuXau3atUpISDjntU8++aTuvvtuTZkyxT6++OKL7SDdBx54QD//+c/Vvn17BQQE6PDhw7WeZx537NixznuGhITYAw0ffLtq91Et3npAj43prdCg5q+iCwCA27awmPG5JqwsWbJEK1euVHJy8nmfU1hYaMeknM0ElOr7BQcHa9CgQbUG5ZruJvN42LBhjSke6jGqd6w6tw3TqcIyvfdlNvUEAPDuwGK6gebNm2enKpu1WMwYE3MUFRXVXDNx4kTbZVPtxhtv1OzZs23XUXp6up0ObVpdzOerg4uZ0vzyyy/rtddeszOKpk6dalth7r33Xme+Vp8V4O+nu4ZWba44j24hAIC3dwmZ4GGYQbNnmzNnjiZNmmTPMzMza7Wo/OIXv7AruJqPBw8eVIcOHWxY+e///u+aa26//XY7Rfqpp56yAeiSSy7RsmXLvjMQF01362UJ+n/L9+izzFPaeShXF3WKpjoBAL6xDou7YB2Whpk2f5vtEjKtLc/+8OIW/qkAAOAm67DAs0wYWrXy7dufHVR+cZmriwMAQIMRWHzI5d1i1L1DhApLK2xoAQDAUxBYfIgZSzT+TCvLG5sy7SwtAAA8AYHFx9wyKEGhQf76OidfWzNOuro4AAA0CIHFx0SHBekHAzrZ83mpGa4uDgAADUJg8UHV3ULvb8/RiYJSVxcHAIDzIrD4oAGJbXVx52iVVlRq8Rb2YQIAeOleQvB8Ey7vov/653a78m3H6NBm3y8owF8jerZXVGiQU8oHAMDZWDjORxWWlmvosyuUX1zutHvecHG8Xhg/0Gn3AwB4t7xGLBxHC4uPCg8OtKvdLvo0S5XNnN5snp+adkIf7MhWdm6R4qPDnFZOAAAMAosPu3FAJ3s4w+0vbdSm9BNasClTj47p7ZR7AgBQjUG3cIq7h1XNPFrwaZbKKiqpVQCAUxFY4BRj+nZUh8gQHc0v0Uc7D1OrAACnIrDAKYID/XXH4ER7/vfU/dQqAMCpCCxwmjuHdJG/n+wA3L2H86lZAIDTEFjgNJ3ahml0nzh7zrL/AABnIrCgRQbfvrXtoApKnLfGCwDAtxFY4FQp3dura7tw5ZeU653PD1G7AACnILDAqfz9/TTh8qpWltc37pejmYvSAQBg31+oBjjbfwxKUEigv77Oyde2zJNUMACg2QgscLq24cH6wZkVdP++MYMaBgA0G4EFLTr49v3tOTp+uoRaBgA0C4EFLaJ/Qlv1T4hWaUWl/rHlALUMAGgWAgtaTPXg2zc2ZaiiksG3AICmI7CgxdzYv5Oiw4J04GSR1uw5Qk0DAJqMwIIWExYcoFsHJdjzeamZ1DQAoMkILGhR4890C63afURZJwqpbQBAkxBY0KKS20foip7tZdaPe2MTrSwAgKYhsKDVBt/+Y0uWissqqHEAQKMFNv4pQONcfWGs4qNDlZ1brCWfHdS1F3VsdhWGhwQoJDCAHwUA+Ag/hxds9pKXl6fo6Gjl5uYqKirK1cVBHf6yYq9+v3yP0+omIjhA70wfoR6xbahvAPCB92+6hNAq7hzaRQkXhDntfgWlFXr1k3Sn3Q8A4N5oYUGrMY15zmjPS00/rrte3qTw4AClPnG1okKDnFE8AEAro4UFbsnPz0/+/s0/hnVrp56xbVRYWqG3trLsPwD4ArqE4JHBp3pzxb+nZtiWGwCAdyOwwCP98NLOduDtN0cLtOGb464uDgCghRFY4JEiQ4N088CqZf9f37jf1cUBALhTYJk1a5YGDx6syMhIxcbGaty4cdq9e/c5nzNq1CjbhP/vxw033FBzzaRJk77z9Wuvvbbprwo+obpbaPlXh3XoVJGriwMAcJfAsmbNGk2bNk2pqalavny5ysrKNGbMGBUUFNT7nLfeekvZ2dk1x44dOxQQEKBbb7211nUmoJx93YIFC5r+quATesVF6vJuMap0SPNZ9h8AvFqjVrpdtmxZrcdz5861LS1bt27VyJEj63xOTExMrccLFy5UeHj4dwJLSEiIOnZs/gqo8C0Th3VVatoJLfw0Uz++uqeCA+nlBABv1Ky/7mZlurpCybm88soruuOOOxQREVHr86tXr7bhp3fv3po6daqOH69/IGVJSYmdu332Ad90Td84xUWF6NjpUn2wI9vVxQEAuFtgqays1IwZM5SSkqJ+/fo16DmbN2+2XUJTpkz5TnfQ66+/rhUrVuh3v/ud7Xq67rrrVFFRUe9YGrOUb/WRmJjY1JcBDxcU4K+7hpyZ4rwxw9XFAQC420q3phXkgw8+0Pr165WQUDVb43z+8z//Uxs3btSXX355zuvS0tLUvXt3ffzxx7r66qvrbGExRzXTwmJCC3sJ+aYjecUa/tuVKq906P0fX6G+ndhPCgA8QYuvdDt9+nQtXbpUq1atanBYMQNzzfiVyZMnn/fabt26qX379tq3b1+dXzfjXcwLO/uA74qNCtXYflXjn/6eyhRnAPBGjQospjHGhJUlS5Zo5cqVSk5ObvBzFy9ebFtFJkyYcN5rDxw4YMewxMfHN6Z48GETL6/qFnr7s0PKLSpzdXEAAK4MLGZK87x58zR//ny7FktOTo49ioq+XQNj4sSJmjlzZp2Dbc26Le3atav1+dOnT+unP/2pnSq9f/9+O47lpptuUo8ePTR27NjmvDb4kCHJMeodF6misgr9k/2FAMC3A8vs2bNtP5NZDM60flQfixYtqrkmMzPTrqNyNrO4nBnrUld3kFmTxYxp+cEPfqBevXrZawYNGqR169bZrh+gsfsLzUvNUKVZnAUA4DWaPOjWUwftwHsVlJTr8mdXKL+kXH+fPERX9Ozg6iIBAFw56BZwRxEhgbplUPX+QkxxBgBvQmCBV5lwZvDtil2HdZD9hQDAaxBY4FV6xLZRSo92dn+hN1JpZQEAb0Fggde5+0wry6JPs3S6pFwl5RXNPgAAHrT5IeAJRveJU3x0qLJzi9Xv6Q+dcs8bLo7XC+MHOuVeAIDGo4UFXicwwF/TruohPz/n3fO97dnaeahqs08AQOujhQVeO/j2loEJKq+sbPa9Hv/ndhtYzOaKv72lv1PKBwBoHAILvFZYcIBZmrDZ95mU0tUGlrc/P6jHr7tQbcODnVI+AEDD0SUEnMdlSReoT3yUissqtXjLAeoLAFyAwAI0YNn/e84s+//31AxVsOw/ALQ6AgvQADdd0llRoYHKPFGoNXuOUGcA0MoILEADx8PcdlmiPX9tAwvSAUBrI7AAjZh5ZKZKr9lzVOnHCqg3AGhFBBaggbq2j9CoXlU7QJspzgCA1kNgARph4vCu9uPirVkqLC2n7gCglRBYgEa4smcHJbULV35xud7+7BB1BwCthMACNOZ/GH+/ms0VX9+4Xw6Hg/oDgFZAYAEa6dZBiQoLCtDXOfnanH6C+gOAVkBgARopOjxI4y7tbM9fZ/AtALQKAgvQBBPPrHy7bGeOcnKLqUMAaGEEFqAJzN5CQ5Jj7DL98zcxxRkAWhqBBWiie4ZVTXGevzlTJeUV1CMAtCACC9BEYy6KU1xUiI6dLtWyHTnUIwC0IAIL0ERBAf4aP7RqLMtrG/ZTjwDQgggsQDPcMSRRQQF+2pZ5StsP5FKXANBCAlvqxoAviI0M1fUXx+udzw/pxTXf6P6R3Zp9z/DgAPWMbSM/s9MiAMAisADNNHFYVxtY3tuebQ9n+PVNF+nuM4N6AQAEFqDZBnZpqzuHJGrd3mPNvldpeaWO5Jfob+vSdNfQJAX408oCAAYtLEAzma6bWTf3d0o9FpVW6PJZK5R1okirvj6i0X3j+PkAAINuAfcSFhyg2wcn2vPXNjLzCACqMUsIcDNmN2gz3tZ0Me07ctrVxQEAt0BgAdxMYky4rr6wqivodVpZAMAisABuaNLwqhlC/9x6QPnFZa4uDgC4HIEFcEMpPdqpR2wbFZRW6M2tB1xdHABwOQIL4KYzj+4508pilv2vrHS4ukgA4FIEFsBN3XxpZ0WGBmr/8UKt2XvU1cUBAJcisABuKiIkULcOOjPFmc0VAfi4RgWWWbNmafDgwYqMjFRsbKzGjRun3bt3n/M5o0aNss3b/37ccMMNNdc4HA499dRTio+PV1hYmEaPHq29e/c2/VUBXmLisKopzqt3H1X6sQJXFwcAPCOwrFmzRtOmTVNqaqqWL1+usrIyjRkzRgUF9f8hfeutt5SdnV1z7NixQwEBAbr11ltrrnnuuef05z//WS+++KI2bdqkiIgIjR07VsXFxc17dYCH69o+QqN6dbDnTHEG4Mv8HKZ5o4mOHj1qW1pMkBk5cmSDnvPHP/7RtqaY8GKCifn2nTp10mOPPaaf/OQn9prc3FzFxcVp7ty5uuOOO857z7y8PEVHR9vnRUVFNfXlAG5p9e4jmjTnU0WGBGrjE1erTQg7agDwDo15/27WGBbzDYyYmJgGP+eVV16xIcSEFSM9PV05OTm2G6iaKfzQoUO1cePGOu9RUlJiX+TZB+CtRvbsoOT2EcovKddb25jiDMA3NTmwVFZWasaMGUpJSVG/fv0a9JzNmzfbLqEpU6bUfM6EFcO0qJzNPK7+Wl1jaUyoqT4SE6sGJgLeyN/fT/cMS6oZfNuMRlEA8L3AYsaymPCxcOHCRrWuXHzxxRoyZIiaY+bMmbZ1p/rIyspq1v0Ad3fLoARFBAfom6MFWr/vmKuLAwCeEVimT5+upUuXatWqVUpISGjQc8zAXBNuJk+eXOvzHTt2tB8PHz5c6/PmcfXX/l1ISIjt6zr7ALxZZGiQ/mNQ1f9rTHEG4IsaFVhMU7QJK0uWLNHKlSuVnJzc4OcuXrzYjj2ZMGFCrc+be5hgsmLFiprPmTEpZrbQsGHDGlM8wKtNPLPy7YqvjyjzeKGriwMA7htYTDfQvHnzNH/+fLsWixljYo6ioqKaayZOnGi7bOrqDjLrtrRr167W582aLGYszG9+8xu9++672r59u72HmTlkrgdQpXuHNrqiZ3uZISx/T91PtQDwKY2aHzl79uyaxeDONmfOHE2aNMmeZ2Zmyt+/dg4yi8utX79eH330UZ33/dnPfma7jB544AGdOnVKI0aM0LJlyxQaGtrY1wN4/S7O6/Ye06JPs/TINb0UHswUZwC+oVnrsLgL1mGBrzCbIF71+9XKOF5oW1tiI5sf6i/qFKV7U7ra1k4AcNf3b/55BnjcFOeuembpV7alxRn+uU26OCFag7s2fD0lAGhtBBbAA/cXCg8OUG5RWbPvtXbvUX2y77jmfrKfwALArRFYAA8TGOCvO4Z0ccq9RvbqoOv+tE7Ldubo0KkidWob5pT7AoCzNWtpfgCerU98lC7vFqOKSofmpWa4ujgAUC8CC+DjJg2vWk9pweZMFZdVuLo4AFAnAgvg467pG6fObcN0srBM73x+0NXFAYA6EVgAHxdgZh4Nr9pccc4nbK4IwD0RWADo9su6KCwoQF/n5Cs17QQ1AsDtEFgAKDo8SDcP7GxrYu6GdGoEgNshsACoWfbfWP7VYWWdYHNFAO6FwALA6hkXqRE92qvSbq7IFGcA7oXAAqCG2VPIWLg5U4Wl5dQMALdBYAFQ46resUpqF6684nIt+YwpzgDcB4EFwLd/EPz9NHFYVSuL2V/ICzZzB+AlCCwAarn1sgRFBAdo75HTdmNEAHAHBBYAtUSFBuk/BiXYc6Y4A3AXBBYA3zHxzBTnFV8fUcbxAmoIgMsRWAB8R/cObTSqdweZISyvbWCKMwDXI7AAOOdCcou3ZOl0CVOcAbgWgQVAnUb27KBuHSKUX1Kut7YdoJYAuFSga789AHee4mxaWZ56Z6fdxblru4hm3zPQ30+XdrlAYcEBTikjAN/h5/CChRby8vIUHR2t3NxcRUVFubo4gNcwXUHDnl1hW1mc5YaL4/XC+IFOux8A33j/poUFQL3ahATqyRv76vWN+1VZ2byKMv8y2pWdp/d3ZNuZR0lOaLEB4DsILADO6bbLEu3hDPe8ullr9hy1M4+eurEvNQ+gwRh0C6DV3Dci2X78x5Ys5ReXUfMAGozAAqDVjOzZXt07RNixMW9uZeYRgIYjsABoNX5+fro3paqVZe6G/aqo9Pgx/wBaCYEFQKu6eWBnRYcFKeN4oVZ9fYTaB9AgBBYArSo8OFB3DKkaxPvqJ+nUPoAGIbAAaHUTh3VVgL+fNnxz3E51BoDzIbAAaHWd24bp2os62vO5n+znJwDgvAgsAFzivhFVmysu+fygjp8u4acA4JwILABcYmCXC9Q/IVql5ZVasDmTnwKAcyKwAHDZFOf7zkxxfn1jhg0uAFAfAgsAl7n+4njFRoboSH6JPtiRzU8CQL0ILABcJjjQX3dfnmTPX12fLi/YPB5ACyGwAHCpu4Z2scHliwO52pZ5ip8GgOYHllmzZmnw4MGKjIxUbGysxo0bp927d5/3eadOndK0adMUHx+vkJAQ9erVS++//37N13/5y1/a/uyzjwsvvLAxRQPgodq1CdG4SzrZcxaSA1CfQDXCmjVrbPAwoaW8vFxPPPGExowZo6+++koRERF1Pqe0tFTXXHONDThvvvmmOnfurIyMDLVt27bWdRdddJE+/vjjbwsW2KiiAfBgZn+hf2w5oGU7cnToVJE6tQ1zdZEAuJlGpYJly5bVejx37lwbRLZu3aqRI0fW+ZxXX31VJ06c0IYNGxQUFGQ/17Vr1+8WJDBQHTtWLSQFwLf0iY/SsG7ttDHtuJ0x9Ph1tLACcOIYltzcXPsxJiam3mveffddDRs2zLbMxMXFqV+/fnr22WdVUVFR67q9e/eqU6dO6tatm8aPH6/MzPrXZSgpKVFeXl6tA4Bnuzel6h8yZk2WotLafx8AoMn9LpWVlZoxY4ZSUlJsCKlPWlqaVq5caUOIGbeyb98+PfjggyorK9PTTz9trxk6dKhtrendu7eys7P1q1/9SldccYV27Nhhx8vUNZbGXAPAe1zdJ05dYsKVeaJQczaka+yZpfubo21YkB0jA8Dz+TmaOI9w6tSp+uCDD7R+/XolJCTUe50ZYFtcXKz09HQFBATYz/3hD3/Q888/b8NJfYN0k5KS7HWTJ0+us4XFHNVMC0tiYqJt8YmKimrKywHgBszU5meWfuW0+5kNFhf/aJhdVReA+zHv39HR0Q16/25SC8v06dO1dOlSrV279pxhxTAzg8zYleqwYvTp00c5OTl2QG5wcPB3nmMG5JqgY1pj6mJmGpkDgHe5bXCi3v3ikNKOnm72vUorKlVcVqkXV3+jv028zCnlA+A6jQospjHmoYce0pIlS7R69WolJ1ctq30upsto/vz5tgvJ379qyMyePXtskKkrrBinT5/WN998o7vvvrsxxQPg4dqEBOrtaSlOude+I/ka/Ye1Wr7rsDKOFyipXd0zGQF44aBbM3B23rx5NoCYsSWmlcQcRUVFNddMnDhRM2fOrNV1ZGYJPfzwwzaovPfee3bQrblXtZ/85Cd2yvT+/fvtbKIf/vCHtkXmzjvvdNbrBOBjesRG6speHWQ6ved8st/VxQHQmoFl9uzZtp9p1KhRtoWk+li0aFHNNWZ2z9ljU8zYkg8//FCffvqp+vfvrx//+Mc2vDz++OM11xw4cMCGEzPo9rbbblO7du2UmpqqDh06NPf1AfBh942oagVevCVLecVlri4OAFcMunUnjRm0A8B3mD9vY/7fWu09clq/uKGPplzRzdVFAtDE92/2EgLgtcw2H9WtLKZbqLyi0tVFAtBEBBYAXu2Hl3ZWTESwDp4q0kdfHXZ1cQA0EYEFgFcLDQrQ+KFdatZ5AeCZCCwAvN7dlycpKMBPWzJO6ousU64uDoAmILAA8HqxUaG6sX8ne/4KrSyARyKwAPAJ1YNv39+erezcb9eOAuAZCCwAfEK/ztEamhyj8kqHXt+Y4eriAGgkAgsAn2tlmb8pU4Wl5a4uDoBGILAA8Bmj+8SpS0y4covK9M9tB11dHACNQGAB4DMC/P10b0pXez5nfboqKz1+oW/AZxBYAPiUWy9LVGRIoNKOFWjNnqOuLg6ABiKwAPApbUICdfvgRHvOFGfAcxBYAPice4Z3lb+ftH7fMX2dk+fq4gBoAAILAJ+TGBOua/t1tOcs1w94hkBXFwAAXGHyiGS9vz1Hb39+SGMv6qiQwIBm37NPfKTatQlxSvkA1EZgAeCTBna5QAMSovXFgVxNfm2LU+7ZtV24Pn70SgUG0HgNOBuBBYBP8vPz0y++31e/eW+XSsoqmn2/jOOF2n+8UB/syNGNA6r2LQLgPH4Oh8PjFyLIy8tTdHS0cnNzFRUV5eriAPBBf/p4r/7fx3tsq83b01JsIALgvPdv2i0BwAkmXN5FwYH+totpa8ZJ6hRwMgILADiBGWx7y8DO9vz/1qVTp4CTEVgAwEnuS6naXPHDr3KUcbyAegWciMACAE7SMy5So3p3kBkZOOeT/dQr4EQEFgBwoikjutmP/9iSpdzCMuoWcBICCwA4UUqPdrqwY6QKSyu04NNM6hZwEgILADiRmc5sVtE15n6yX2UVldQv4AQEFgBwsh9c0knt24QoJ69Y72/Ppn4BJyCwAICTmX2J7hmWZM9fXpcmL1ifE3A5AgsAtIDxlycpNMhfOw7maXP6CeoYaCYCCwC0gJiIYN0yMMGe/996FpIDmovAAgAt5L4zg28/3nVY6cdYSA5oDgILALSQ7h3a6OoLY+1Ccq/SygI0C4EFAFrQ5CuqWlkWb83SqcJS6hpoIgILALSgYd3aqW98lIrLKvXGJhaSA5qKwAIALbyQ3JQzrSyvbdiv0nIWkgOagsACAC3s+/07KTYyREfyS7T0y0PUN9AEgU15EgCg4YID/XXP8K56/sPd+v1He7Qprfnrspg1Xn40qrvio8P4UcAnEFgAoBWMH9pFL6zap4OnirRoS5ZT7nmsoFQv3DXQKfcCvCqwzJo1S2+99Za+/vprhYWFafjw4frd736n3r17n/N5p06d0s9//nP73BMnTigpKUl//OMfdf3119dc88ILL+j5559XTk6OBgwYoL/85S8aMmRI018ZALiRtuHB+vvkoUpNO97sexWXVegvK/fpg+3ZyjpRqMSYcKeUEfCawLJmzRpNmzZNgwcPVnl5uZ544gmNGTNGX331lSIiIup8Tmlpqa655hrFxsbqzTffVOfOnZWRkaG2bdvWXLNo0SI9+uijevHFFzV06FAbZsaOHavdu3fb5wGANxiUdIE9nOHzrFNat/eYXv0kXU/feJFT7gm4Mz9HM3blOnr0qA0UJsiMHDmyzmtMCDEtJ6ZVJigoqM5rTEgxIeivf/2rfVxZWanExEQ99NBDevzxx89bjry8PEVHRys3N1dRUVFNfTkA4DHW7jmqia9uVnhwgDY+frWiw+v++wq4s8a8fzdrlpD5BkZMTEy917z77rsaNmyYbZmJi4tTv3799Oyzz6qioqKmBWbr1q0aPXr0t4Xy97ePN27cWOc9S0pK7Is8+wAAX3JFz/a6sGOkCksrNH8z67vA+zU5sJhWkBkzZiglJcWGkPqkpaXZriATUN5//309+eST+v3vf6/f/OY39uvHjh2zXzNh5mzmsRnPUt9YGpPIqg/TGgMAvre+Szd7PndDOuu7wOs1ObCYFpMdO3Zo4cKF5w02ptvob3/7mwYNGqTbb7/dDsA1XUVNNXPmTNu6U31kZTlnxD0AeJIfDKha3+VwHuu7wPs1KbBMnz5dS5cu1apVq5SQULV9en3i4+PVq1cvBQQE1HyuT58+tvXEdAe1b9/efu3w4cO1nmced+zYsc57hoSE2L6usw8A8MX1XSaldLXnf1ubpmYMSQS8K7CY/xlMWFmyZIlWrlyp5OSq5abPxXQZ7du3z7a0VNuzZ48NMsHBwfYwLS8rVqyo+bq51jw2Y18AAPUbPyTJDrz9Oidfn+xr/pRpwCsCi+kGmjdvnubPn6/IyEjbSmKOoqKimmsmTpxou2yqTZ061a698vDDD9ug8t5779lBt+Ze1cyU5pdfflmvvfaadu3aZZ9TUFCge++911mvEwC8kpkddNtlVeP4Xl6X5uriAO6xDsvs2bPtx1GjRtX6/Jw5czRp0iR7npmZaWf5VDMDYj/88EM98sgj6t+/v12HxYSX//qv/6q5xoxrMVOkn3rqKRuALrnkEi1btuw7A3EBAN91X0qyXt+4X2v2HNXunHz17hhJNcHrNGsdFnfBOiwAfN2Db2zV+9tzdOugBD1/6wBXFwdwr3VYAADuoXqK89ufH9SRvGJXFwdwOgILAHiBgV0u0GVJF6iswqHXNu53dXEApyOwAICXtbLMS81UYWm5q4sDOBWBBQC8xDV949S1Xbhyi8q0eMsBVxcHcCoCCwB4iQB/P00eUbU+1ivr01VR6fFzKoAaBBYA8CL/MShRbcODlHmiUMu/qns/NsATEVgAwIuEBQfo7suTapbrB3xy4TgAgPu7e1iSXlqTpm2Zp/SvLw4puX1Es+8ZHx2qdm1CnFI+oCkILADgZWIjQ/XDSztr0ZYsPbTgM6fcMzIkUMsfvVIdo0Odcj+gsQgsAOCFpn+vh744cEonC0ubfa/84nLll5RrzoZ0zbyuj1PKBzQWS/MDAM7p468Oa8rrW2wry4aZ31NkaBA1BqdgaX4AgNN878JYde8QYVtZFm7OombhEswSAgCc+43C30/3n1lF99VP0lVWUUmNodURWAAA5zXu0s5q3yZE2bnFWvrlIWoMrY7AAgA4r9CgAN2b0tWemynTDger6KJ1EVgAAA0yfmgXhQcH6OucfK3fd4xaQ6sisAAAGqRteLBuuyzRnrOKLlobgQUA0GBmc0WzyeK6vce081AuNYdWQ2ABADRYYky4rr843p6/zF5FaEUEFgBAozxwZorzv77M1qFTRdQeWgWBBQDQKBcnRGtYt3aqqHTo1fXp1B5aBYEFANBoD1xZ1cqyYHOmcovKqEG0OAILAKDRRvXqoF5xbVRQWmFDC9DSCCwAgEbz8/t2uf45n6SrtJzl+tGyCCwAgCa56ZLOiosK0eG8Er3z+UFqES2KwAIAaJLgQH9NGp5sz19ex3L9aFkEFgBAk901tIsiggO05/Bprd5zlJpEiyGwAACaLDosSHcO6WLP/3ZmU0RnHMC/83N4wW9GXl6eoqOjlZubq6ioKFcXBwB8ilk8buRzq1Re6Zy3Ez8/6aGreujRMb2dcj94x/s3LSwAgGbp1DZMEy5Pclotmn9Gv7gmTUfyi512T3g+WlgAAE5xqrDUrn7bXFNe36LPMk/pwVHd9bNrL3RK2eD5LSyBrVYqAIBXaxse7JT7/OjK7vrPv2/V31Mz9OBVPdQmhLcq0CUEAHAz1/SJU7f2EcovLtdCVtHFGYxhAQC4FX9/Pz0wsmoV3VfWs4ouqhBYAABuZ9ylndUhMkTZucX61xeHXF0cuAECCwDA7YQGBejelK72/KW137A2CwgsAAD3NH5o0rer6O5mFV1f16gWllmzZmnw4MGKjIxUbGysxo0bp927d5/zOXPnzrW7ep59hIaG1rpm0qRJ37nm2muvbdorAgB4zSq6Zul/48U137i6OPCkwLJmzRpNmzZNqampWr58ucrKyjRmzBgVFBSc83lmbnV2dnbNkZGR8Z1rTEA5+5oFCxY0/tUAALzKfSOSFejvp03pJ/RZ5klXFwcu1KjJ7cuWLftO64lpadm6datGjhxZ7/NMi0nHjh3Pee+QkJDzXgMA8C3x0WG66ZLO+ue2A/rb2jTNnjDI1UWCJw66NSvTGTExMee87vTp00pKSlJiYqJuuukm7dy58zvXrF692oaf3r17a+rUqTp+/Hi99yspKbGr4519AAC8U/UU52U7c5R+7Nwt+vBeTQ4slZWVmjFjhlJSUtSvX796rzMB5NVXX9U777yjefPm2ecNHz5cBw4cqNUd9Prrr2vFihX63e9+Z7uerrvuOlVUVNQ7lsYs5Vt9mCAEAPBOvTtG6nsXxto9hl5el+bq4sDT9hIyrSAffPCB1q9fr4SEhAY/z4x76dOnj+688079+te/rvOatLQ0de/eXR9//LGuvvrqOltYzFHNtLCY0MJuzQDgnTalHdftf0tVcKC/Pvmv79k1WuD5Wny35unTp2vp0qVatWpVo8KKERQUpEsvvVT79u2r95pu3bqpffv29V5jxruYF3b2AQDwXkOSY3RJYluVllfqtQ37XV0cuECjAotpjDFhZcmSJVq5cqWSk5Mb/Q1NN8/27dsVHx9f7zWmu8iMYTnXNQAA32Emb5hNEY3XN+5XQUm5q4sEdw4sZkqzGYcyf/58uxZLTk6OPYqKimqumThxombOnFnz+JlnntFHH31ku3m2bdumCRMm2GnNU6ZMqRmQ+9Of/tROld6/f78dx2IG5vbo0UNjx4515msFAHiwa/pWbYqYZzZF/DTL1cWBOweW2bNn236mUaNG2daP6mPRokU112RmZtp1VKqdPHlS999/vx23cv3119v+qg0bNqhv37726wEBAfryyy/1gx/8QL169dLkyZM1aNAgrVu3znb9AABg3y/8/XR/9aaI69JUVlFJxfiQJg+69dRBOwAAz1VcVqERv1ulY6dL9Nwt/XVD/+YPHTADeYMC2FrP3d+/CSwAAI/ywqp9ev7Dc28L0xjhwQFa9MAwXZwQ7bR7wk1mCQEA4CoTLk9Sl5hwp92vsLRCf1qxx2n3gxsszQ8AgDtsirj6J6NU6oQxLBnHC3Xtn9bq411HtOdwvnrFRTqljHA+WlgAAB7H399PoUEBzT7MKrpj+1btY8eO0O6NwAIA8Gk/GlW1vsu7nx/SwVPfLtMB90JgAQD4NLOC7rBu7VRe6dAr69JdXRzUg8ACAPB51a0sCzZn6mRBqc/XhzsisAAAfN7Inu3VNz5KRWUVen1jhs/XhzsisAAAfJ7dq+hMK8vcDekqLGWvIndDYAEAQNL1/Tra9V1OFpbpH+xV5HYILAAAmIXJAvxr9ip6eV06exW5GQILAABn3DooQe3bBNvpzUu/PES9uBECCwAAZ5jF5O5NSbbnL65OkxfsD+w1CCwAAJxlwtAkRQQHaPfhfK3afYS6cRMEFgAAzhIdHqTxlyfZ89mrv6Fu3ASBBQCAf3NfSrKCAvz06f6T2ppxgvpxAwQWAAD+TcfoUN18aYI9n706jfpxAwQWAADq8MCV3eTnJ32867D2HM6njlyMwAIAQB26d2ijsX072vOX1tDK4mqBri4AAADuyizXv2xnjt75/KBG9Gyn0MCAZt/z4oRoJVwQ7pTy+RICCwAA9bgksa0u7xaj1LQTemTRF06pJ7Mw3dqfXaXwYN6CG4PaAgDgHJ76/kWa9cEuFZdVNLue9hw+rWOnS7Vgc5Ymj6haoA4N4+fwgmX88vLyFB0drdzcXEVFRbm6OAAA1OmNTRn6+ZIdio8O1ZqfXqXgQN8eSprXiPdv364pAABa0S0DE9QhMkTZucV6+7OD1H0jEFgAAGjFvYqmnOkKenHtN6qo9PhOjlZDYAEAoBWZZf+jQgOVdrRAH+3Moe4biMACAEArahMSqHuGd7Xn/7v6G3aEbiACCwAArWzS8K4KDfLX9oO5Wr/vGPXfAAQWAABaWbs2IbpjcBd7/r+r2BG6IQgsAAC4wP0juynQ308b047rs8yT/AzOg8ACAIALdG4bppsu6WzPZ6+mleV8CCwAALjI1FFVO0J/9NVh7WVH6HMisAAA4CI9YiM1pm+cPZ+9hlaWcyGwAADgQg+O6mE/vvv5IR04WcjPoh4EFgAAXGhAYlul9Gin8kqHXl6bxs+iHgQWAADcpJVl4adZOna6xNXFcUsEFgAAXGx493YakBCtkvJKzf1kv6uL4/mBZdasWRo8eLAiIyMVGxurcePGaffu3ed8zty5c+Xn51frCA0NrXWNw+HQU089pfj4eIWFhWn06NHau3dv014RAAAexrw3Tj3TyvLaxv3KLy5zdZE8O7CsWbNG06ZNU2pqqpYvX66ysjKNGTNGBQUF53xeVFSUsrOza46MjIxaX3/uuef05z//WS+++KI2bdqkiIgIjR07VsXFxU17VQAAeBgzW6h7hwjlF5frjU2Zri6O2/FzmOaNJjp69KhtaTFBZuTIkfW2sMyYMUOnTp2q8+vm23fq1EmPPfaYfvKTn9jP5ebmKi4uzj73jjvuOG858vLyFB0dbZ9nwhEAAJ5o8ZYs/fTNLxUWFKDEmLBm3y8kMEC/uKGPhnZrJ3fUmPfvwOZ8I/MNjJiYmHNed/r0aSUlJamyslIDBw7Us88+q4suush+LT09XTk5ObYbqJop/NChQ7Vx48Y6A0tJSYk9zn7BAAB4unGXdtYLq/Zp//FC7Tl82in3/OW/vtL7Px5hu508WZMDiwkfpuUkJSVF/fr1q/e63r1769VXX1X//v1twPmf//kfDR8+XDt37lRCQoINK4ZpUTmbeVz9tbrG0vzqV79qatEBAHBLQQH+WvJginblNP8f4uUVDk2dt1W7svO0evdRXXVhrHwysJixLDt27ND69evPed2wYcPsUc2ElT59+uill17Sr3/96yZ975kzZ+rRRx+t1cKSmJjYpHsBAOBOLogI1vDu7Z1yrwmXJ+mltWn666p9GtW7g0e3sjRpWvP06dO1dOlSrVq1yraSNEZQUJAuvfRS7du3zz7u2LGj/Xj48OFa15nH1V/7dyEhIbav6+wDAADUNnlEsoID/bU146Q2pZ+QJ2tUYDEDZE1YWbJkiVauXKnk5ORGf8OKigpt377dTmE2zD1MMFmxYkWtFhMzW+jslhkAANA4sVGhuu2yqoYFMzbGZwKL6QaaN2+e5s+fb9diMWNMzFFUVFRzzcSJE22XTbVnnnlGH330kdLS0rRt2zZNmDDBTmueMmWK/bppnjJjYX7zm9/o3XfftWHG3MPMHDLrvAAAgKb7z5HdFeDvp3V7j+nLA3XP2PW6wDJ79mw7cHbUqFG2haT6WLRoUc01mZmZdq2VaidPntT9999vx61cf/31tvVkw4YN6tu3b801P/vZz/TQQw/pgQcesAvTmVlFy5Yt+84CcwAAoHESY8J10yWdPL6VpVnrsLgL1mEBAKB++47k65r/t1bmHX/5IyPVMy5Snvb+zV5CAAB4uR6xkRrbt2oiy+zV38gTEVgAAPAB066q2qvonS8OKfN4oTwNgQUAAB9wcUK0RvbqoIpKh15a63mtLAQWAAB8xLRR3e3HxVsO6EieZ20wTGABAMBHDEmO0WVJF6i0olIvr0uTJyGwAADgI/z8/DTte1VjWd7YlKmTBaXyFAQWAAB8yKheHdQ3PkqFpRWau2G/PAWBBQAAX2tluaqqlcUEltMl5fIEBBYAAHzMtf06qluHCOUWlemN1Ax5AgILAAA+JsDfT1OvrJox9PK6dBWXVcjdBbq6AAAAoPWNu7Sz/vjxXh08VaSHFnymhAvCznl9oL+ffn7Dt/sAtjYCCwAAPigowF8PjOymp9/dqeVfHT7v9cGB/gQWAADQ+sYP7WJXvj1eUHLeawP8XTuKhBYWAAB8VGCAv+4bkSxPwKBbAADg9ggsAADA7RFYAACA2yOwAAAAt0dgAQAAbo/AAgAA3B6BBQAAuD0CCwAAcHsEFgAA4PYILAAAwO0RWAAAgNsjsAAAALdHYAEAAG7PK3Zrdjgc9mNeXp6riwIAABqo+n27+n3c6wNLfn6+/ZiYmOjqogAAgCa8j0dHR5/zGj9HQ2KNm6usrNShQ4cUGRkpPz8/p6c/E4SysrIUFRXl1HuD+nY1fr+pb2/G77f717eJICasdOrUSf7+/t7fwmJeZEJCQot+D1P5BJbWQ323Luqb+vZm/H67d32fr2WlGoNuAQCA2yOwAAAAt0dgOY+QkBA9/fTT9iNaHvXduqhv6tub8fvtXfXtFYNuAQCAd6OFBQAAuD0CCwAAcHsEFgAA4PYILAAAwO0RWM7jhRdeUNeuXRUaGqqhQ4dq8+bNrfOT8XJr167VjTfeaFc3NKsTv/3227W+bsaCP/XUU4qPj1dYWJhGjx6tvXv3uqy8nmzWrFkaPHiwXQk6NjZW48aN0+7du2tdU1xcrGnTpqldu3Zq06aNbrnlFh0+fNhlZfZks2fPVv/+/WsWzxo2bJg++OCDmq9T1y3rt7/9rf2bMmPGDOq8Bfzyl7+09Xv2ceGFF7bK7zeB5RwWLVqkRx991E7T2rZtmwYMGKCxY8fqyJEjTql8X1ZQUGDr0wTCujz33HP685//rBdffFGbNm1SRESErXvzPwMaZ82aNfYPSGpqqpYvX66ysjKNGTPG/gyqPfLII/rXv/6lxYsX2+vNVhc333wzVd0EZtVt86a5detWbdmyRd/73vd00003aefOndR1C/v000/10ksv2cB4Nn6/neuiiy5SdnZ2zbF+/frWqWszrRl1GzJkiGPatGk1jysqKhydOnVyzJo1iypzIvNruGTJkprHlZWVjo4dOzqef/75ms+dOnXKERIS4liwYAF130xHjhyxdb5mzZqaug0KCnIsXry45ppdu3bZazZu3Eh9O8EFF1zg+L//+z/qugXl5+c7evbs6Vi+fLnjyiuvdDz88MP28/x+O9fTTz/tGDBgQJ1fa+m6poWlHqWlpfZfSKYr4uw9i8zjjRs3Oictok7p6enKycmpVfdmrwnTJUfdN19ubq79GBMTYz+a33PT6nJ2fZsm3i5dulDfzVRRUaGFCxfa1izTNURdtxzTinjDDTfU+j02qHPnM93zpju/W7duGj9+vDIzM1ulrr1i88OWcOzYMfvHJi4urtbnzeOvv/7aZeXyBSasGHXVffXX0PSdzU3ffkpKivr161dT38HBwWrbti317STbt2+3AcV0YZp+/CVLlqhv3776/PPPqesWYEKh6bY3XUL/jt9v5zL/cJw7d6569+5tu4N+9atf6YorrtCOHTtavK4JLICP/SvU/GE5u88Zzmf+mJtwYlqz3nzzTd1zzz22Px/Ol5WVpYcfftiOzzKTI9CyrrvuuppzM1bIBJikpCT94x//sBMkWhJdQvVo3769AgICvjO62Tzu2LFji/5QfF11/VL3zjV9+nQtXbpUq1atsgNDz65v0wV66tSpWtfzu9505l+ZPXr00KBBg+wsLTPA/E9/+hN13QJMN4SZCDFw4EAFBgbaw4RDM2jfnJt/3fP73XJMa0qvXr20b9++Fv/9JrCc4w+O+WOzYsWKWs3p5rFp6kXLSU5Otr/cZ9d9Xl6enS1E3TeeGddsworplli5cqWt37OZ3/OgoKBa9W2mPZt+aerbOczfjpKSEuq6BVx99dW2C860aFUfl112mR1bUX3O73fLOX36tL755hu7BEWL/y1p9rBdL7Zw4UI7M2Xu3LmOr776yvHAAw842rZt68jJyXF10bxiRP9nn31mD/Nr+Ic//MGeZ2Rk2K//9re/tXX9zjvvOL788kvHTTfd5EhOTnYUFRW5uugeZ+rUqY7o6GjH6tWrHdnZ2TVHYWFhzTU/+tGPHF26dHGsXLnSsWXLFsewYcPsgcZ7/PHH7Qys9PR0+7trHvv5+Tk++ugj6rqVnD1LyOD323kee+wx+7fE/H5/8sknjtGjRzvat29vZx+2dF0TWM7jL3/5i6384OBgO805NTXVKRXv61atWmWDyr8f99xzT83U5ieffNIRFxdnQ+PVV1/t2L17t6uL7ZHqqmdzzJkzp+YaEwQffPBBO/02PDzc8cMf/tCGGjTefffd50hKSrJ/Mzp06GB/d6vDCnXtmsDC77fz3H777Y74+Hj7+925c2f7eN++fa1S137mP81vpwEAAGg5jGEBAABuj8ACAADcHoEFAAC4PQILAABwewQWAADg9ggsAADA7RFYAACA2yOwAAAAt0dgAQAAbo/AAgAA3B6BBQAAuD0CCwAAkLv7//kZmyKGCgWiAAAAAElFTkSuQmCC", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "model = make_model(rngs)\n", - "optimizer = nnx.Optimizer(model, tx=optax.MultiSteps(optax.adam(1e-3), every_k_schedule=3), wrt=nnx.Param)\n", + "tx = optax.MultiSteps(optax.adam(1e-3), every_k_schedule=3)\n", + "optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)\n", "\n", "@nnx.jit\n", "def train_step(model, optimizer, x, y):\n", @@ -443,7 +295,7 @@ }, { "cell_type": "markdown", - "id": "4c6930d0", + "id": "a664628f", "metadata": {}, "source": [ "# Sharding Optimization State Differently from Parameters" @@ -451,7 +303,7 @@ }, { "cell_type": "markdown", - "id": "ff70909f", + "id": "21d36a4f", "metadata": {}, "source": [ "Say we're doing data parallelism. We want to replicate our parameters across all GPUs so we can do the forward and backward passes without communication latency." @@ -459,7 +311,7 @@ }, { "cell_type": "markdown", - "id": "5cf6ede5", + "id": "b6d407e0", "metadata": {}, "source": [ "But we don't need to replicate the optimizer state, as it's not invovled in SPMD computations. One copy is enough, and we can shard this copy across our mesh to reduce memory usage. This means that we need the optimizer state to be sharded differently from the parameters themselves. To do this, we can add the 'optimizer_sharding' metadata to the initializer." @@ -467,8 +319,8 @@ }, { "cell_type": "code", - "execution_count": 25, - "id": "8387cb21", + "execution_count": null, + "id": "cd2f28cf", "metadata": {}, "outputs": [], "source": [ @@ -496,7 +348,7 @@ }, { "cell_type": "markdown", - "id": "f16aab53", + "id": "d1835582", "metadata": { "lines_to_next_cell": 0 }, @@ -506,28 +358,17 @@ }, { "cell_type": "code", - "execution_count": 26, - "id": "10c3b40c", + "execution_count": null, + "id": "88318319", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ShapedArray(float32[2@x,8@y])" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "jax.typeof(optimizer.opt_state[0][1].layers[0]['kernel'][...])" ] }, { "cell_type": "markdown", - "id": "34796a25", + "id": "c0062bd7", "metadata": { "lines_to_next_cell": 0 }, @@ -537,42 +378,27 @@ }, { "cell_type": "code", - "execution_count": 27, - "id": "49b34175", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ShapedArray(float32[2,8])" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "id": "befe02f0", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], "source": [ "jax.typeof(model.layers[0].kernel[...])" ] + }, + { + "cell_type": "markdown", + "id": "65136083", + "metadata": {}, + "source": [] } ], "metadata": { "jupytext": { "formats": "ipynb,md", "main_language": "python" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.5" } }, "nbformat": 4, diff --git a/docs_nnx/guides/optimization_cookbook.md b/docs_nnx/guides/optimization_cookbook.md index afaea94ca..59c70073b 100644 --- a/docs_nnx/guides/optimization_cookbook.md +++ b/docs_nnx/guides/optimization_cookbook.md @@ -46,75 +46,29 @@ y = rngs.normal((32, 8)) # Exponential Moving Average -Neural network see increased robustness when, rather than using only the weights available at the end of training, we use an exponential moving average of the weights produced throughout training. To modify the standard Flax training loop to accomodate calculating exponential moving averages, we can introduce a new `nnx.Variable` subclass for EMA Params, along with a function that converts all variables in a module to this subclass. +Neural networks see increased robustness when, rather than using only the weights available at the end of training, we use an exponential moving average of the weights produced throughout training. NNX provides `nnx.EMA` to make this easy. Simply create an `nnx.EMA` from your model and call `ema.update` after each optimizer step. The averaged parameters are stored in `ema.params`. ```python -class EmaParam(nnx.Variable): - @classmethod - def from_variable(cls, var): - return cls.from_metadata(jnp.copy(var.get_value()), var.get_metadata()) - -def as_ema_params(node): - return jax.tree.map( - EmaParam.from_variable, - node, - is_leaf=lambda x: isinstance(x, nnx.Variable), - ) -``` - -Now, we'll add a method to update the EMA params based on current model values. - -```python -class EMA(nnx.Pytree): - def __init__(self, params, decay: float, *, only: nnx.filterlib.Filter = ...): - self.decay = decay - self.filter = only - self.ema_params = nnx.data(as_ema_params(nnx.state(params, only))) - - def update(self, new_params): - def _update(ema_param, new_param): - ema_param[...] = ( - self.decay * ema_param[...] + (1.0 - self.decay) * new_param[...] - ) - jax.tree.map( - _update, - self.ema_params, - nnx.state(new_params, self.filter), - is_leaf=lambda x: isinstance(x, nnx.Variable), - ) -``` - - -The training loop proceeds as normal, but calls `ema.update` after each optimizer step. - -```python -# initialization model = make_model(rngs) -ema = EMA(model, decay=0.9) - -# simulate parameter update -def double(param): - param[...] *= 2.0 -jax.tree.map(double, model, is_leaf=lambda x: isinstance(x, nnx.Variable)) -ema.update(model) +optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) +ema = nnx.EMA(model, decay=0.9) +ema_model = ema.apply_to(model) @nnx.jit def train_step(model, optimizer, ema, x, y): loss, grads = nnx.value_and_grad(loss_fn)(model, x, y) optimizer.update(model, grads) ema.update(model) - return loss -optimizer = nnx.Optimizer( - model, - tx=optax.adam(1e-3), - wrt=nnx.Param, -) -losses = [] +@nnx.jit +def eval_step(model, x, y): + return loss_fn(model, x, y) + for _ in range(50): - loss = train_step(model, optimizer, ema, x, y) - losses.append(loss) -plt.plot(losses); + train_step(model, optimizer, ema, x, y) + +loss = eval_step(ema_model, x, y) +print(f"final eval loss: {loss}") ``` # Low Rank Adaptation @@ -129,23 +83,21 @@ def add_rank2_lora(path, node): return node base_model = make_model(rngs) -lora_model = nnx.recursive_map(add_rank2_lora, base_model) +model = nnx.recursive_map(add_rank2_lora, base_model) nnx.display(model) ``` +The training loop is the same as before, but we pass `wrt=nnx.LoRAParam` to the optimizer so that only the low-rank adaptation parameters are updated while the base model weights remain frozen. ```python +optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.LoRAParam) + @nnx.jit def train_step(model, optimizer, x, y): loss, grads = nnx.value_and_grad(loss_fn)(model, x, y) optimizer.update(model, grads) return loss -optimizer = nnx.Optimizer( - model, - tx=optax.adam(1e-3), - wrt=nnx.LoRAParam, -) losses = [] for _ in range(50): @@ -174,10 +126,7 @@ def train_step(model, optimizer, x, y): return loss model = make_model(rngs) -optimizer = nnx.Optimizer( - model, - tx=optax.lbfgs(1e-3), - wrt=nnx.Param) +optimizer = nnx.Optimizer(model, optax.lbfgs(1e-3), wrt=nnx.Param) losses = [] for _ in range(50): @@ -221,7 +170,8 @@ Gradient accumulation in Flax is easy: just use the `optax.MultiSteps` optimizer ```python model = make_model(rngs) -optimizer = nnx.Optimizer(model, tx=optax.MultiSteps(optax.adam(1e-3), every_k_schedule=3), wrt=nnx.Param) +tx = optax.MultiSteps(optax.adam(1e-3), every_k_schedule=3) +optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) @nnx.jit def train_step(model, optimizer, x, y): @@ -276,3 +226,4 @@ But the model is not: ```python jax.typeof(model.layers[0].kernel[...]) ``` + diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index bbabaa834..78a550411 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -181,6 +181,7 @@ from .training.optimizer import Optimizer as Optimizer from .training.optimizer import ModelAndOptimizer as ModelAndOptimizer from .training.optimizer import OptState as OptState +from .training.ema import EMA as EMA from .transforms.autodiff import DiffState as DiffState from .transforms.autodiff import grad as grad from .transforms.autodiff import value_and_grad as value_and_grad diff --git a/flax/nnx/training/ema.py b/flax/nnx/training/ema.py new file mode 100644 index 000000000..ec8a29723 --- /dev/null +++ b/flax/nnx/training/ema.py @@ -0,0 +1,179 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import typing as tp + +from flax.nnx import filterlib +from flax.nnx import graphlib +from flax.nnx import pytreelib +from flax.nnx import statelib +from flax.nnx import variablelib +import jax +import jax.numpy as jnp + +A = tp.TypeVar('A') + + +def _to_ema_param(node: tp.Any): + def ema_param(path, x): + if not isinstance(x, variablelib.Variable): + path_str = '/'.join(str(k) for k in path) + raise TypeError( + f"EMA only supports Variable leaves, got {type(x).__name__} at " + f"path '{path_str}'. Use the `only` filter to select Variable leaves." + ) + ema_metadata = x.get_metadata() + value = jnp.copy(x.get_value()) + return type(x)(value, **ema_metadata) + + return jax.tree.map_with_path( + ema_param, node, is_leaf=lambda x: isinstance(x, variablelib.Variable) + ) + + +class EMA(pytreelib.Pytree): + """Exponential Moving Average (EMA) of parameters. + + Maintains a shadow copy of model Variables that is updated as an + exponentially weighted moving average on each call to :meth:`update`. + This is commonly used to stabilize training and improve evaluation + performance by applying the averaged parameters at inference time. + + Example usage:: + + >>> from flax import nnx + >>> import jax, jax.numpy as jnp + >>> import optax + ... + >>> model = nnx.Linear(2, 2, rngs=nnx.Rngs(0)) + >>> optimizer = nnx.Optimizer(model, optax.sgd(0.1), wrt=nnx.Param) + >>> ema = nnx.EMA(model, decay=0.9) + >>> ema_model = ema.apply_to(model) + ... + >>> def loss_fn(model, x, y): + ... return jnp.mean((model(x) - y) ** 2) + ... + >>> @nnx.jit + ... def train_step(model, optimizer, ema, x, y): + ... grads = nnx.grad(loss_fn)(model, x, y) + ... optimizer.update(model, grads) + ... ema.update(model) + ... + >>> @nnx.jit + ... def eval_step(model, x, y): + ... return loss_fn(model, x, y) + ... + >>> x, y = jnp.ones((1, 2)), jnp.ones((1, 2)) + >>> train_step(model, optimizer, ema, x, y) + >>> loss = eval_step(ema_model, x, y) + + In this example, ``ema.update`` computes the moving average and updates + the internal state of ``ema``. ``ema.apply_to`` creates a new model + instance (``ema_model``) that shares its Variables with ``ema``. + Therefore, ``ema_model`` will automatically reflect the updates performed by + ``ema.update`` and can be used directly in ``eval_step``. + + Attributes: + decay: The decay rate for the exponential moving average. + filter: The filter used to select which variables to track. + params: A pytree of variables holding the current + moving average values. + """ + + def __init__( + self, + params: tp.Any, + decay: float, + *, + only: filterlib.Filter = ..., + graph: bool | None = None, + ): + """Initializes the EMA module. + + Args: + params: Any object, typically an NNX module/node, whose parameters + will be tracked. + decay: The decay rate for the moving average. + only: A filter indicating which variables should be included in the + EMA tracking. Defaults to matching everything. Note that EMA only + tracks ``nnx.Variable`` instances. + graph: If ``True``, uses graph-mode which supports the full NNX + feature set including shared references. If ``False``, uses + tree-mode which treats Modules as regular JAX pytrees, avoiding + the overhead of the graph protocol. If ``None`` (default), the + value is determined by the current ``nnx.set_graph_mode`` context. + """ + self.graph = graph + self.decay = decay + self.filter = only + self.params: graphlib.State = pytreelib.data( + _to_ema_param(graphlib.state(params, only, graph=graph)) + ) + + def update(self, updates: tp.Any) -> None: + """Updates the EMA parameters towards the given new parameters. + + The update rule for each parameter is:: + + ema = decay * ema + (1 - decay) * update + + Args: + updates: The new parameters or module to blend into the current EMA. + This should have the same structure as the ``params`` object passed + during initialization. + """ + def _update_ema(ema: variablelib.Variable, update: tp.Any) -> tp.Any: + ema[...] = self.decay * ema + (1.0 - self.decay) * update + + jax.tree.map( + _update_ema, + self.params, + graphlib.state(updates, self.filter, graph=self.graph), + is_leaf=lambda x: isinstance(x, variablelib.Variable), + ) + + def apply_to(self, model: A) -> A: + """Returns a view of the model using the EMA parameters. + + Constructs a new model instance with the same structure as ``model`` + but whose tracked parameters are replaced by their exponential moving + average values. Non-tracked state (e.g. variables excluded by the + ``only`` filter) is preserved from the original ``model``. + + This is typically used at evaluation time to obtain a model whose + parameters reflect the smoothed training trajectory. + + Example usage:: + + >>> from flax import nnx + >>> import jax.numpy as jnp + ... + >>> model = nnx.Linear(2, 2, use_bias=False, rngs=nnx.Rngs(0)) + >>> ema = nnx.EMA(model, decay=0.9) + >>> ema_model = ema.apply_to(model) + >>> assert ema_model.kernel is ema.params.kernel + + Args: + model: A model instance whose graph structure is used to build + the output. The model should have the same structure as the + ``params`` originally passed to :class:`EMA`. + + Returns: + A new model of the same type as ``model`` with tracked parameters + replaced by the current EMA values. + """ + graphdef, state = graphlib.split(model, graph=self.graph) + merged_state = statelib.merge_state(state, self.params) + return graphlib.merge(graphdef, merged_state) diff --git a/flax/traceback_util.py b/flax/traceback_util.py index 7b33e66cc..021396f7b 100644 --- a/flax/traceback_util.py +++ b/flax/traceback_util.py @@ -25,7 +25,7 @@ # Whether to filter flax frames from traceback. _flax_filter_tracebacks = config.flax_filter_frames # Flax specific set of paths to exclude from tracebacks. -_flax_exclusions = set() +_flax_exclusions: set[str] = set() # re-import JAX symbol for convenience. diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 17ac93355..9fbd09514 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -1163,7 +1163,7 @@ def read_chunk(i): else: checkpoint_contents = fp.read() - state_dict = serialization.msgpack_restore(checkpoint_contents) + state_dict = serialization.msgpack_restore(checkpoint_contents) # type: ignore[arg-type] state_dict = _restore_mpas( state_dict, target, diff --git a/tests/nnx/ema_test.py b/tests/nnx/ema_test.py new file mode 100644 index 000000000..db6fe6807 --- /dev/null +++ b/tests/nnx/ema_test.py @@ -0,0 +1,117 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4' + +from absl.testing import absltest +from flax import nnx +import jax +import jax.numpy as jnp +import numpy as np +import optax + +class TestEMA(absltest.TestCase): + + def test_ema_initialization_and_update(self): + model = nnx.Linear(2, 2, use_bias=False, rngs=nnx.Rngs(0)) + initial_kernel = jnp.copy(model.kernel[...]) + + ema = nnx.EMA(model, decay=0.9) + + np.testing.assert_allclose(ema.params.kernel, initial_kernel) + + def double(param): + param[...] *= 2.0 + + jax.tree.map(double, model, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + ema.update(model) + expected = 0.9 * initial_kernel + 0.1 * (2.0 * initial_kernel) + + np.testing.assert_allclose(ema.params.kernel[...], expected) + + def test_ema_sharding(self): + if jax.device_count() < 4: + self.skipTest('At least 4 devices required') + + mesh = jax.make_mesh( + (2, 2), ('row', 'col'), + axis_types=(jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Auto), + ) + with jax.set_mesh(mesh): + model = nnx.Linear( + 4, 2, rngs=nnx.Rngs(0), + kernel_init=nnx.with_partitioning( + nnx.initializers.lecun_normal(), + sharding=('row', 'col'), + ), + use_bias=False, + ) + ema = nnx.EMA(model, decay=0.9) + + # EMA params should have the same sharding as the model variables. + self.assertTrue( + ema.params.kernel.sharding.is_equivalent_to( + model.kernel.sharding, + ndim=2, + ) + ) + + def test_ema_example(self): + def loss_fn(model, x, y): + return jnp.mean((model(x) - y) ** 2) + + rngs = nnx.Rngs(0) + x = rngs.normal((1, 2)) + y = rngs.normal((1, 3)) + + model = nnx.Linear(2, 3, rngs=rngs) + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + ema = nnx.EMA(model, decay=0.9) + ema_model = ema.apply_to(model) + original_kernel = ema_model.kernel[...] + + @nnx.jit + def train_step(model, optimizer, ema, x, y): + loss, grads = nnx.value_and_grad(loss_fn)(model, x, y) + optimizer.update(model, grads) + ema.update(model) + return loss + + train_step(model, optimizer, ema, x, y) + + self.assertIsInstance(ema_model.kernel, nnx.Param) + self.assertIs(ema_model.kernel, ema.params.kernel) + self.assertIsNot(ema_model.kernel, model.kernel) + self.assertFalse(jnp.allclose(ema_model.kernel[...], original_kernel)) + + def test_ema_apply_to(self): + model = nnx.Linear(2, 2, use_bias=False, rngs=nnx.Rngs(0)) + ema = nnx.EMA(model, decay=0.9) + ema_model = ema.apply_to(model) + + def double(param): + param[...] *= 2.0 + + jax.tree.map(double, model, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + ema.update(model) + + np.testing.assert_allclose(ema_model.kernel[...], ema.params.kernel[...]) + self.assertFalse(jnp.allclose(ema_model.kernel[...], model.kernel[...])) + +if __name__ == '__main__': + absltest.main()