{
"cells": [
{
"cell_type": "markdown",
"id": "e5a89f53-5f63-43e9-bd66-a6790d1fe87a",
"metadata": {},
"source": [
"# 2025-10-31 Differentiation\n",
"\n",
"* Computing derivatives in maintainable code\n",
"\n",
"* Forward and reverse\n",
"\n",
"* Algorithmic (automatic) differentiation"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f00d4064-320e-4992-a27b-c5d729ed4c8a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"grad_descent (generic function with 1 method)"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using LinearAlgebra\n",
"using Plots\n",
"using Polynomials\n",
"default(lw=4, ms=5, legendfontsize=12, xtickfontsize=12, ytickfontsize=12)\n",
"\n",
"# Here's our Vandermonde matrix again\n",
"function vander(x, k=nothing)\n",
" if isnothing(k)\n",
" k = length(x)\n",
" end\n",
" m = length(x)\n",
" V = ones(m, k)\n",
" for j in 2:k\n",
" V[:, j] = V[:, j-1] .* x\n",
" end\n",
" V\n",
"end\n",
"\n",
"# With Chebyshev polynomials\n",
"function vander_chebyshev(x, n=nothing)\n",
" if isnothing(n)\n",
" n = length(x) # Square by default\n",
" end\n",
" m = length(x)\n",
" T = ones(m, n)\n",
" if n > 1\n",
" T[:, 2] = x\n",
" end\n",
" for k in 3:n\n",
" #T[:, k] = x .* T[:, k-1]\n",
" T[:, k] = 2 * x .* T[:,k-1] - T[:, k-2]\n",
" end\n",
" T\n",
"end\n",
"\n",
"# A utility for evaluating our regression\n",
"function chebyshev_regress_eval(x, xx, n)\n",
" V = vander_chebyshev(x, n)\n",
" vander_chebyshev(xx, n) / V\n",
"end\n",
"\n",
"# And our \"bad\" function\n",
"runge(x) = 1 / (1 + 10*x^2)\n",
"runge_noisy(x, sigma) = runge.(x) + randn(size(x)) * sigma\n",
"\n",
"# And our gradient descent algorithm\n",
"function grad_descent(loss, grad, c0; gamma=1e-3, tol=1e-5)\n",
" \"\"\"Minimize loss(c) via gradient descent with initial guess c0\n",
" using learning rate gamma. Declares convergence when gradient\n",
" is less than tol or after 500 steps.\n",
" \"\"\"\n",
" c = copy(c0)\n",
" chist = [copy(c)]\n",
" lhist = [loss(c)]\n",
" for it in 1:500\n",
" g = grad(c)\n",
" c -= gamma * g\n",
" push!(chist, copy(c))\n",
" push!(lhist, loss(c))\n",
" if norm(g) < tol\n",
" break\n",
" end\n",
" end\n",
" (c, hcat(chist...), lhist)\n",
"end"
]
},
{
"cell_type": "markdown",
"id": "2b8543fb-6c37-48cd-8a5d-09e1184d4684",
"metadata": {},
"source": [
"## Nonlinear models\n",
"\n",
"Instead of the linear model\n",
"\n",
"$$ f \\left( x, c \\right) = V \\left( x \\right) c = c_0 + c_1 T_1 \\left( x \\right) + c_2 T_2 \\left( x \\right) + \\cdots $$\n",
"\n",
"let's consider a rational model with only three parameters\n",
"\n",
"$$ f \\left( x, c \\right) = \\frac{1}{c_1 + c_2 x + c_3 x^2} = \\left( c_1 + c_2 x + c_3 x^2 \\right)^{-1} $$\n",
"\n",
"We'll use the same loss function\n",
"\n",
"$$ L \\left( c; x, y \\right) = \\frac{1}{2} \\left\\lvert \\left\\lvert f \\left( x, c \\right) - y \\right\\rvert \\right\\rvert^2 $$\n",
"\n",
"We will also need the gradient\n",
"\n",
"$$ \\nabla_c L \\left( c; x, y \\right) = \\left( f \\left( x, c \\right) - y \\right)^T \\nabla_c f \\left( x, c \\right) $$\n",
"\n",
"where\n",
"\n",
"$$ \\frac{\\partial f \\left( x, c \\right)}{\\partial c_1} = - \\left( c_1 + c_2 x + c_3 x^2 \\right)^{-2} = - f \\left( x, c \\right)^2 $$\n",
"$$ \\frac{\\partial f \\left( x, c \\right)}{\\partial c_1} = - \\left( c_1 + c_2 x + c_3 x^2 \\right)^{-2} x = - f \\left( x, c \\right)^2 x $$\n",
"$$ \\frac{\\partial f \\left( x, c \\right)}{\\partial c_1} = - \\left( c_1 + c_2 x + c_3 x^2 \\right)^{-2} x^2 = - f \\left( x, c \\right)^2 x^2 $$"
]
},
{
"cell_type": "markdown",
"id": "03467447-80d5-45bb-84e3-855b51b2026a",
"metadata": {},
"source": [
"## Fitting a rational function\n",
"\n",
"Now let's fit our rational function with gradient descent, as above."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ad366000-5ca6-4380-b4c0-d2cb36a23066",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"gradient (generic function with 1 method)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f(x, c) = 1 ./ (c[1] .+ c[2].*x + c[3].*x.^2)\n",
"\n",
"function gradf(x, c)\n",
" f2 = f(x, c).^2\n",
" [-f2 -f2.*x -f2.*x.^2]\n",
"end\n",
"\n",
"function loss(c)\n",
" r = f(x, c) - y\n",
" 0.5 * r' * r\n",
"end\n",
"\n",
"function gradient(c)\n",
" r = f(x, c) - y\n",
" vec(r' * gradf(x, c))\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "65dc18d1-ef9c-4703-a317-ad70c1dc0ac5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"c = [1.1380315133586891, 0.9808803838099355, 9.257852500955655]\n"
]
},
{
"data": {
"image/png": "",
"image/svg+xml": [
"\n",
"\n"
],
"text/html": [
"
"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Let's fit our Runge function with noise\n",
"x = LinRange(-1, 1, 200)\n",
"y = runge_noisy(x, .5)\n",
"c, _, lhist = grad_descent(loss, gradient, [1., 0, 10.], gamma=1e-2)\n",
"@show c\n",
"\n",
"plot(lhist, yscale=:log10, label=\"loss\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "abac8e2a-ac55-457f-b289-753ab21bc3f2",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"image/svg+xml": [
"\n",
"\n"
],
"text/html": [
"
"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scatter(x, y, label=\"data\")\n",
"V = vander_chebyshev(x, 12) # Will overfit for Chebyshev (fit to the noise!)\n",
"plot!(x -> runge(x), color=:black, label=\"True\")\n",
"plot!(x, V * (V \\ y), label=\"Chebyshev fit\")\n",
"plot!(x -> f(x, c), label=\"Rational fit\")"
]
},
{
"cell_type": "markdown",
"id": "7de6fd47-3ac9-4cf6-b0c1-05606eff9e7c",
"metadata": {},
"source": [
"## Computing derivatives\n",
"\n",
"How should we compute these derivatives as the model gets complicated?\n",
"\n",
"Recall the definition of the derivative:\n",
"\n",
"$$ \\lim_{h \\rightarrow 0} \\frac{f \\left( x + h \\right) - f \\left( x \\right)}{h} $$\n",
"\n",
"* How should we pick $h$?\n",
"\n",
"* Too big: discretization error dominates (think of truncating the Taylor series)\n",
"\n",
"* Too small: rounding error dominates"
]
},
{
"cell_type": "markdown",
"id": "b8b6b9e1-6ba2-41d9-a6d2-2471c587d237",
"metadata": {},
"source": [
"### Automatic step size selection\n",
"\n",
"Walker and Pernice, Dennis and Schnabel developed ways to automatically choose the step size when computing the derivative numerically."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8b338775-9137-4b23-aab6-a143320076a9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"diff(sin, x) - cos(x) = -7.201219542896098e-8\n",
"diff_wp(sin, x) - cos(x) = 2.790787678730311e-8\n"
]
}
],
"source": [
"# Derivative via differancing\n",
"diff(f, x; h=1e-8) = (f(x+h) - f(x)) / h\n",
"\n",
"# And automatic selection of h\n",
"function diff_wp(f, x; h=1e-8)\n",
" \"\"\"Diff using Walker and Pernice (1998) choice of step\"\"\"\n",
" h *= (1 + abs(x))\n",
" (f(x+h) - f(x)) / h\n",
"end\n",
"\n",
"# Let's try it!\n",
"x = 10\n",
"@show diff(sin, x) - cos(x)\n",
"@show diff_wp(sin, x) - cos(x);"
]
},
{
"cell_type": "markdown",
"id": "0c1b1e03-1ca2-4e66-8ac3-58095e368235",
"metadata": {},
"source": [
"## Symbolic differentiation\n",
"\n",
"We can also use a package to symbolically differentiate like we would by hand."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e98f0c5a-5928-478a-b11b-64ae51834256",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m\u001b[1m Resolving\u001b[22m\u001b[39m package versions...\n",
"\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `~/.julia/environments/v1.11/Project.toml`\n",
"\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `~/.julia/environments/v1.11/Manifest.toml`\n"
]
},
{
"data": {
"text/latex": [
"$$ \\begin{equation}\n",
"\\frac{\\mathrm{d} \\sin\\left( x \\right)}{\\mathrm{d}x}\n",
"\\end{equation}\n",
" $$"
],
"text/plain": [
"Differential(x)(sin(x))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using Pkg\n",
"pkg\"add Symbolics\"\n",
"using Symbolics\n",
"\n",
"# Some setup\n",
"@variables x\n",
"Dx = Differential(x)\n",
"\n",
"# And an example\n",
"y = sin(x)\n",
"Dx(y)"
]
},
{
"cell_type": "markdown",
"id": "14c4c72d-3eb8-4f36-8782-c65b9e11f2e7",
"metadata": {},
"source": [
"### Product rule\n",
"\n",
"This package can follow the product rule!"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e41073c3-6082-4159-a54a-8fe44dc00331",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$$ \\begin{equation}\n",
"\\frac{\\left( \\frac{\\cos\\left( x^{\\pi} \\right)}{x} - 3.1416 x^{2.1416} \\log\\left( x \\right) \\sin\\left( x^{\\pi} \\right) \\right) \\cos\\left( \\cos^{3.1416}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{3.1416} \\right)}{\\log\\left( x \\right) \\cos\\left( x^{\\pi} \\right)} - \\left( \\frac{3.1416 \\cos^{3.1416}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{2.1416}}{x} - 9.8696 \\cos^{2.1416}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{3.1416} x^{2.1416} \\sin\\left( x^{\\pi} \\right) \\right) \\log\\left( \\log\\left( x \\right) \\cos\\left( x^{\\pi} \\right) \\right) \\sin\\left( \\cos^{3.1416}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{3.1416} \\right)\n",
"\\end{equation}\n",
" $$"
],
"text/plain": [
"((cos(x^π) / x - 3.141592653589793(x^2.141592653589793)*log(x)*sin(x^π))*cos((log(x)^3.141592653589793)*(cos(x^π)^3.141592653589793))) / (log(x)*cos(x^π)) - ((3.141592653589793(log(x)^2.141592653589793)*(cos(x^π)^3.141592653589793)) / x - 9.869604401089358(x^2.141592653589793)*(log(x)^3.141592653589793)*sin(x^π)*(cos(x^π)^2.141592653589793))*log(log(x)*cos(x^π))*sin((log(x)^3.141592653589793)*(cos(x^π)^3.141592653589793))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = x\n",
"for _ in 1:2\n",
" y = cos(y^π) * log(y)\n",
"end\n",
"\n",
"# This will be... a lot\n",
"expand_derivatives(Dx(y))"
]
},
{
"cell_type": "markdown",
"id": "231305d3-dd15-4568-9019-04b87ef991e1",
"metadata": {},
"source": [
"The size of these expressions can grow **exponentially**!"
]
},
{
"cell_type": "markdown",
"id": "1468105e-6d13-440d-8799-e833968b492d",
"metadata": {},
"source": [
"## Hand-coded derivatives\n",
"\n",
"With (mild) algebra abuse, the expression\n",
"\n",
"$$ \\frac{df}{dx} = f' \\left( x \\right) $$\n",
"\n",
"is equivalent to\n",
"\n",
"$$ df = f' \\left( x \\right) dx $$"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "fb7caa5a-b9b2-4bbd-a6da-5e80b47da514",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(-1.5346823414986814, -34.032439961925064)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function f(x)\n",
" y = x\n",
" for _ in 1:2\n",
" a = y^π\n",
" b = cos(a)\n",
" c = log(y)\n",
" y = b * c\n",
" end\n",
" y\n",
"end\n",
"\n",
"f(1.9), diff_wp(f, 1.9)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fdd708aa-d70f-42e4-ab82-aa5c5c8947f1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(-1.5346823414986814, -34.032419599140475)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function df(x, dx)\n",
" y = x\n",
" dy = dx\n",
" for _ in 1:2\n",
" a = y^π\n",
" da = π * y^(π - 1) * dy\n",
" b = cos(a)\n",
" db = -sin(a) * da\n",
" c = log(y)\n",
" dc = dy / y\n",
" y = b * c\n",
" dy = db * c + b * dc\n",
" end\n",
" y, dy\n",
"end\n",
"\n",
"df(1.9, 1)"
]
},
{
"cell_type": "markdown",
"id": "f04e9352-442a-406d-b10a-9f103bd152c9",
"metadata": {},
"source": [
"### Forward vs reverse mode\n",
"\n",
"We can differentiate a composition $h \\left( g \\left( f \\left( x \\right) \\right) \\right)$ as\n",
"\n",
"$$ \\begin{align}\n",
" \\operatorname{d} h &= h' \\operatorname{d} g \\\\\n",
" \\operatorname{d} g &= g' \\operatorname{d} f \\\\\n",
" \\operatorname{d} f &= f' \\operatorname{d} x\n",
"\\end{align}$$\n",
"\n",
"What we've done above is called \"forward mode\", and amounts to placing the parentheses in the chain rule like\n",
"\n",
"$$ \\operatorname{d} h = \\frac{dh}{dg} \\left( \\frac{dg}{df} \\left( \\frac{df}{dx} \\operatorname{d} x \\right) \\right) $$\n",
"\n",
"This expression means the same thing if we rearrange the parenthesis,\n",
"\n",
"$$ \\operatorname{d} h = \\left( \\left( \\left( \\frac{dh}{dg} \\right) \\frac{dg}{df} \\right) \\frac{df}{dx} \\right) \\operatorname{d} x $$"
]
},
{
"cell_type": "markdown",
"id": "89c8a57a-415c-47f1-9580-9cb002a05e1e",
"metadata": {},
"source": [
"## Reverse mode example\n",
"\n",
"Let's do an example to better understand.\n",
"\n",
"$$ \\underbrace{\\bar x}_{\\frac{dh}{dx}} = \\underbrace{\\bar g \\frac{dg}{df}}_{\\bar f} \\frac{df}{dx} $$"
]
},
{
"cell_type": "markdown",
"id": "d35d952f-75ad-40c9-abc6-0e910eb1a23d",
"metadata": {},
"source": [
"Let's consider the function\n",
"\n",
"$$ z = x \\cdot y + \\sin \\left( x \\right) $$\n",
"\n",
"We can represent this in pesudocode as\n",
"\n",
"```\n",
"x = ?\n",
"y = ?\n",
"a = x * y\n",
"b = sin(x)\n",
"z = a + b\n",
"```\n",
"\n",
"Evaluating the derivative in the forward mode gives\n",
"\n",
"```\n",
"dx = ?\n",
"dy = ?\n",
"da = y * dx + x * dy\n",
"db = cos(x) * dx\n",
"dz = da + db\n",
"```\n",
"\n",
"and evaluating in the reverse mode gives\n",
"\n",
"```\n",
"gz = ?\n",
"gb = gz\n",
"ga = gz\n",
"gy = x * ga\n",
"gx = y * ga + cos(x) * gb\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "7a179966-27b3-422c-a437-14364d719904",
"metadata": {},
"source": [
"## Automatic differentiation"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7c6270a4-c48b-4a2f-86d4-d86a095508f4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m\u001b[1m Resolving\u001b[22m\u001b[39m package versions...\n",
"\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `~/.julia/environments/v1.11/Project.toml`\n",
"\u001b[32m\u001b[1m No Changes\u001b[22m\u001b[39m to `~/.julia/environments/v1.11/Manifest.toml`\n"
]
}
],
"source": [
"using Pkg\n",
"pkg\"add Zygote\"\n",
"import Zygote"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f6480617-5257-4e20-a5d2-d01c1d52c583",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(-34.03241959914049,)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Zygote.gradient(f, 1.9)"
]
},
{
"cell_type": "markdown",
"id": "dab90022-6472-4a04-a63c-292daf434e01",
"metadata": {},
"source": [
"## But how?\n",
"\n",
"It's cool that Zygote works, but how does it actually work?"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c15e67c1-93a2-4575-9533-c96ac84be5d6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[90m; Function Signature: square(Float64)\u001b[39m\n",
"\u001b[90m; @ In[12]:1 within `square`\u001b[39m\n",
"\u001b[95mdefine\u001b[39m \u001b[36mdouble\u001b[39m \u001b[93m@julia_square_35825\u001b[39m\u001b[33m(\u001b[39m\u001b[36mdouble\u001b[39m \u001b[0m%\"x::Float64\"\u001b[33m)\u001b[39m \u001b[0m#0 \u001b[33m{\u001b[39m\n",
"\u001b[91mtop:\u001b[39m\n",
"\u001b[90m; ┌ @ intfuncs.jl:370 within `literal_pow`\u001b[39m\n",
"\u001b[90m; │┌ @ float.jl:493 within `*`\u001b[39m\n",
" \u001b[0m%0 \u001b[0m= \u001b[96m\u001b[1mfmul\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%\"x::Float64\"\u001b[0m, \u001b[0m%\"x::Float64\"\n",
" \u001b[96m\u001b[1mret\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%0\n",
"\u001b[90m; └└\u001b[39m\n",
"\u001b[33m}\u001b[39m\n"
]
}
],
"source": [
"square(x) = x^2\n",
"# Let's look at the LLVM bitcode here\n",
"@code_llvm square(1.5)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "65a3be3d-546d-4d1d-8f0e-8964091f7393",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[90m; Function Signature: gradient(typeof(Main.square), Float64)\u001b[39m\n",
"\u001b[90m; @ /home/jeremy/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:152 within `gradient`\u001b[39m\n",
"\u001b[95mdefine\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[93m@julia_gradient_36053\u001b[39m\u001b[33m(\u001b[39m\u001b[36mdouble\u001b[39m \u001b[0m%\"args\u001b[33m[\u001b[39m\u001b[33m1\u001b[39m\u001b[33m]\u001b[39m\u001b[0m::Float64\"\u001b[33m)\u001b[39m \u001b[0m#0 \u001b[33m{\u001b[39m\n",
"\u001b[91mtop:\u001b[39m\n",
"\u001b[90m; @ /home/jeremy/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:154 within `gradient`\u001b[39m\n",
"\u001b[90m; ┌ @ /home/jeremy/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:97 within `#88`\u001b[39m\n",
"\u001b[90m; │┌ @ In[12]:1 within `square`\u001b[39m\n",
"\u001b[90m; ││┌ @ /home/jeremy/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:222 within `ZBack`\u001b[39m\n",
"\u001b[90m; │││┌ @ /home/jeremy/.julia/packages/Zygote/55SqB/src/lib/number.jl:12 within `literal_pow_pullback`\u001b[39m\n",
"\u001b[90m; ││││┌ @ promotion.jl:430 within `*` @ float.jl:493\u001b[39m\n",
" \u001b[0m%0 \u001b[0m= \u001b[96m\u001b[1mfmul\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%\"args\u001b[33m[\u001b[39m\u001b[33m1\u001b[39m\u001b[33m]\u001b[39m\u001b[0m::Float64\"\u001b[0m, \u001b[33m2.000000e+00\u001b[39m\n",
"\u001b[90m; └└└└└\u001b[39m\n",
"\u001b[90m; @ /home/jeremy/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:155 within `gradient`\u001b[39m\n",
" \u001b[91m%\"new:\u001b[39m\u001b[0m:Tuple2.unbox.fca.0.insert\" \u001b[0m= \u001b[95minsertvalue\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[95mzeroinitializer\u001b[39m\u001b[0m, \u001b[36mdouble\u001b[39m \u001b[0m%0\u001b[0m, \u001b[33m0\u001b[39m\n",
" \u001b[96m\u001b[1mret\u001b[22m\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[0m%\"new::Tuple2.unbox.fca.0.insert\"\n",
"\u001b[33m}\u001b[39m\n"
]
}
],
"source": [
"# And here is the LLVM bitcode of Zygote's derivative\n",
"@code_llvm Zygote.gradient(square, 1.5)"
]
},
{
"cell_type": "markdown",
"id": "6313d9fb-e0e3-4145-8e9b-a5ed4376111d",
"metadata": {},
"source": [
"## Types of algorithmic differentiation\n",
"\n",
"* Source transformation: Fortran code in, Fortran code out\n",
"\n",
" * Duplicates compiler features, usually incomplete language coverage\n",
"\n",
" * Produces efficient code\n",
"\n",
"* Operator overloading: C++ types\n",
"\n",
" * Hard to vectorize\n",
"\n",
" * Loops are effectively unrolled/inefficient\n",
"\n",
"* Just-in-time compilation: tightly coupled with compiler\n",
"\n",
" * JIT lag\n",
"\n",
" * Needs dynamic language features (JAX) or tight integration with compiler (Zygote, Enzyme)\n",
"\n",
" * [Some sharp bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow)"
]
},
{
"cell_type": "markdown",
"id": "b9ec8ac2-2af7-49eb-ae34-85b8aa77af53",
"metadata": {},
"source": [
"## Forward or reverse mode\n",
"\n",
"Pick forward or reverse mode depending upon the 'shape' of your function.\n",
"\n",
"* One input, many outputs: use forward mode\n",
"\n",
" * \"One input\" can be looking in one direction\n",
"\n",
"* Many inputs, one output: use reverse mode\n",
"\n",
" * Will need to traverse execution backwards (\"tape\")\n",
" \n",
" * Hierarchical checkpointing\n",
" \n",
"* About square? Forward mode is usually a bit more efficient"
]
},
{
"cell_type": "markdown",
"id": "b5eff0db-b422-4eb5-ab09-f7f137460280",
"metadata": {},
"source": [
"## Ill-conditioned optimization\n",
"\n",
"$$ L \\left( c; x, y \\right) = \\frac 1 2 \\lVert \\underbrace{f \\left( x, c \\right) - y}_{r \\left( c \\right)} \\rVert_{C^{-1}}^2 $$\n",
"\n",
"Gradient of $L$ requires the Jacobian $J$ of the model $f$.\n",
"\n",
"$$ g \\left( c \\right) = \\nabla_c L = r^T \\underbrace{\\nabla_c f}_{J} $$\n",
"\n",
"We can solve $g \\left( c \\right) = 0$ using a Newton method\n",
"\n",
"$$ g \\left( c + \\delta c \\right) = g \\left( c \\right) + \\underbrace{\\nabla_c g}_{H} \\delta c + \\mathcal{O} \\left( \\left( \\delta c \\right)^2 \\right) $$\n",
"\n",
"The Hessian requires the second derivative of $f$, which can cause problems\n",
"\n",
"$$ H = J^T J + r^T \\left( \\nabla_c J \\right) $$"
]
},
{
"cell_type": "markdown",
"id": "21b8f927-a7a4-4628-a761-778186e6d260",
"metadata": {},
"source": [
"## Newton-like methods for optimization\n",
"\n",
"Solve\n",
"\n",
"$$ H \\delta c = - g \\left( c \\right) $$\n",
"\n",
"Update $c \\leftarrow c + \\gamma \\delta c$ using using a line search or trust region."
]
},
{
"cell_type": "markdown",
"id": "72061af6-e5bb-4015-892d-6aa61a4fc6a0",
"metadata": {},
"source": [
"## Outlook\n",
"\n",
"* The optimization problem can be solved using a Newton method.\n",
"It can be onerous to implement the needed derivatives.\n",
"\n",
"* The Gauss-Newton method (see activity) is often more practical than Newton while being faster than gradient descent, though it lacks robustness.\n",
"\n",
"* The Levenberg-Marquardt method provides a sort of middle-ground between Gauss-Newton and gradient descent.\n",
"\n",
"* Many globalization techniques are used for models that possess many local minima.\n",
"\n",
"* One pervasive approach is stochastic gradient descent, where small batches (e.g., 1 or 10 or 20) are selected randomly from the corpus of observations (500 in our current example), and a step of gradient descent is applied to that reduced set of observations.\n",
"This helps to escape saddle points and weak local minima.\n",
"\n",
"* Among expressive models $f \\left( x, c \\right)$, some may converge much more easily than others.\n",
"Having a good optimization algorithm is essential for nonlinear regression with complicated models, especially those with many parameters $c$.\n",
"\n",
"* Classification is a very similar problem to regression, but the observations $y$ are discrete, thus\n",
"\n",
" * models $f \\left( x, c \\right)$ must have discrete output\n",
"\n",
" * the least squares loss function is not appropriate\n",
"\n",
"* [Why momentum really works](https://distill.pub/2017/momentum/)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.11.6",
"language": "julia",
"name": "julia-1.11"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}