{ "cells": [ { "cell_type": "markdown", "metadata": { "vscode": { "languageId": "plaintext" } }, "source": [ "# Derivatives" ] }, { "cell_type": "markdown", "metadata": { "vscode": { "languageId": "plaintext" } }, "source": [ "As explained in backend dev guide, we use `JAX` for automatic differentiation. And if the user doesn't have `JAX`, we use finite differences to compute the derivatives. This occurs at `desc/derivatives.py` with the following,\n", "\n", "```python\n", "from desc.backend import use_jax # True if there is JAX installation, False otherwise\n", "Derivative = AutoDiffDerivative if use_jax else FiniteDiffDerivative\n", "```\n", "\n", "Usually, this portion of the code is not very visible, since we have derivative methods for `Objective` classes such as `jac_scaled`, `jac_scaled_error`, `jvp_scaled_error` etc." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start with an example of getting the full jacobian matrix of `ForceBalance` objective." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "\n", "sys.path.insert(0, os.path.abspath(\".\"))\n", "sys.path.append(os.path.abspath(\"../../../\"))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from desc.objectives import ObjectiveFunction, ForceBalance\n", "from desc.examples import get" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Precomputing transforms\n" ] } ], "source": [ "# Use W7-X equilibrium from examples\n", "eq = get(\"W7-X\")\n", "# Initialize and build the objective\n", "obj = ForceBalance(eq)\n", "obj.build()" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "({'R_lmn': Array([-7.61552162e-06, 9.16996648e-05, 1.86963797e-05, ...,\n", " 2.19188894e-08, -1.74145484e-06, 0.00000000e+00], dtype=float64),\n", " 'Z_lmn': Array([-2.40502618e-05, -8.58547166e-05, -1.50730988e-05, ...,\n", " 5.67762944e-07, -5.57790527e-07, 8.68739525e-09], dtype=float64),\n", " 'L_lmn': Array([ 2.44528892e-05, -3.26433442e-05, 4.54266096e-05, ...,\n", " 1.96745697e-06, 1.08110518e-06, 9.52033876e-07], dtype=float64),\n", " 'p_l': Array([ 185596.929, -371193.859, 185596.929, 0. , 0. ,\n", " 0. , 0. ], dtype=float64),\n", " 'i_l': Array([-0.85604702, -0.03880954, -0.06867951, -0.01869703, 0.01905612,\n", " 0. , 0. ], dtype=float64),\n", " 'c_l': Array([], shape=(0,), dtype=float64),\n", " 'Psi': Array([-2.133], dtype=float64),\n", " 'Te_l': Array([], shape=(0,), dtype=float64),\n", " 'ne_l': Array([], shape=(0,), dtype=float64),\n", " 'Ti_l': Array([], shape=(0,), dtype=float64),\n", " 'Zeff_l': Array([], shape=(0,), dtype=float64),\n", " 'a_lmn': Array([], shape=(0,), dtype=float64),\n", " 'Ra_n': Array([ 5.60514435e+00, 3.59944651e-01, 1.35342448e-02, 8.05206303e-04,\n", " -5.52216662e-05, -4.44705537e-05, -8.69796953e-05, -1.31024894e-05,\n", " -1.16094194e-04, 1.46746080e-05, 8.95046765e-05, 6.43224852e-05,\n", " 2.50291789e-05], dtype=float64),\n", " 'Za_n': Array([ 9.41747282e-06, 1.91787732e-05, -1.31424684e-05, -6.94635436e-05,\n", " 8.13912452e-05, -2.27957475e-05, -3.17561695e-06, 8.07758929e-05,\n", " -5.52676301e-05, -8.85642777e-05, 1.62604346e-02, 3.02638117e-01], dtype=float64),\n", " 'Rb_lmn': Array([ 0.00000000e+00, -1.25396811e-06, 2.37682928e-06, 5.95368860e-07,\n", " -9.54472906e-07, 1.76054297e-06, -7.80023478e-07, 2.85992538e-05,\n", " -5.44278689e-05, 4.35065475e-05, 1.08419042e-04, -4.94117636e-05,\n", " 0.00000000e+00, 1.34029977e-05, 2.04240633e-06, -7.06203921e-06,\n", " -5.25684139e-06, 5.26857420e-06, 9.89961072e-06, -1.98539507e-05,\n", " -1.37306672e-05, 5.71040199e-05, -1.49695429e-04, -4.60454698e-04,\n", " 0.00000000e+00, -3.93586902e-06, -1.24341432e-05, -4.31824078e-06,\n", " -5.97446210e-06, 1.93177836e-05, -1.08202463e-05, -3.19000253e-05,\n", " 6.01629881e-05, -5.05481211e-05, 1.09639483e-04, 5.60451377e-04,\n", " 0.00000000e+00, 2.19434568e-05, 5.68552543e-05, -3.39579527e-06,\n", " 3.40497368e-07, -7.07304561e-07, -2.90358710e-06, 2.39414988e-06,\n", " -3.41254747e-05, 1.20202218e-04, 1.18651124e-04, -8.32041206e-04,\n", " 0.00000000e+00, -1.24025745e-04, -1.83680116e-05, -3.52409581e-05,\n", " -4.37411937e-05, -2.60458588e-05, 1.62577249e-05, -4.77267423e-05,\n", " -1.01162081e-05, 1.03273942e-04, 1.34530102e-04, 1.23911346e-03,\n", " 0.00000000e+00, -6.76061823e-06, 8.70797194e-05, 1.51334703e-04,\n", " -2.48705314e-05, 8.40929499e-05, -1.76719549e-05, -8.28560112e-06,\n", " -8.39944402e-05, 1.09293507e-05, -2.35790085e-04, -2.90309898e-04,\n", " 0.00000000e+00, -1.40711798e-04, -3.72085801e-05, 1.05608247e-04,\n", " -3.56672224e-04, -1.64028860e-05, -7.46879435e-05, -1.89335072e-04,\n", " 9.95517680e-05, -1.68951959e-04, 3.34416910e-04, -5.08760113e-04,\n", " 0.00000000e+00, -2.01053288e-04, 1.98060835e-04, 2.82046428e-04,\n", " 3.24819528e-05, -9.16676344e-06, 5.84552500e-04, -2.48935819e-04,\n", " 5.72081394e-04, -8.79364811e-06, -4.79856375e-04, -4.27183536e-04,\n", " 0.00000000e+00, -2.33672679e-04, 1.47109523e-05, 4.95887339e-04,\n", " -3.63675836e-04, -8.63986194e-04, 7.70859371e-04, -8.90638973e-04,\n", " -4.37097812e-04, -1.15208699e-03, 6.61923757e-04, -2.87994589e-03,\n", " 0.00000000e+00, -1.71082980e-04, 1.02042117e-04, 7.05405746e-04,\n", " -8.08356490e-04, -1.00149518e-04, 1.39281338e-03, 1.11088184e-03,\n", " -2.98470909e-03, 1.31711124e-02, 1.72199858e-03, -2.51570115e-03,\n", " 0.00000000e+00, -6.66266145e-05, 3.53781230e-04, 3.67472397e-04,\n", " -7.30928521e-04, -5.53093355e-04, 1.67717947e-03, 4.49574982e-04,\n", " -8.59264072e-03, 2.10631146e-02, -6.55119383e-02, 2.04921849e-02,\n", " 0.00000000e+00, 4.93164572e-05, -1.37934002e-06, 3.48747582e-05,\n", " -2.49952008e-04, -1.26808781e-04, 1.55916882e-03, -1.37401892e-03,\n", " 6.40857796e-04, 1.38368078e-02, -3.30987517e-02, 2.37685981e-01,\n", " 5.52080588e+00, 4.88750272e-01, 3.80212296e-02, -2.74340380e-03,\n", " 2.26250894e-03, 4.59523537e-04, 1.45473621e-04, -4.19920226e-04,\n", " 6.61312391e-05, 1.20380771e-04, -4.93785187e-05, 5.06253648e-05,\n", " 0.00000000e+00, 2.77874172e-01, -1.88444111e-01, 5.60375622e-02,\n", " -1.06466419e-02, -1.61546780e-03, 2.60987218e-03, -1.65588471e-03,\n", " 1.54348374e-04, 1.41240066e-04, -4.68985644e-05, -8.74637332e-06,\n", " -9.09094130e-05, 0.00000000e+00, -6.89426147e-03, -1.58823405e-02,\n", " 6.98701599e-02, -2.11167665e-02, 8.73971829e-03, -6.15425293e-04,\n", " -1.47192353e-03, 5.63154628e-04, 7.94970612e-04, -3.66554531e-04,\n", " -2.75449640e-04, 4.96466569e-05, 0.00000000e+00, -1.07405103e-04,\n", " 2.07450036e-03, -1.44334488e-03, -1.32273068e-02, 3.34646659e-03,\n", " -1.24245401e-03, -1.25560748e-03, -1.53205832e-05, 7.74630935e-04,\n", " -7.17526096e-04, -1.38685668e-04, 1.80207036e-04, 0.00000000e+00,\n", " -1.44894391e-03, 5.37529112e-04, -7.51293596e-04, 1.63860140e-03,\n", " 4.27561816e-04, 1.21209333e-03, -8.17142801e-04, 1.02439933e-03,\n", " 3.24105307e-04, -4.33093962e-04, 1.53808458e-06, 2.31586111e-04,\n", " 0.00000000e+00, -7.70555083e-05, 1.54415389e-04, -2.57148748e-04,\n", " 1.34582918e-04, -7.82040872e-04, 6.89024028e-05, -6.41100249e-04,\n", " -1.00080461e-04, -3.13016107e-05, -2.91005885e-04, -2.03612905e-04,\n", " 2.00479184e-04, 0.00000000e+00, -3.27434063e-04, -6.43809437e-04,\n", " 7.09680036e-04, -1.80269282e-04, 1.67172133e-04, -1.82015689e-05,\n", " 1.68919192e-04, -2.10527819e-06, 3.60964029e-04, -9.20602182e-05,\n", " 2.74357397e-05, 1.45678556e-04, 0.00000000e+00, 1.32418908e-04,\n", " -2.35219751e-04, -3.66240508e-04, 7.96095394e-05, -1.48855380e-04,\n", " -1.34999964e-05, -2.50806296e-05, -1.30755858e-04, 1.94245152e-05,\n", " -1.45982291e-04, -8.14607170e-05, 9.32716986e-06, 0.00000000e+00,\n", " -3.97326415e-04, 4.74403413e-04, 1.62224727e-04, 1.08398906e-04,\n", " -8.91249878e-05, 1.83884831e-05, -2.40289881e-05, 7.38105894e-06,\n", " 3.02917180e-05, 3.58997199e-05, 1.86696917e-05, 1.27153388e-04,\n", " 0.00000000e+00, 3.14252134e-04, 2.42775627e-04, 3.45611712e-05,\n", " 1.31094839e-04, -7.53604442e-06, -2.74296245e-05, 5.72755107e-06,\n", " -1.15537077e-07, -1.41883374e-05, -8.27834068e-07, -5.55663371e-05,\n", " -1.95851990e-05, 0.00000000e+00, 2.81914663e-04, 4.40443120e-06,\n", " -1.49769088e-04, 1.47248844e-04, -4.96393928e-05, 8.95590646e-06,\n", " 1.62116235e-05, -1.41344880e-05, -1.57793034e-06, -1.08519927e-06,\n", " 1.40464052e-05, 9.31369257e-06, 0.00000000e+00, 1.20763030e-04,\n", " 2.83437068e-04, 1.12192149e-04, -3.35557551e-05, 9.58778856e-06,\n", " 6.55749433e-06, -8.07589787e-06, 3.41500830e-06, -5.71000472e-07,\n", " 2.86601658e-07, -9.66274877e-07, -9.79588887e-06, 0.00000000e+00,\n", " 1.56387880e-05, -5.65274168e-05, -1.30469763e-04, -3.56260557e-05,\n", " 5.26731348e-05, -3.92943660e-05, 4.93030264e-06, -1.03452643e-06,\n", " 8.20391030e-07, -1.42760697e-06, -3.71332831e-06, 1.81690419e-06,\n", " 0.00000000e+00], dtype=float64),\n", " 'Zb_lmn': Array([-1.59145234e-05, -1.71391266e-04, -7.91973944e-05, -4.39062257e-05,\n", " 7.98893892e-05, 2.72582146e-06, -2.93867340e-05, 1.57097819e-05,\n", " -4.26560267e-06, -4.55237782e-06, 7.54645015e-06, -3.17065262e-07,\n", " 0.00000000e+00, -9.17007866e-05, -3.55313493e-04, 1.48658850e-04,\n", " -2.99905782e-05, -9.68237731e-05, 4.47130254e-05, 1.53726948e-05,\n", " -1.41235588e-05, 1.57602711e-06, -3.51287786e-06, -1.05258685e-06,\n", " -2.14237913e-06, 0.00000000e+00, -3.55416724e-04, 9.57510199e-05,\n", " -2.31592515e-04, 1.55081199e-04, 3.97931887e-05, -6.29524645e-05,\n", " 1.17195519e-05, 1.15344561e-05, -1.09573381e-05, 9.66159063e-06,\n", " 6.70075813e-06, 1.11565939e-05, 0.00000000e+00, -5.44847705e-04,\n", " -2.94447353e-04, -1.14415141e-04, 5.59860566e-06, 4.10113373e-05,\n", " 3.51568705e-05, -3.61987404e-05, 4.77649802e-06, -1.77793812e-06,\n", " -4.60374739e-06, -2.71376911e-05, -2.19835560e-05, 0.00000000e+00,\n", " 6.04240195e-04, 2.95953162e-04, -1.53984757e-04, -1.37946293e-04,\n", " 3.80053900e-07, 5.42775470e-05, 2.16718456e-05, -3.83085098e-05,\n", " 4.17106338e-05, -4.29763317e-06, 5.00218959e-05, -2.66433575e-05,\n", " 0.00000000e+00, -2.01210283e-04, 2.72924191e-04, 4.89452597e-04,\n", " -9.00784540e-05, 1.47845856e-04, -2.33234209e-05, -2.70819118e-05,\n", " -5.22733231e-05, -3.59423038e-05, 2.28237806e-05, 2.20527970e-05,\n", " 1.32913610e-04, 0.00000000e+00, -1.57367536e-04, 3.71725377e-04,\n", " -6.36426311e-04, 2.22300734e-04, -6.49692379e-05, 5.59928928e-05,\n", " -1.45262709e-05, 1.82817611e-04, -1.90225679e-04, -1.65789066e-05,\n", " -3.41635802e-04, 1.42370007e-04, 0.00000000e+00, -2.18280703e-04,\n", " -6.74904801e-04, -1.39052993e-05, 6.33690036e-05, -5.29789071e-04,\n", " 1.89630376e-04, -1.60017118e-04, 1.81558635e-04, 3.51937873e-04,\n", " 8.63840349e-05, -3.79244210e-04, -8.62999673e-05, 0.00000000e+00,\n", " -1.79922860e-03, -4.21381224e-04, -8.40471962e-04, 7.59434356e-04,\n", " 1.55512209e-03, -1.26593821e-03, 5.66615606e-04, -1.23249156e-03,\n", " 4.24568541e-04, 3.51304802e-04, -3.26262397e-04, -1.99365207e-04,\n", " 0.00000000e+00, -2.07332464e-03, 2.19586433e-03, -2.88732658e-03,\n", " -1.10527210e-02, 4.70896266e-03, 4.78332141e-04, 1.06683627e-03,\n", " -7.17034272e-04, -2.02991704e-04, 4.28040480e-04, -1.65562537e-04,\n", " -4.19786631e-05, 0.00000000e+00, -2.68475500e-03, -5.04561720e-03,\n", " 5.05067595e-02, -6.80661202e-03, -8.65891445e-03, 2.79244197e-03,\n", " 1.07713589e-03, -3.70605228e-04, -2.57765881e-04, -3.87070707e-05,\n", " -2.61759977e-04, 5.64993554e-05, 0.00000000e+00, 2.35387780e-01,\n", " -1.85125153e-01, -1.15440713e-02, 4.04994737e-03, -1.86158757e-03,\n", " 3.03139010e-04, 8.67626564e-04, -1.08117062e-04, -1.49375476e-04,\n", " 8.38550268e-05, 2.11667699e-04, -8.97990281e-05, 0.00000000e+00,\n", " 0.00000000e+00, 1.37712271e-05, -6.36611914e-06, -1.18681624e-05,\n", " -1.24789101e-04, 2.44036600e-04, 9.45804113e-06, -9.83701436e-04,\n", " -7.80804617e-04, 1.45689524e-03, 4.22713826e-03, -6.25162134e-01,\n", " 0.00000000e+00, -8.85719468e-05, 2.52969231e-04, -7.29692775e-06,\n", " -8.38344991e-05, -2.90782661e-05, 8.66129398e-04, -5.38883085e-04,\n", " -1.98479088e-03, 5.02096257e-03, -2.78218540e-02, -2.30947866e-01,\n", " 0.00000000e+00, 5.03558841e-05, -2.63815440e-04, -1.82971769e-05,\n", " -2.36363570e-04, -4.43164027e-04, 1.13029804e-03, 2.87552785e-03,\n", " -8.74378594e-03, -8.03135304e-03, 5.04579599e-02, -2.09901887e-02,\n", " 0.00000000e+00, -4.14802544e-05, -1.65366676e-04, 4.80830286e-04,\n", " -2.63663437e-04, -6.40520873e-04, 8.43705177e-04, 3.94648880e-04,\n", " 5.05183873e-03, -1.09983099e-02, -2.77445960e-03, 4.13869020e-04,\n", " 0.00000000e+00, -2.07992208e-04, -3.24063918e-04, 3.02798196e-04,\n", " 4.85144482e-04, -1.29914738e-03, 5.40226896e-04, -1.07054306e-03,\n", " 1.40360267e-03, 1.20732576e-03, -1.19048363e-03, 7.32864212e-04,\n", " 0.00000000e+00, -9.15227898e-05, -3.69086955e-04, 1.06458651e-04,\n", " 3.45515089e-04, 2.14998237e-04, -3.81748474e-05, -2.01645712e-04,\n", " -5.70283675e-04, 8.77129950e-05, -4.47129604e-04, 6.91900424e-05,\n", " 0.00000000e+00, 1.33523378e-04, -3.38399645e-04, -2.61091588e-05,\n", " -2.05885200e-04, 1.93308397e-04, -5.31859955e-05, 2.45222218e-04,\n", " 1.47426092e-05, -2.25792401e-04, 7.56420778e-04, -9.31423439e-04,\n", " 0.00000000e+00, 1.26395074e-04, 2.76665231e-05, 1.78997551e-05,\n", " -4.72089225e-05, -6.07456676e-06, 2.26200365e-05, -8.54052921e-05,\n", " -1.00621364e-04, 1.47021789e-04, -4.73474277e-04, -2.13248831e-04,\n", " 0.00000000e+00, -3.59572996e-05, 4.30604167e-05, -1.61850437e-05,\n", " 4.29341084e-05, -6.52970608e-06, 3.88704850e-05, 3.69296819e-06,\n", " -1.10540268e-04, 1.25347560e-04, 2.51674680e-04, 5.51525894e-04,\n", " 0.00000000e+00, -3.79914757e-05, -3.00119956e-05, -4.93139249e-06,\n", " 2.28252475e-06, 1.22920907e-05, -3.76491919e-05, 1.68796655e-05,\n", " 2.66907520e-05, -1.10558632e-04, 1.75346678e-05, 4.43658096e-04,\n", " 0.00000000e+00, -3.45211694e-06, 1.89232570e-06, 1.32004602e-05,\n", " -1.21583278e-05, 1.46353196e-05, 1.39711447e-05, -6.49184566e-05,\n", " 8.94858145e-06, 1.07985868e-04, 2.22851574e-06, 2.58804360e-05,\n", " 0.00000000e+00, -1.00618400e-05, 3.49319825e-07, -8.77056443e-07,\n", " 2.93769886e-06, -1.75338508e-05, 1.92668754e-05, 4.79370077e-05,\n", " -1.14222396e-04, -3.44231084e-05, 1.99798412e-04, 1.38273445e-04,\n", " 0.00000000e+00, -1.70913709e-06, 8.69644914e-06, -2.13835896e-06,\n", " -3.74268677e-06, 1.68316270e-05, -2.94087823e-05, 5.84730031e-06,\n", " 7.94816640e-05, -4.92286011e-05, -6.37424206e-05, -4.43234863e-05], dtype=float64),\n", " 'I': Array([], shape=(0,), dtype=float64),\n", " 'G': Array([], shape=(0,), dtype=float64),\n", " 'Phi_mn': Array([], shape=(0,), dtype=float64)},)" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params = obj.xs(eq)\n", "params" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The portion of the Jacobian for\n", "\tG has shape (5346, 0)\n", "\tI has shape (5346, 0)\n", "\tL_lmn has shape (5346, 1134)\n", "\tPhi_mn has shape (5346, 0)\n", "\tPsi has shape (5346, 1)\n", "\tR_lmn has shape (5346, 1141)\n", "\tRa_n has shape (5346, 13)\n", "\tRb_lmn has shape (5346, 313)\n", "\tTe_l has shape (5346, 0)\n", "\tTi_l has shape (5346, 0)\n", "\tZ_lmn has shape (5346, 1134)\n", "\tZa_n has shape (5346, 12)\n", "\tZb_lmn has shape (5346, 312)\n", "\tZeff_l has shape (5346, 0)\n", "\ta_lmn has shape (5346, 0)\n", "\tc_l has shape (5346, 0)\n", "\ti_l has shape (5346, 7)\n", "\tne_l has shape (5346, 0)\n", "\tp_l has shape (5346, 7)\n", "Total number of parameters that we took the derivative for is 4074\n" ] } ], "source": [ "(J,) = obj.jac_scaled(*params)\n", "sum = 0\n", "print(\"The portion of the Jacobian for\")\n", "for key in J.keys():\n", " print(f\"\\t{key:10} has shape {J[key].shape}\")\n", " sum += J[key].shape[1]\n", "print(\"Total number of parameters that we took the derivative for is\", sum)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alternatively, we can also use the following syntax," ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(5346, 1141)" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(J,) = obj.jac_scaled(eq.params_dict)\n", "J[\"R_lmn\"].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This way of taking the Jacobian is useful if you need to investigate the effect of individual parameters. However, if you want to get a single Jacobian matrix, the proper way is to use an `ObjectiveFunction` to wrap the `Objective`. This can be done as follows," ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Building objective: force\n", "Precomputing transforms\n" ] }, { "data": { "text/plain": [ "(5346, 4074)" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "objfun = ObjectiveFunction(ForceBalance(eq))\n", "objfun.build()\n", "J = objfun.jac_scaled(objfun.x(eq))\n", "J.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see that if we would put individual parts of the previous method, we would get the same Jacobian matrix.\n", "\n", "In the code, you will see that we have many functions named like `compute_`, `jac_`, `vjp_` and `jvp_`. They are all variations of the original methods, applying some scaling, normalization or bound/target. Here is a brief summary of what they do,\n", "\n", "| **Function** | **Purpose** | **Full Jacobian** |\n", "|-----------------------------|------------------------------------|-------------------------------|\n", "| `compute` | Main method to compute the raw objective function. | |\n", "| `compute_unscaled` | Compute the raw value of the objective, optionally applying a loss function. | `jac_unscaled` |\n", "| `compute_scaled` | Compute the objective with weighting and normalization applied. | `jac_scaled` |\n", "| `compute_scaled_error` | In addition to `compute_scaled` makes bounds/target adjustments, weighting, and normalization. | `jac_scaled_error` |\n", "| `compute_scalar` | Compute the scalar value of the objective. $\\mathbf{f}^2/2$ | `grad` | \n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`jvp_` and `vjp_` methods compute the derivative in certain directions. These stand for jacobian vector product and vector jacobian product, and they are more efficient than taking the intended column from the full Jacobian matrix. If you look at the implementation of `jac_` methods, you will see that we are actually taking `jvp`s in each direction to form the full Jacobian.\n", "\n", "```python\n", "@jit\n", "def jac_scaled_error(self, x, constants=None):\n", " \"\"\"Compute Jacobian matrix of self.compute_scaled_error wrt x.\"\"\"\n", " v = jnp.eye(x.shape[0])\n", " return self.jvp_scaled_error(v, x, constants).T\n", "```\n", "\n", "Here, `v` is the tangents in each unit direction. In the code, we usually don't take the full Jacobian. For example, `LinearConstraintProjection` reduces the number of parameters to only operate in the null-space of the constraint matrix, but our `compute` function still takes the full state vector. So, how do we take the derivative in that case? The solution is a little bit of linear algebra. Let's consider the following problem.\n", "\n", "$$ \\min_{\\mathbf{x}} \\mathbf{f(x)} $$\n", "$$ \\text{subject to } x_1 = x_2 $$\n", "$$ \\mathbf{x} = [x_1, x_2, x_3, x_4] $$\n", "\n", "Since the constraint links $x_1$ and $x_2$, the reduced state vector will have only 3 parameters $\\mathbf{y} = [y_1, y_2, y_3]$ and $y_1=x_1=x_2$, $y_2=x_2$, $y_3=x_3$.\n", "\n", "Taking the derivative of $f$ with respect to $y_2$ and $y_3$ is straight-forward. But when we are taking the derivative in $y_1$ both $x_1$ and $x_2$ are changing, so we have to take the derivative in both directions. In this simple example deciding which parameters are free and which are dependent was easy. However, for more complex linear constraints, the more systematic way is to use the null-space matrix $Z$. If we want to take the derivative in $y_1$ direction, the tangent vector in reduced space is $\\mathbf{v}_r = [1, 0, 0]$, and the tangent vector in full space is $Z\\mathbf{v}_r$. We have a handy utility function to calculate the pseudo-inverse and null-space of a matrix. Here is how we get the full tangent direction for the simple example." ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Full tangent: [0.70710678 0.70710678 0. 0. ]\n", "Null-space:\n", " [[0.70710678 0. 0. ]\n", " [0.70710678 0. 0. ]\n", " [0. 1. 0. ]\n", " [0. 0. 1. ]]\n" ] } ], "source": [ "import numpy as np\n", "from desc.utils import svd_inv_null\n", "\n", "A = np.array([[1.0, -1.0, 0.0, 0.0]])\n", "Ainv, Z = svd_inv_null(A)\n", "vr = np.array([1, 0, 0])\n", "print(\"Full tangent: \", Z @ vr)\n", "print(\"Null-space:\\n\", Z)" ] } ], "metadata": { "kernelspec": { "display_name": "desc-env-cpu", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 2 }