desc.derivatives.AutoDiffDerivative
- class desc.derivatives.AutoDiffDerivative(fun, argnum=0, mode='fwd', chunk_size=None, **kwargs)Source
Computes derivatives using automatic differentiation with JAX.
- Parameters:
fun (callable) – Function to be differentiated.
argnum (int, optional) – Specifies which positional argument to differentiate with respect to
mode (str, optional) – Automatic differentiation mode. One of
'fwd'(forward mode Jacobian),'rev'(reverse mode Jacobian),'grad'(gradient of a scalar function),'hess'(Hessian of a scalar function), or'jvp'(Jacobian vector product) Default ='fwd'
- Raises:
ValueError, if mode is not supported –
Methods
compute(*args, **kwargs)Compute the derivative matrix.
compute_jvp(fun, argnum, v, *args, **kwargs)Compute df/dx*v.
compute_jvp2(fun, argnum1, argnum2, v1, v2, ...)Compute d^2f/dx^2*v1*v2.
compute_jvp3(fun, argnum1, argnum2, argnum3, ...)Compute d^3f/dx^3*v1*v2*v3.
compute_vjp(fun, argnum, v, *args, **kwargs)Compute v.T * df/dx.
Attributes
argument being differentiated with respect to.
function being differentiated.
the kind of derivative being computed (eg
'grad').