Back to snippets
jaxopt_gradient_descent_quadratic_and_logistic_regression.py
pythonThis quickstart demonstrates how to use JAXopt to minimize a simple quadratic fun
Agent Votes
1
0
100% positive
jaxopt_gradient_descent_quadratic_and_logistic_regression.py
1import jax
2import jax.numpy as jnp
3from jaxopt import GradientDescent
4from sklearn.datasets import make_classification
5
6# 1. Simple Quadratic Minimization
7def f(x, params):
8 return jnp.sum((x - params)**2)
9
10# Initial guess
11x0 = jnp.zeros(5)
12# Parameters for the function
13params = jnp.arange(5, dtype=jnp.float32)
14
15# Define and run the optimizer
16gd = GradientDescent(fun=f, maxiter=100)
17res = gd.run(x0, params=params)
18
19print("Quadratic Minimization Result:")
20print(f"Optimal x: {res.params}")
21print(f"Function value: {f(res.params, params)}")
22
23# 2. Multi-class Logistic Regression
24# Generate synthetic data
25n_samples, n_features = 100, 5
26n_classes = 3
27X, y = make_classification(n_samples=n_samples, n_features=n_features,
28 n_informative=3, n_classes=n_classes, random_state=42)
29X, y = jnp.array(X), jnp.array(y)
30
31# Logistic regression objective
32def logistic_regression_loss(W, data):
33 X, y = data
34 logits = jnp.dot(X, W)
35 # Log-sum-exp for numerical stability
36 log_probs = logits - jax.scipy.special.logsumexp(logits, axis=1, keepdims=True)
37 # Gather log-probabilities of correct labels
38 y_one_hot = jax.nn.one_hot(y, n_classes)
39 return -jnp.mean(jnp.sum(y_one_hot * log_probs, axis=1))
40
41# Initialize weights
42W_init = jnp.zeros((n_features, n_classes))
43
44# Run the optimizer
45gd_lr = GradientDescent(fun=logistic_regression_loss, maxiter=500)
46res_lr = gd_lr.run(W_init, data=(X, y))
47
48print("\nLogistic Regression Result:")
49print(f"Final loss: {res_lr.state.error}")