essos.augmented_lagrangian¶
ALM (Augmented Lagrangian Method) using JAX and optimizers from OPTAX/JAXOPT/OPTIMISTIX inspired by mdmm_jax github repository
Classes¶
A class containing constrain parameters for Augmented Lagrangian Method |
|
A pair of pure functions implementing a constraint. |
|
Functions¶
|
Different methods for updating multipliers and penalties |
|
Different methods for updating multipliers and penalties) |
|
A gradient transformation for Optax that prepares an MDMM gradient |
|
Represents an equality constraint, g(x) = 0. |
|
Represents an inequality constraint, h(x) >= 0, which uses a slack |
|
Combines multiple constraint tuples into a single constraint tuple. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Module Contents¶
- class essos.augmented_lagrangian.LagrangeMultiplier¶
Bases:
NamedTupleA class containing constrain parameters for Augmented Lagrangian Method
- value: Any¶
- penalty: Any¶
- sq_grad: Any¶
- essos.augmented_lagrangian.update_method(params, updates, eta, omega, model_mu='Constant', beta=2.0, mu_max=10000.0, alpha=0.99, gamma=0.01, epsilon=1e-08, eta_tol=0.0001, omega_tol=1e-06)¶
Different methods for updating multipliers and penalties
- essos.augmented_lagrangian.update_method_squared(params, updates, eta, omega, model_mu='Constant', beta=2.0, mu_max=10000.0, alpha=0.99, gamma=0.01, epsilon=1e-08, eta_tol=0.0001, omega_tol=1e-06)¶
Different methods for updating multipliers and penalties)
- essos.augmented_lagrangian.lagrange_update(model_lagrangian='Standard')¶
A gradient transformation for Optax that prepares an MDMM gradient descent ascent update from a normal gradient descent update.
- It should be used like this with a base optimizer:
- optimizer = optax.chain(
optax.sgd(1e-3), mdmm_jax.optax_prepare_update(),
)
- Returns:
An Optax gradient transformation that converts a gradient descent update into a gradient descent ascent update.
- class essos.augmented_lagrangian.Constraint¶
Bases:
NamedTupleA pair of pure functions implementing a constraint.
- init¶
A pure function which, when called with an example instance of the arguments to the constraint functions, returns a pytree containing the constraint’s learnable parameters.
- loss¶
A pure function which, when called with the the learnable parameters returned by init() followed by the arguments to the constraint functions, returns the loss value for the constraint.
- init: Callable¶
- loss: Callable¶
- essos.augmented_lagrangian.eq(fun, model_lagrangian='Standard', multiplier=0.0, penalty=1.0, sq_grad=0.0, weight=1.0, reduction=jnp.sum)¶
Represents an equality constraint, g(x) = 0.
- Parameters:
fun – The constraint function, a differentiable function of your parameters which should output zero when satisfied and smoothly increasingly far from zero values for increasing levels of constraint violation.
damping – Sets the damping (oscillation reduction) strength.
weight – Weights the loss from the constraint relative to the primary loss function’s value.
reduction – The function that is used to aggregate the constraints if the constraint function outputs more than one element.
- Returns:
An (init_fn, loss_fn) constraint tuple for the equality constraint.
- essos.augmented_lagrangian.ineq(fun, model_lagrangian='Standard', multiplier=0.0, penalty=1.0, sq_grad=0.0, weight=1.0, reduction=jnp.sum)¶
Represents an inequality constraint, h(x) >= 0, which uses a slack variable internally to convert it to an equality constraint.
- Parameters:
fun – The constraint function, a differentiable function of your parameters which should output greater than or equal to zero when satisfied and smoothly increasingly negative values for increasing levels of constraint violation.
damping – Sets the damping (oscillation reduction) strength.
weight – Weights the loss from the constraint relative to the primary loss function’s value.
reduction – The function that is used to aggregate the constraints if the constraint function outputs more than one element.
- Returns:
An (init_fn, loss_fn) constraint tuple for the inequality constraint.
- essos.augmented_lagrangian.combine(*args)¶
Combines multiple constraint tuples into a single constraint tuple.
- Parameters:
*args – A series of constraint (init_fn, loss_fn) tuples.
- Returns:
A single (init_fn, loss_fn) tuple that wraps the input constraints.
- essos.augmented_lagrangian.total_infeasibility(tree)¶
- essos.augmented_lagrangian.norm_constraints(tree)¶
- essos.augmented_lagrangian.infty_norm_constraints(tree)¶
- essos.augmented_lagrangian.penalty_average(tree)¶
- essos.augmented_lagrangian.ALM_model_optax(optimizer: optax.GradientTransformation, constraints: Constraint, loss=lambda x: ..., model_lagrangian='Standard', model_mu='Constant', beta=2.0, mu_max=10000.0, alpha=0.99, gamma=0.01, epsilon=1e-08, eta_tol=0.0001, omega_tol=1e-06, **kargs)¶
- essos.augmented_lagrangian.ALM_model_jaxopt_lbfgsb(constraints: Constraint, loss=lambda x: ..., model_lagrangian='Standard', beta=2.0, mu_max=10000.0, alpha=0.99, gamma=0.01, epsilon=1e-08, eta_tol=0.0001, omega_tol=1e-06, **kargs)¶
- essos.augmented_lagrangian.ALM_model_jaxopt_LevenbergMarquardt(constraints: Constraint, loss=lambda x: ..., beta=2.0, mu_max=10000.0, alpha=0.99, gamma=0.01, epsilon=1e-08, eta_tol=0.0001, omega_tol=1e-06, **kargs)¶
- essos.augmented_lagrangian.ALM_model_jaxopt_lbfgs(constraints: Constraint, loss=lambda x: ..., model_lagrangian='Standard', beta=2.0, mu_max=10000.0, alpha=0.99, gamma=0.01, epsilon=1e-08, eta_tol=0.0001, omega_tol=1e-06, **kargs)¶
- essos.augmented_lagrangian.ALM_model_optimistix_LevenbergMarquardt(constraints: Constraint, loss=lambda x: ..., beta=2.0, mu_max=10000.0, alpha=0.99, gamma=0.01, epsilon=1e-08, eta_tol=0.0001, omega_tol=1e-06, **kargs)¶