Derivatives
As explained in backend dev guide, we use JAX for automatic differentiation. And if the user doesn’t have JAX, we use finite differences to compute the derivatives. This occurs at desc/derivatives.py with the following,
from desc.backend import use_jax # True if there is JAX installation, False otherwise
Derivative = AutoDiffDerivative if use_jax else FiniteDiffDerivative
Usually, this portion of the code is not very visible, since we have derivative methods for Objective classes such as jac_scaled, jac_scaled_error, jvp_scaled_error etc.
Let’s start with an example of getting the full jacobian matrix of ForceBalance objective.
[4]:
import sys
import os
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))
[5]:
from desc.objectives import ObjectiveFunction, ForceBalance
from desc.examples import get
[ ]:
# Use W7-X equilibrium from examples
eq = get("W7-X")
# Initialize and build the objective
obj = ForceBalance(eq)
obj.build()
Precomputing transforms
[43]:
params = obj.xs(eq)
params
[43]:
({'R_lmn': Array([-7.61552162e-06, 9.16996648e-05, 1.86963797e-05, ...,
2.19188894e-08, -1.74145484e-06, 0.00000000e+00], dtype=float64),
'Z_lmn': Array([-2.40502618e-05, -8.58547166e-05, -1.50730988e-05, ...,
5.67762944e-07, -5.57790527e-07, 8.68739525e-09], dtype=float64),
'L_lmn': Array([ 2.44528892e-05, -3.26433442e-05, 4.54266096e-05, ...,
1.96745697e-06, 1.08110518e-06, 9.52033876e-07], dtype=float64),
'p_l': Array([ 185596.929, -371193.859, 185596.929, 0. , 0. ,
0. , 0. ], dtype=float64),
'i_l': Array([-0.85604702, -0.03880954, -0.06867951, -0.01869703, 0.01905612,
0. , 0. ], dtype=float64),
'c_l': Array([], shape=(0,), dtype=float64),
'Psi': Array([-2.133], dtype=float64),
'Te_l': Array([], shape=(0,), dtype=float64),
'ne_l': Array([], shape=(0,), dtype=float64),
'Ti_l': Array([], shape=(0,), dtype=float64),
'Zeff_l': Array([], shape=(0,), dtype=float64),
'a_lmn': Array([], shape=(0,), dtype=float64),
'Ra_n': Array([ 5.60514435e+00, 3.59944651e-01, 1.35342448e-02, 8.05206303e-04,
-5.52216662e-05, -4.44705537e-05, -8.69796953e-05, -1.31024894e-05,
-1.16094194e-04, 1.46746080e-05, 8.95046765e-05, 6.43224852e-05,
2.50291789e-05], dtype=float64),
'Za_n': Array([ 9.41747282e-06, 1.91787732e-05, -1.31424684e-05, -6.94635436e-05,
8.13912452e-05, -2.27957475e-05, -3.17561695e-06, 8.07758929e-05,
-5.52676301e-05, -8.85642777e-05, 1.62604346e-02, 3.02638117e-01], dtype=float64),
'Rb_lmn': Array([ 0.00000000e+00, -1.25396811e-06, 2.37682928e-06, 5.95368860e-07,
-9.54472906e-07, 1.76054297e-06, -7.80023478e-07, 2.85992538e-05,
-5.44278689e-05, 4.35065475e-05, 1.08419042e-04, -4.94117636e-05,
0.00000000e+00, 1.34029977e-05, 2.04240633e-06, -7.06203921e-06,
-5.25684139e-06, 5.26857420e-06, 9.89961072e-06, -1.98539507e-05,
-1.37306672e-05, 5.71040199e-05, -1.49695429e-04, -4.60454698e-04,
0.00000000e+00, -3.93586902e-06, -1.24341432e-05, -4.31824078e-06,
-5.97446210e-06, 1.93177836e-05, -1.08202463e-05, -3.19000253e-05,
6.01629881e-05, -5.05481211e-05, 1.09639483e-04, 5.60451377e-04,
0.00000000e+00, 2.19434568e-05, 5.68552543e-05, -3.39579527e-06,
3.40497368e-07, -7.07304561e-07, -2.90358710e-06, 2.39414988e-06,
-3.41254747e-05, 1.20202218e-04, 1.18651124e-04, -8.32041206e-04,
0.00000000e+00, -1.24025745e-04, -1.83680116e-05, -3.52409581e-05,
-4.37411937e-05, -2.60458588e-05, 1.62577249e-05, -4.77267423e-05,
-1.01162081e-05, 1.03273942e-04, 1.34530102e-04, 1.23911346e-03,
0.00000000e+00, -6.76061823e-06, 8.70797194e-05, 1.51334703e-04,
-2.48705314e-05, 8.40929499e-05, -1.76719549e-05, -8.28560112e-06,
-8.39944402e-05, 1.09293507e-05, -2.35790085e-04, -2.90309898e-04,
0.00000000e+00, -1.40711798e-04, -3.72085801e-05, 1.05608247e-04,
-3.56672224e-04, -1.64028860e-05, -7.46879435e-05, -1.89335072e-04,
9.95517680e-05, -1.68951959e-04, 3.34416910e-04, -5.08760113e-04,
0.00000000e+00, -2.01053288e-04, 1.98060835e-04, 2.82046428e-04,
3.24819528e-05, -9.16676344e-06, 5.84552500e-04, -2.48935819e-04,
5.72081394e-04, -8.79364811e-06, -4.79856375e-04, -4.27183536e-04,
0.00000000e+00, -2.33672679e-04, 1.47109523e-05, 4.95887339e-04,
-3.63675836e-04, -8.63986194e-04, 7.70859371e-04, -8.90638973e-04,
-4.37097812e-04, -1.15208699e-03, 6.61923757e-04, -2.87994589e-03,
0.00000000e+00, -1.71082980e-04, 1.02042117e-04, 7.05405746e-04,
-8.08356490e-04, -1.00149518e-04, 1.39281338e-03, 1.11088184e-03,
-2.98470909e-03, 1.31711124e-02, 1.72199858e-03, -2.51570115e-03,
0.00000000e+00, -6.66266145e-05, 3.53781230e-04, 3.67472397e-04,
-7.30928521e-04, -5.53093355e-04, 1.67717947e-03, 4.49574982e-04,
-8.59264072e-03, 2.10631146e-02, -6.55119383e-02, 2.04921849e-02,
0.00000000e+00, 4.93164572e-05, -1.37934002e-06, 3.48747582e-05,
-2.49952008e-04, -1.26808781e-04, 1.55916882e-03, -1.37401892e-03,
6.40857796e-04, 1.38368078e-02, -3.30987517e-02, 2.37685981e-01,
5.52080588e+00, 4.88750272e-01, 3.80212296e-02, -2.74340380e-03,
2.26250894e-03, 4.59523537e-04, 1.45473621e-04, -4.19920226e-04,
6.61312391e-05, 1.20380771e-04, -4.93785187e-05, 5.06253648e-05,
0.00000000e+00, 2.77874172e-01, -1.88444111e-01, 5.60375622e-02,
-1.06466419e-02, -1.61546780e-03, 2.60987218e-03, -1.65588471e-03,
1.54348374e-04, 1.41240066e-04, -4.68985644e-05, -8.74637332e-06,
-9.09094130e-05, 0.00000000e+00, -6.89426147e-03, -1.58823405e-02,
6.98701599e-02, -2.11167665e-02, 8.73971829e-03, -6.15425293e-04,
-1.47192353e-03, 5.63154628e-04, 7.94970612e-04, -3.66554531e-04,
-2.75449640e-04, 4.96466569e-05, 0.00000000e+00, -1.07405103e-04,
2.07450036e-03, -1.44334488e-03, -1.32273068e-02, 3.34646659e-03,
-1.24245401e-03, -1.25560748e-03, -1.53205832e-05, 7.74630935e-04,
-7.17526096e-04, -1.38685668e-04, 1.80207036e-04, 0.00000000e+00,
-1.44894391e-03, 5.37529112e-04, -7.51293596e-04, 1.63860140e-03,
4.27561816e-04, 1.21209333e-03, -8.17142801e-04, 1.02439933e-03,
3.24105307e-04, -4.33093962e-04, 1.53808458e-06, 2.31586111e-04,
0.00000000e+00, -7.70555083e-05, 1.54415389e-04, -2.57148748e-04,
1.34582918e-04, -7.82040872e-04, 6.89024028e-05, -6.41100249e-04,
-1.00080461e-04, -3.13016107e-05, -2.91005885e-04, -2.03612905e-04,
2.00479184e-04, 0.00000000e+00, -3.27434063e-04, -6.43809437e-04,
7.09680036e-04, -1.80269282e-04, 1.67172133e-04, -1.82015689e-05,
1.68919192e-04, -2.10527819e-06, 3.60964029e-04, -9.20602182e-05,
2.74357397e-05, 1.45678556e-04, 0.00000000e+00, 1.32418908e-04,
-2.35219751e-04, -3.66240508e-04, 7.96095394e-05, -1.48855380e-04,
-1.34999964e-05, -2.50806296e-05, -1.30755858e-04, 1.94245152e-05,
-1.45982291e-04, -8.14607170e-05, 9.32716986e-06, 0.00000000e+00,
-3.97326415e-04, 4.74403413e-04, 1.62224727e-04, 1.08398906e-04,
-8.91249878e-05, 1.83884831e-05, -2.40289881e-05, 7.38105894e-06,
3.02917180e-05, 3.58997199e-05, 1.86696917e-05, 1.27153388e-04,
0.00000000e+00, 3.14252134e-04, 2.42775627e-04, 3.45611712e-05,
1.31094839e-04, -7.53604442e-06, -2.74296245e-05, 5.72755107e-06,
-1.15537077e-07, -1.41883374e-05, -8.27834068e-07, -5.55663371e-05,
-1.95851990e-05, 0.00000000e+00, 2.81914663e-04, 4.40443120e-06,
-1.49769088e-04, 1.47248844e-04, -4.96393928e-05, 8.95590646e-06,
1.62116235e-05, -1.41344880e-05, -1.57793034e-06, -1.08519927e-06,
1.40464052e-05, 9.31369257e-06, 0.00000000e+00, 1.20763030e-04,
2.83437068e-04, 1.12192149e-04, -3.35557551e-05, 9.58778856e-06,
6.55749433e-06, -8.07589787e-06, 3.41500830e-06, -5.71000472e-07,
2.86601658e-07, -9.66274877e-07, -9.79588887e-06, 0.00000000e+00,
1.56387880e-05, -5.65274168e-05, -1.30469763e-04, -3.56260557e-05,
5.26731348e-05, -3.92943660e-05, 4.93030264e-06, -1.03452643e-06,
8.20391030e-07, -1.42760697e-06, -3.71332831e-06, 1.81690419e-06,
0.00000000e+00], dtype=float64),
'Zb_lmn': Array([-1.59145234e-05, -1.71391266e-04, -7.91973944e-05, -4.39062257e-05,
7.98893892e-05, 2.72582146e-06, -2.93867340e-05, 1.57097819e-05,
-4.26560267e-06, -4.55237782e-06, 7.54645015e-06, -3.17065262e-07,
0.00000000e+00, -9.17007866e-05, -3.55313493e-04, 1.48658850e-04,
-2.99905782e-05, -9.68237731e-05, 4.47130254e-05, 1.53726948e-05,
-1.41235588e-05, 1.57602711e-06, -3.51287786e-06, -1.05258685e-06,
-2.14237913e-06, 0.00000000e+00, -3.55416724e-04, 9.57510199e-05,
-2.31592515e-04, 1.55081199e-04, 3.97931887e-05, -6.29524645e-05,
1.17195519e-05, 1.15344561e-05, -1.09573381e-05, 9.66159063e-06,
6.70075813e-06, 1.11565939e-05, 0.00000000e+00, -5.44847705e-04,
-2.94447353e-04, -1.14415141e-04, 5.59860566e-06, 4.10113373e-05,
3.51568705e-05, -3.61987404e-05, 4.77649802e-06, -1.77793812e-06,
-4.60374739e-06, -2.71376911e-05, -2.19835560e-05, 0.00000000e+00,
6.04240195e-04, 2.95953162e-04, -1.53984757e-04, -1.37946293e-04,
3.80053900e-07, 5.42775470e-05, 2.16718456e-05, -3.83085098e-05,
4.17106338e-05, -4.29763317e-06, 5.00218959e-05, -2.66433575e-05,
0.00000000e+00, -2.01210283e-04, 2.72924191e-04, 4.89452597e-04,
-9.00784540e-05, 1.47845856e-04, -2.33234209e-05, -2.70819118e-05,
-5.22733231e-05, -3.59423038e-05, 2.28237806e-05, 2.20527970e-05,
1.32913610e-04, 0.00000000e+00, -1.57367536e-04, 3.71725377e-04,
-6.36426311e-04, 2.22300734e-04, -6.49692379e-05, 5.59928928e-05,
-1.45262709e-05, 1.82817611e-04, -1.90225679e-04, -1.65789066e-05,
-3.41635802e-04, 1.42370007e-04, 0.00000000e+00, -2.18280703e-04,
-6.74904801e-04, -1.39052993e-05, 6.33690036e-05, -5.29789071e-04,
1.89630376e-04, -1.60017118e-04, 1.81558635e-04, 3.51937873e-04,
8.63840349e-05, -3.79244210e-04, -8.62999673e-05, 0.00000000e+00,
-1.79922860e-03, -4.21381224e-04, -8.40471962e-04, 7.59434356e-04,
1.55512209e-03, -1.26593821e-03, 5.66615606e-04, -1.23249156e-03,
4.24568541e-04, 3.51304802e-04, -3.26262397e-04, -1.99365207e-04,
0.00000000e+00, -2.07332464e-03, 2.19586433e-03, -2.88732658e-03,
-1.10527210e-02, 4.70896266e-03, 4.78332141e-04, 1.06683627e-03,
-7.17034272e-04, -2.02991704e-04, 4.28040480e-04, -1.65562537e-04,
-4.19786631e-05, 0.00000000e+00, -2.68475500e-03, -5.04561720e-03,
5.05067595e-02, -6.80661202e-03, -8.65891445e-03, 2.79244197e-03,
1.07713589e-03, -3.70605228e-04, -2.57765881e-04, -3.87070707e-05,
-2.61759977e-04, 5.64993554e-05, 0.00000000e+00, 2.35387780e-01,
-1.85125153e-01, -1.15440713e-02, 4.04994737e-03, -1.86158757e-03,
3.03139010e-04, 8.67626564e-04, -1.08117062e-04, -1.49375476e-04,
8.38550268e-05, 2.11667699e-04, -8.97990281e-05, 0.00000000e+00,
0.00000000e+00, 1.37712271e-05, -6.36611914e-06, -1.18681624e-05,
-1.24789101e-04, 2.44036600e-04, 9.45804113e-06, -9.83701436e-04,
-7.80804617e-04, 1.45689524e-03, 4.22713826e-03, -6.25162134e-01,
0.00000000e+00, -8.85719468e-05, 2.52969231e-04, -7.29692775e-06,
-8.38344991e-05, -2.90782661e-05, 8.66129398e-04, -5.38883085e-04,
-1.98479088e-03, 5.02096257e-03, -2.78218540e-02, -2.30947866e-01,
0.00000000e+00, 5.03558841e-05, -2.63815440e-04, -1.82971769e-05,
-2.36363570e-04, -4.43164027e-04, 1.13029804e-03, 2.87552785e-03,
-8.74378594e-03, -8.03135304e-03, 5.04579599e-02, -2.09901887e-02,
0.00000000e+00, -4.14802544e-05, -1.65366676e-04, 4.80830286e-04,
-2.63663437e-04, -6.40520873e-04, 8.43705177e-04, 3.94648880e-04,
5.05183873e-03, -1.09983099e-02, -2.77445960e-03, 4.13869020e-04,
0.00000000e+00, -2.07992208e-04, -3.24063918e-04, 3.02798196e-04,
4.85144482e-04, -1.29914738e-03, 5.40226896e-04, -1.07054306e-03,
1.40360267e-03, 1.20732576e-03, -1.19048363e-03, 7.32864212e-04,
0.00000000e+00, -9.15227898e-05, -3.69086955e-04, 1.06458651e-04,
3.45515089e-04, 2.14998237e-04, -3.81748474e-05, -2.01645712e-04,
-5.70283675e-04, 8.77129950e-05, -4.47129604e-04, 6.91900424e-05,
0.00000000e+00, 1.33523378e-04, -3.38399645e-04, -2.61091588e-05,
-2.05885200e-04, 1.93308397e-04, -5.31859955e-05, 2.45222218e-04,
1.47426092e-05, -2.25792401e-04, 7.56420778e-04, -9.31423439e-04,
0.00000000e+00, 1.26395074e-04, 2.76665231e-05, 1.78997551e-05,
-4.72089225e-05, -6.07456676e-06, 2.26200365e-05, -8.54052921e-05,
-1.00621364e-04, 1.47021789e-04, -4.73474277e-04, -2.13248831e-04,
0.00000000e+00, -3.59572996e-05, 4.30604167e-05, -1.61850437e-05,
4.29341084e-05, -6.52970608e-06, 3.88704850e-05, 3.69296819e-06,
-1.10540268e-04, 1.25347560e-04, 2.51674680e-04, 5.51525894e-04,
0.00000000e+00, -3.79914757e-05, -3.00119956e-05, -4.93139249e-06,
2.28252475e-06, 1.22920907e-05, -3.76491919e-05, 1.68796655e-05,
2.66907520e-05, -1.10558632e-04, 1.75346678e-05, 4.43658096e-04,
0.00000000e+00, -3.45211694e-06, 1.89232570e-06, 1.32004602e-05,
-1.21583278e-05, 1.46353196e-05, 1.39711447e-05, -6.49184566e-05,
8.94858145e-06, 1.07985868e-04, 2.22851574e-06, 2.58804360e-05,
0.00000000e+00, -1.00618400e-05, 3.49319825e-07, -8.77056443e-07,
2.93769886e-06, -1.75338508e-05, 1.92668754e-05, 4.79370077e-05,
-1.14222396e-04, -3.44231084e-05, 1.99798412e-04, 1.38273445e-04,
0.00000000e+00, -1.70913709e-06, 8.69644914e-06, -2.13835896e-06,
-3.74268677e-06, 1.68316270e-05, -2.94087823e-05, 5.84730031e-06,
7.94816640e-05, -4.92286011e-05, -6.37424206e-05, -4.43234863e-05], dtype=float64),
'I': Array([], shape=(0,), dtype=float64),
'G': Array([], shape=(0,), dtype=float64),
'Phi_mn': Array([], shape=(0,), dtype=float64)},)
[50]:
(J,) = obj.jac_scaled(*params)
sum = 0
print("The portion of the Jacobian for")
for key in J.keys():
print(f"\t{key:10} has shape {J[key].shape}")
sum += J[key].shape[1]
print("Total number of parameters that we took the derivative for is", sum)
The portion of the Jacobian for
G has shape (5346, 0)
I has shape (5346, 0)
L_lmn has shape (5346, 1134)
Phi_mn has shape (5346, 0)
Psi has shape (5346, 1)
R_lmn has shape (5346, 1141)
Ra_n has shape (5346, 13)
Rb_lmn has shape (5346, 313)
Te_l has shape (5346, 0)
Ti_l has shape (5346, 0)
Z_lmn has shape (5346, 1134)
Za_n has shape (5346, 12)
Zb_lmn has shape (5346, 312)
Zeff_l has shape (5346, 0)
a_lmn has shape (5346, 0)
c_l has shape (5346, 0)
i_l has shape (5346, 7)
ne_l has shape (5346, 0)
p_l has shape (5346, 7)
Total number of parameters that we took the derivative for is 4074
Alternatively, we can also use the following syntax,
[54]:
(J,) = obj.jac_scaled(eq.params_dict)
J["R_lmn"].shape
[54]:
(5346, 1141)
This way of taking the Jacobian is useful if you need to investigate the effect of individual parameters. However, if you want to get a single Jacobian matrix, the proper way is to use an ObjectiveFunction to wrap the Objective. This can be done as follows,
[56]:
objfun = ObjectiveFunction(ForceBalance(eq))
objfun.build()
J = objfun.jac_scaled(objfun.x(eq))
J.shape
Building objective: force
Precomputing transforms
[56]:
(5346, 4074)
You can see that if we would put individual parts of the previous method, we would get the same Jacobian matrix.
In the code, you will see that we have many functions named like compute_, jac_, vjp_ and jvp_. They are all variations of the original methods, applying some scaling, normalization or bound/target. Here is a brief summary of what they do,
Function |
Purpose |
Full Jacobian |
|---|---|---|
|
Main method to compute the raw objective function. |
|
|
Compute the raw value of the objective, optionally applying a loss function. |
|
|
Compute the objective with weighting and normalization applied. |
|
|
In addition to |
|
|
Compute the scalar value of the objective. \(\mathbf{f}^2/2\) |
|
jvp_ and vjp_ methods compute the derivative in certain directions. These stand for jacobian vector product and vector jacobian product, and they are more efficient than taking the intended column from the full Jacobian matrix. If you look at the implementation of jac_ methods, you will see that we are actually taking jvps in each direction to form the full Jacobian.
@jit
def jac_scaled_error(self, x, constants=None):
"""Compute Jacobian matrix of self.compute_scaled_error wrt x."""
v = jnp.eye(x.shape[0])
return self.jvp_scaled_error(v, x, constants).T
Here, v is the tangents in each unit direction. In the code, we usually don’t take the full Jacobian. For example, LinearConstraintProjection reduces the number of parameters to only operate in the null-space of the constraint matrix, but our compute function still takes the full state vector. So, how do we take the derivative in that case? The solution is a little bit of linear algebra. Let’s consider the following problem.
Since the constraint links \(x_1\) and \(x_2\), the reduced state vector will have only 3 parameters \(\mathbf{y} = [y_1, y_2, y_3]\) and \(y_1=x_1=x_2\), \(y_2=x_2\), \(y_3=x_3\).
Taking the derivative of \(f\) with respect to \(y_2\) and \(y_3\) is straight-forward. But when we are taking the derivative in \(y_1\) both \(x_1\) and \(x_2\) are changing, so we have to take the derivative in both directions. In this simple example deciding which parameters are free and which are dependent was easy. However, for more complex linear constraints, the more systematic way is to use the null-space matrix \(Z\). If we want to take the derivative in \(y_1\) direction, the tangent vector in reduced space is \(\mathbf{v}_r = [1, 0, 0]\), and the tangent vector in full space is \(Z\mathbf{v}_r\). We have a handy utility function to calculate the pseudo-inverse and null-space of a matrix. Here is how we get the full tangent direction for the simple example.
[66]:
import numpy as np
from desc.utils import svd_inv_null
A = np.array([[1.0, -1.0, 0.0, 0.0]])
Ainv, Z = svd_inv_null(A)
vr = np.array([1, 0, 0])
print("Full tangent: ", Z @ vr)
print("Null-space:\n", Z)
Full tangent: [0.70710678 0.70710678 0. 0. ]
Null-space:
[[0.70710678 0. 0. ]
[0.70710678 0. 0. ]
[0. 1. 0. ]
[0. 0. 1. ]]