desc.optimize.sgd

class desc.optimize.sgd(fun, x0, grad, args=(), method='sgd', x_scale='auto', ftol=1e-06, xtol=1e-06, gtol=1e-06, verbose=1, maxiter=None, callback=None, options=None)Source

Minimize a scalar function using one of stochastic gradient descent methods.

This is the generic function. The update method is chosen based on the method argument.

Update rule for 'sgd',

v_k = β * v_{k-1} + (1-β) * ∇f(x_k)

x_{k+1} = x_k - α * v_k

where α is the step size and β is the momentum parameter.

Additionally, optax optimizers can be used by specifying the method as 'optax-<optimizer_name>', where <optimizer_name> is any valid optax optimizer. Hyperparameters for the optax optimizer must be passed via the optax-options key of options dictionary.

Parameters:
  • fun (callable) – objective to be minimized. Should have a signature like fun(x,*args)-> float

  • x0 (array-like) – initial guess

  • grad (callable) – function to compute gradient, df/dx. Should take the same arguments as fun

  • args (tuple) – additional arguments passed to fun and grad

  • method (str) – Name of the method to use. Available options are ‘sgd’. Additionally, optax optimizers can be used by specifying the method as 'optax-<optimizer_name>', where <optimizer_name> is any valid optax optimizer. Hyperparameters for the optax optimizer must be passed via the 'optax-options' key of options dictionary. A custom optax optimizer can be used by specifying the method as 'optax-custom' and passing the optax optimizer via the 'update-rule' key of 'optax-options' in the options dictionary.

  • x_scale (array_like or 'auto', optional) – Characteristic scale of each variable. Setting x_scale is equivalent to reformulating the problem in scaled variables xs = x / x_scale. Improved convergence may be achieved by setting x_scale such that a step of a given size along any of the scaled variables has a similar effect on the cost function. Defaults to ‘auto’, meaning no scaling.

  • ftol (float or None, optional) – Tolerance for termination by the change of the cost function. The optimization process is stopped when dF < ftol * F.

  • xtol (float or None, optional) – Tolerance for termination by the change of the independent variables. Optimization is stopped when norm(dx) < xtol * (xtol + norm(x)). If None, the termination by this condition is disabled.

  • gtol (float or None, optional) – Absolute tolerance for termination by the norm of the gradient. Optimizer terminates when max(abs(g)) < gtol. If None, the termination by this condition is disabled.

  • verbose ({0, 1, 2}, optional) –

    • 0 : work silently.

    • 1 (default) : display a termination report.

    • 2 : display progress during iterations

  • maxiter (int, optional) – maximum number of iterations. Defaults to size(x)*100

  • callback (callable, optional) –

    Called after each iteration. Should be a callable with the signature:

    callback(xk, *args) -> bool

    where xk is the current parameter vector. and args are the same arguments passed to fun and grad. If callback returns True the algorithm execution is terminated.

  • options (dict, optional) –

    dictionary of optional keyword arguments to override default solver settings for the update rule chosen.

    • "alpha" : (float > 0) Learning rate. Defaults to 1e-2 * ||x_scaled|| / ||g_scaled||.

    • "beta" : (float > 0) Exponential decay rate for the first moment estimates. Default 0.9.

    For optax optimizers, hyperparameters specific to the chosen optimizer must be passed via the optax-options key of options dictionary.

Returns:

res (OptimizeResult) – The optimization result represented as a OptimizeResult object. Important attributes are: x the solution array, success a Boolean flag indicating if the optimizer exited successfully.

Examples

One can use custom optax optimizers as follows:

import optax
from desc.optimize import Optimizer
from desc.examples import get

eq = get("DSHAPE")

# Optimizer
opt = optax.chain(
    optax.sgd(learning_rate=1.0),
    optax.scale_by_zoom_linesearch(max_linesearch_steps=15),
)
optimizer = Optimizer("optax-custom")
eq.solve(optimizer=optimizer, options={"optax-options": {"update_rule": opt}})