essos.augmented_lagrangian

ALM (Augmented Lagrangian Method) using JAX and optimizers from OPTAX/JAXOPT/OPTIMISTIX inspired by mdmm_jax github repository

Classes

LagrangeMultiplier

A class containing constrain parameters for Augmented Lagrangian Method

Constraint

A pair of pure functions implementing a constraint.

ALM

Functions

update_method(params, updates, eta, omega[, model_mu, ...])

Different methods for updating multipliers and penalties

update_method_squared(params, updates, eta, omega[, ...])

Different methods for updating multipliers and penalties)

lagrange_update([model_lagrangian])

A gradient transformation for Optax that prepares an MDMM gradient

eq(fun[, model_lagrangian, multiplier, penalty, ...])

Represents an equality constraint, g(x) = 0.

ineq(fun[, model_lagrangian, multiplier, penalty, ...])

Represents an inequality constraint, h(x) >= 0, which uses a slack

combine(*args)

Combines multiple constraint tuples into a single constraint tuple.

total_infeasibility(tree)

norm_constraints(tree)

infty_norm_constraints(tree)

penalty_average(tree)

ALM_model_optax(optimizer, constraints[, loss, ...])

ALM_model_jaxopt_lbfgsb(constraints[, loss, ...])

ALM_model_jaxopt_LevenbergMarquardt(constraints[, ...])

ALM_model_jaxopt_lbfgs(constraints[, loss, ...])

ALM_model_optimistix_LevenbergMarquardt(constraints[, ...])

Module Contents

class essos.augmented_lagrangian.LagrangeMultiplier

Bases: NamedTuple

A class containing constrain parameters for Augmented Lagrangian Method

value: Any
penalty: Any
sq_grad: Any
essos.augmented_lagrangian.update_method(params, updates, eta, omega, model_mu='Constant', beta=2.0, mu_max=10000.0, alpha=0.99, gamma=0.01, epsilon=1e-08, eta_tol=0.0001, omega_tol=1e-06)

Different methods for updating multipliers and penalties

essos.augmented_lagrangian.update_method_squared(params, updates, eta, omega, model_mu='Constant', beta=2.0, mu_max=10000.0, alpha=0.99, gamma=0.01, epsilon=1e-08, eta_tol=0.0001, omega_tol=1e-06)

Different methods for updating multipliers and penalties)

essos.augmented_lagrangian.lagrange_update(model_lagrangian='Standard')

A gradient transformation for Optax that prepares an MDMM gradient descent ascent update from a normal gradient descent update.

It should be used like this with a base optimizer:
optimizer = optax.chain(

optax.sgd(1e-3), mdmm_jax.optax_prepare_update(),

)

Returns:

An Optax gradient transformation that converts a gradient descent update into a gradient descent ascent update.

class essos.augmented_lagrangian.Constraint

Bases: NamedTuple

A pair of pure functions implementing a constraint.

init

A pure function which, when called with an example instance of the arguments to the constraint functions, returns a pytree containing the constraint’s learnable parameters.

loss

A pure function which, when called with the the learnable parameters returned by init() followed by the arguments to the constraint functions, returns the loss value for the constraint.

init: Callable
loss: Callable
essos.augmented_lagrangian.eq(fun, model_lagrangian='Standard', multiplier=0.0, penalty=1.0, sq_grad=0.0, weight=1.0, reduction=jnp.sum)

Represents an equality constraint, g(x) = 0.

Parameters:
  • fun – The constraint function, a differentiable function of your parameters which should output zero when satisfied and smoothly increasingly far from zero values for increasing levels of constraint violation.

  • damping – Sets the damping (oscillation reduction) strength.

  • weight – Weights the loss from the constraint relative to the primary loss function’s value.

  • reduction – The function that is used to aggregate the constraints if the constraint function outputs more than one element.

Returns:

An (init_fn, loss_fn) constraint tuple for the equality constraint.

essos.augmented_lagrangian.ineq(fun, model_lagrangian='Standard', multiplier=0.0, penalty=1.0, sq_grad=0.0, weight=1.0, reduction=jnp.sum)

Represents an inequality constraint, h(x) >= 0, which uses a slack variable internally to convert it to an equality constraint.

Parameters:
  • fun – The constraint function, a differentiable function of your parameters which should output greater than or equal to zero when satisfied and smoothly increasingly negative values for increasing levels of constraint violation.

  • damping – Sets the damping (oscillation reduction) strength.

  • weight – Weights the loss from the constraint relative to the primary loss function’s value.

  • reduction – The function that is used to aggregate the constraints if the constraint function outputs more than one element.

Returns:

An (init_fn, loss_fn) constraint tuple for the inequality constraint.

essos.augmented_lagrangian.combine(*args)

Combines multiple constraint tuples into a single constraint tuple.

Parameters:

*args – A series of constraint (init_fn, loss_fn) tuples.

Returns:

A single (init_fn, loss_fn) tuple that wraps the input constraints.

essos.augmented_lagrangian.total_infeasibility(tree)
essos.augmented_lagrangian.norm_constraints(tree)
essos.augmented_lagrangian.infty_norm_constraints(tree)
essos.augmented_lagrangian.penalty_average(tree)
class essos.augmented_lagrangian.ALM

Bases: NamedTuple

init: Callable
update: Callable
essos.augmented_lagrangian.ALM_model_optax(optimizer: optax.GradientTransformation, constraints: Constraint, loss=lambda x: ..., model_lagrangian='Standard', model_mu='Constant', beta=2.0, mu_max=10000.0, alpha=0.99, gamma=0.01, epsilon=1e-08, eta_tol=0.0001, omega_tol=1e-06, **kargs)
essos.augmented_lagrangian.ALM_model_jaxopt_lbfgsb(constraints: Constraint, loss=lambda x: ..., model_lagrangian='Standard', beta=2.0, mu_max=10000.0, alpha=0.99, gamma=0.01, epsilon=1e-08, eta_tol=0.0001, omega_tol=1e-06, **kargs)
essos.augmented_lagrangian.ALM_model_jaxopt_LevenbergMarquardt(constraints: Constraint, loss=lambda x: ..., beta=2.0, mu_max=10000.0, alpha=0.99, gamma=0.01, epsilon=1e-08, eta_tol=0.0001, omega_tol=1e-06, **kargs)
essos.augmented_lagrangian.ALM_model_jaxopt_lbfgs(constraints: Constraint, loss=lambda x: ..., model_lagrangian='Standard', beta=2.0, mu_max=10000.0, alpha=0.99, gamma=0.01, epsilon=1e-08, eta_tol=0.0001, omega_tol=1e-06, **kargs)
essos.augmented_lagrangian.ALM_model_optimistix_LevenbergMarquardt(constraints: Constraint, loss=lambda x: ..., beta=2.0, mu_max=10000.0, alpha=0.99, gamma=0.01, epsilon=1e-08, eta_tol=0.0001, omega_tol=1e-06, **kargs)