{ "cells": [ { "cell_type": "markdown", "id": "532eeb1b-e320-4479-b11d-15c3bc7d960f", "metadata": {}, "source": [ "# 2025-11-03 Optimization\n", "\n", "* Differentiation\n", "\n", "* Second order (Newton type) optimization\n", "\n", "* Project discussion" ] }, { "cell_type": "code", "execution_count": 1, "id": "9dcc4924-786c-4b90-b175-c785af3406dd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "diff_wp (generic function with 1 method)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "using LinearAlgebra\n", "using Plots\n", "using Polynomials\n", "default(lw=4, ms=5, legendfontsize=12, xtickfontsize=12, ytickfontsize=12)\n", "\n", "# Here's our Vandermonde matrix again\n", "function vander(x, k=nothing)\n", " if isnothing(k)\n", " k = length(x)\n", " end\n", " m = length(x)\n", " V = ones(m, k)\n", " for j in 2:k\n", " V[:, j] = V[:, j-1] .* x\n", " end\n", " V\n", "end\n", "\n", "# With Chebyshev polynomials\n", "function vander_chebyshev(x, n=nothing)\n", " if isnothing(n)\n", " n = length(x) # Square by default\n", " end\n", " m = length(x)\n", " T = ones(m, n)\n", " if n > 1\n", " T[:, 2] = x\n", " end\n", " for k in 3:n\n", " #T[:, k] = x .* T[:, k-1]\n", " T[:, k] = 2 * x .* T[:,k-1] - T[:, k-2]\n", " end\n", " T\n", "end\n", "\n", "# And our \"bad\" function\n", "runge(x) = 1 / (1 + 10*x^2)\n", "runge_noisy(x, sigma) = runge.(x) + randn(size(x)) * sigma\n", "\n", "# And our gradient descent algorithm\n", "function grad_descent(loss, grad, c0; gamma=1e-3, tol=1e-5)\n", " \"\"\"Minimize loss(c) via gradient descent with initial guess c0\n", " using learning rate gamma. Declares convergence when gradient\n", " is less than tol or after 500 steps.\n", " \"\"\"\n", " c = copy(c0)\n", " chist = [copy(c)]\n", " lhist = [loss(c)]\n", " for it in 1:500\n", " g = grad(c)\n", " c -= gamma * g\n", " push!(chist, copy(c))\n", " push!(lhist, loss(c))\n", " if norm(g) < tol\n", " break\n", " end\n", " end\n", " (c, hcat(chist...), lhist)\n", "end\n", "\n", "# And our function for a finite difference while picking h automatically\n", "function diff_wp(f, x; eps=1e-8)\n", " \"\"\"Diff using Walker and Pernice (1998) choice of step\"\"\"\n", " h = eps * (1 + abs(x))\n", " (f(x+h) - f(x)) / h\n", "end" ] }, { "cell_type": "markdown", "id": "eb05af69-8630-436a-ac03-34f6cc3c2022", "metadata": {}, "source": [ "## Hand-coded derivatives\n", "\n", "With (mild) algebra abuse, the expression\n", "\n", "$$ \\frac{df}{dx} = f' \\left( x \\right) $$\n", "\n", "is equivalent to\n", "\n", "$$ df = f' \\left( x \\right) dx $$" ] }, { "cell_type": "code", "execution_count": 2, "id": "01b58aeb-b351-4947-bd69-b8e90c54541c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(-1.5346823414986814, -34.032439961925064)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function f(x)\n", " y = x\n", " for _ in 1:2\n", " a = y^π\n", " b = cos(a)\n", " c = log(y)\n", " y = b * c\n", " end\n", " y\n", "end\n", "\n", "f(1.9), diff_wp(f, 1.9)" ] }, { "cell_type": "code", "execution_count": 3, "id": "5f3f616a-6304-4cf4-9b19-270640eece21", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(-1.5346823414986814, -34.032419599140475)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function df(x, dx)\n", " y = x\n", " dy = dx\n", " for _ in 1:2\n", " a = y^π\n", " da = π * y^(π - 1) * dy\n", " b = cos(a)\n", " db = -sin(a) * da\n", " c = log(y)\n", " dc = dy / y\n", " y = b * c\n", " dy = db * c + b * dc\n", " end\n", " y, dy\n", "end\n", "\n", "df(1.9, 1)" ] }, { "cell_type": "markdown", "id": "d64d4de7-6a57-4d1e-97a0-0f85b6ced100", "metadata": {}, "source": [ "### Forward vs reverse mode\n", "\n", "We can differentiate a composition $h \\left( g \\left( f \\left( x \\right) \\right) \\right)$ as\n", "\n", "$$ \\begin{align}\n", " \\operatorname{d} h &= h' \\operatorname{d} g \\\\\n", " \\operatorname{d} g &= g' \\operatorname{d} f \\\\\n", " \\operatorname{d} f &= f' \\operatorname{d} x\n", "\\end{align}$$\n", "\n", "What we've done above is called \"forward mode\", and amounts to placing the parentheses in the chain rule like\n", "\n", "$$ \\operatorname{d} h = \\frac{dh}{dg} \\left( \\frac{dg}{df} \\left( \\frac{df}{dx} \\operatorname{d} x \\right) \\right) $$\n", "\n", "This expression means the same thing if we rearrange the parenthesis,\n", "\n", "$$ \\operatorname{d} h = \\left( \\left( \\left( \\frac{dh}{dg} \\right) \\frac{dg}{df} \\right) \\frac{df}{dx} \\right) \\operatorname{d} x $$" ] }, { "cell_type": "markdown", "id": "06a40c52-4265-46c1-b173-1cd171876061", "metadata": {}, "source": [ "## Reverse mode example\n", "\n", "Let's do an example to better understand.\n", "\n", "$$ \\underbrace{\\bar x}_{\\frac{dh}{dx}} = \\underbrace{\\bar g \\frac{dg}{df}}_{\\bar f} \\frac{df}{dx} $$" ] }, { "cell_type": "code", "execution_count": 4, "id": "17ec01d1-ef6c-41a9-92a5-7a601d15b5fe", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.2155134138380423, -1.2559760384500684)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function g(x)\n", " a = x^π\n", " b = cos(a)\n", " c = log(x)\n", " y = b * c\n", " y\n", "end\n", "\n", "(g(1.9), diff_wp(g, 1.4))" ] }, { "cell_type": "code", "execution_count": 5, "id": "e816695b-1123-47b5-9d40-56f7aea6abf8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-1.2559761698835525" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function gback(x, y_)\n", " a = x^π\n", " b = cos(a)\n", " c = log(x)\n", " y = b * c\n", " # backward pass\n", " c_ = y_ * b\n", " b_ = c * y_\n", " a_ = -sin(a) * b_\n", " x_ = 1/x * c_ + π * x^(π - 1) * a_\n", "end\n", "\n", "gback(1.4, 1)" ] }, { "cell_type": "markdown", "id": "88216231-9e94-4ab0-b6d5-6cb47fed40fb", "metadata": {}, "source": [ "## Automatic differentiation\n", "\n", "See also [Enzyme.jl](https://enzyme.mit.edu/julia/stable)" ] }, { "cell_type": "code", "execution_count": 6, "id": "bb0d67b2-a820-446c-8c24-6d724ba08b7e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m\u001b[1m Resolving\u001b[22m\u001b[39m package versions...\n", "\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `~/.julia/environments/v1.11/Project.toml`\n", "\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `~/.julia/environments/v1.11/Manifest.toml`\n" ] } ], "source": [ "using Pkg\n", "pkg\"add Zygote\"\n", "import Zygote" ] }, { "cell_type": "code", "execution_count": 7, "id": "ff624d1e-c2a9-469f-91aa-aa5a3773cce5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(-1.2559761698835525,)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Zygote.gradient(g, 1.4)" ] }, { "cell_type": "markdown", "id": "c6e0d318-8c4e-4b5e-a211-5e906f4579ed", "metadata": {}, "source": [ "## But how?\n", "\n", "It's cool that Zygote works, but how does it actually work?" ] }, { "cell_type": "code", "execution_count": 8, "id": "36f05106-5ab2-4844-ad64-a45e3f6cfa56", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[90m; Function Signature: square(Float64)\u001b[39m\n", "\u001b[90m; @ In[8]:1 within `square`\u001b[39m\n", "\u001b[95mdefine\u001b[39m \u001b[36mdouble\u001b[39m \u001b[93m@julia_square_11421\u001b[39m\u001b[33m(\u001b[39m\u001b[36mdouble\u001b[39m \u001b[0m%\"x::Float64\"\u001b[33m)\u001b[39m \u001b[0m#0 \u001b[33m{\u001b[39m\n", "\u001b[91mtop:\u001b[39m\n", "\u001b[90m; ┌ @ intfuncs.jl:370 within `literal_pow`\u001b[39m\n", "\u001b[90m; │┌ @ float.jl:493 within `*`\u001b[39m\n", " \u001b[0m%0 \u001b[0m= \u001b[96m\u001b[1mfmul\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%\"x::Float64\"\u001b[0m, \u001b[0m%\"x::Float64\"\n", " \u001b[96m\u001b[1mret\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%0\n", "\u001b[90m; └└\u001b[39m\n", "\u001b[33m}\u001b[39m\n" ] } ], "source": [ "square(x) = x^2\n", "# Let's look at the LLVM bitcode here\n", "@code_llvm square(1.5)" ] }, { "cell_type": "code", "execution_count": 9, "id": "db557a3a-38ad-4ecc-8c2e-f3d2e89bb222", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[90m; Function Signature: gradient(typeof(Main.square), Float64)\u001b[39m\n", "\u001b[90m; @ /home/jeremy/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:152 within `gradient`\u001b[39m\n", "\u001b[95mdefine\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[93m@julia_gradient_11636\u001b[39m\u001b[33m(\u001b[39m\u001b[36mdouble\u001b[39m \u001b[0m%\"args\u001b[33m[\u001b[39m\u001b[33m1\u001b[39m\u001b[33m]\u001b[39m\u001b[0m::Float64\"\u001b[33m)\u001b[39m \u001b[0m#0 \u001b[33m{\u001b[39m\n", "\u001b[91mtop:\u001b[39m\n", "\u001b[90m; @ /home/jeremy/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:154 within `gradient`\u001b[39m\n", "\u001b[90m; ┌ @ /home/jeremy/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:97 within `#88`\u001b[39m\n", "\u001b[90m; │┌ @ In[8]:1 within `square`\u001b[39m\n", "\u001b[90m; ││┌ @ /home/jeremy/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:222 within `ZBack`\u001b[39m\n", "\u001b[90m; │││┌ @ /home/jeremy/.julia/packages/Zygote/55SqB/src/lib/number.jl:12 within `literal_pow_pullback`\u001b[39m\n", "\u001b[90m; ││││┌ @ promotion.jl:430 within `*` @ float.jl:493\u001b[39m\n", " \u001b[0m%0 \u001b[0m= \u001b[96m\u001b[1mfmul\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%\"args\u001b[33m[\u001b[39m\u001b[33m1\u001b[39m\u001b[33m]\u001b[39m\u001b[0m::Float64\"\u001b[0m, \u001b[33m2.000000e+00\u001b[39m\n", "\u001b[90m; └└└└└\u001b[39m\n", "\u001b[90m; @ /home/jeremy/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:155 within `gradient`\u001b[39m\n", " \u001b[91m%\"new:\u001b[39m\u001b[0m:Tuple2.unbox.fca.0.insert\" \u001b[0m= \u001b[95minsertvalue\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[95mzeroinitializer\u001b[39m\u001b[0m, \u001b[36mdouble\u001b[39m \u001b[0m%0\u001b[0m, \u001b[33m0\u001b[39m\n", " \u001b[96m\u001b[1mret\u001b[22m\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[0m%\"new::Tuple2.unbox.fca.0.insert\"\n", "\u001b[33m}\u001b[39m\n" ] } ], "source": [ "# And here is the LLVM bitcode of Zygote's derivative\n", "@code_llvm Zygote.gradient(square, 1.5)" ] }, { "cell_type": "markdown", "id": "0b4dd66e-f8d0-422d-affe-6e8c17399ec4", "metadata": {}, "source": [ "## Types of algorithmic differentiation\n", "\n", "* Source transformation: Fortran code in, Fortran code out\n", "\n", " * Duplicates compiler features, usually incomplete language coverage\n", "\n", " * Produces efficient code\n", "\n", "* Operator overloading: C++ types\n", "\n", " * Hard to vectorize\n", "\n", " * Loops are effectively unrolled/inefficient\n", "\n", "* Just-in-time compilation: tightly coupled with compiler\n", "\n", " * JIT lag\n", "\n", " * Needs dynamic language features (JAX) or tight integration with compiler (Zygote, Enzyme)\n", "\n", " * [Some sharp bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow)" ] }, { "cell_type": "markdown", "id": "ef006b6b-9ece-478a-a432-a0a374795d96", "metadata": {}, "source": [ "## Forward or reverse mode\n", "\n", "Pick forward or reverse mode depending upon the 'shape' of your function.\n", "\n", "* One input, many outputs: use forward mode\n", "\n", " * \"One input\" can be looking in one direction\n", "\n", "* Many inputs, one output: use reverse mode\n", "\n", " * Will need to traverse execution backwards (\"tape\")\n", " \n", " * Hierarchical checkpointing\n", " \n", "* About square? Forward mode is usually a bit more efficient" ] }, { "cell_type": "markdown", "id": "fa5f0a61-09e4-4a5b-916a-d96e382d9ef4", "metadata": {}, "source": [ "## Differentating an algorithm?\n", "\n", "* Consider an input $c$ and output $x$ such that $f \\left( x, c \\right) = 0$.\n", "\n", "* Consider an input $A$ and output $\\lambda$ such that $A x = \\lambda x$ for some nonzero vector $x$\n", "\n", "* Consider an input `buffer` and and output `sha256(buffer)`" ] }, { "cell_type": "markdown", "id": "3cbca3ea-5b96-46b0-9aba-a578f02d8afd", "metadata": {}, "source": [ "## Ill-conditioned optimization\n", "\n", "$$ L \\left( c; x, y \\right) = \\frac 1 2 \\lVert \\underbrace{f \\left( x, c \\right) - y}_{r \\left( c \\right)} \\rVert_{C^{-1}}^2 $$\n", "\n", "Gradient of $L$ requires the Jacobian $J$ of the model $f$.\n", "\n", "$$ g \\left( c \\right) = \\nabla_c L = r^T \\underbrace{\\nabla_c f}_{J} $$\n", "\n", "We can solve $g \\left( c \\right) = 0$ using a Newton method\n", "\n", "$$ g \\left( c + \\delta c \\right) = g \\left( c \\right) + \\underbrace{\\nabla_c g}_{H} \\delta c + \\mathcal{O} \\left( \\left( \\delta c \\right)^2 \\right) $$\n", "\n", "The Hessian requires the second derivative of $f$, which can cause problems\n", "\n", "$$ H = J^T J + r^T \\left( \\nabla_c J \\right) $$\n", "\n", "Consider - if the Jacobian (fist derivative) can have significantly many more terms than our loss function, then that will compound for the Hessian." ] }, { "cell_type": "markdown", "id": "010da10c-232a-48df-a3b7-0054dda6e5ce", "metadata": {}, "source": [ "## Newton-like methods for optimization\n", "\n", "Solve\n", "\n", "$$ H \\delta c = - g \\left( c \\right) $$\n", "\n", "Update $c \\leftarrow c + \\gamma \\delta c$ using using a line search or [trust region](https://en.wikipedia.org/wiki/Trust_region)." ] }, { "cell_type": "markdown", "id": "c87f1721-7866-4e0e-b4b6-3bcefa29ec0e", "metadata": {}, "source": [ "## Outlook\n", "\n", "* The optimization problem can be solved using a Newton method.\n", "It can be onerous to implement the needed derivatives.\n", "\n", "* The [Gauss-Newton method](https://en.wikipedia.org/wiki/Gauss%E2%80%93Newton_algorithm) is often more practical than Newton while being faster than gradient descent, though it lacks robustness.\n", "\n", "* The [Levenberg-Marquardt](https://en.wikipedia.org/wiki/Levenberg%E2%80%93Marquardt_algorithm) method provides a sort of middle-ground between Gauss-Newton and gradient descent.\n", "\n", "* Many globalization techniques are used for models that possess many local minima.\n", "\n", "* One pervasive approach is stochastic gradient descent, where small batches (e.g., 1 or 10 or 20) are selected randomly from the corpus of observations, and a step of gradient descent is applied to that reduced set of observations.\n", "This helps to escape saddle points and weak local minima.\n", "\n", "* Among expressive models $f \\left( x, c \\right)$, some may converge much more easily than others.\n", "Having a good optimization algorithm is essential for nonlinear regression with complicated models, especially those with many parameters $c$.\n", "\n", "* Classification is a very similar problem to regression, but the observations $y$ are discrete, thus\n", "\n", " * models $f \\left( x, c \\right)$ must have discrete output\n", "\n", " * the least squares loss function is not appropriate\n", "\n", "* [Why momentum really works](https://distill.pub/2017/momentum/)" ] } ], "metadata": { "kernelspec": { "display_name": "Julia 1.11.6", "language": "julia", "name": "julia-1.11" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }