Warm starting

Slides

Now we use the dataset “madelon” and compare the time executed with and without warm-start technique

[ ]:
import time

import numpy as np
from matplotlib import pyplot as plt
from sklearn.preprocessing import StandardScaler

from rehline import plqERM_Ridge_path_sol
[ ]:
import openml

## load the dataset
dataset = openml.datasets.get_dataset(1485)

X, y, _, _ = dataset.get_data(target=dataset.default_target_attribute)

y = np.where(y == "2", 1, -1)
scaler = StandardScaler()
X = scaler.fit_transform(X)
[ ]:
## define the loss function and value of Cs
loss = {"name": "svm"}
Cs = np.logspace(0, 2, 10)
[ ]:
## we first try solver without warm start

print("\nRunning solver WITHOUT warm start...")
start_no_warm = time.time()
Cs_no_warm, times_no_warm, n_iters_no_warm, loss_no_warm, L2_no_warm, coefs_no_warm = plqERM_Ridge_path_sol(
    X,
    y,
    loss=loss,
    Cs=Cs,
    max_iter=1000000,
    tol=1e-4,
    verbose=1,
    warm_start=False,
    return_time=True,
)
end_no_warm = time.time()
total_time_no_warm = end_no_warm - start_no_warm

Running solver WITHOUT warm start...
/usr/local/lib/python3.11/dist-packages/rehline/_class.py:419: ConvergenceWarning: ReHLine failed to converge, increase the number of iterations: `max_iter`.
  warnings.warn(
/usr/local/lib/python3.11/dist-packages/rehline/_class.py:419: ConvergenceWarning: ReHLine failed to converge, increase the number of iterations: `max_iter`.
  warnings.warn(
/usr/local/lib/python3.11/dist-packages/rehline/_class.py:419: ConvergenceWarning: ReHLine failed to converge, increase the number of iterations: `max_iter`.
  warnings.warn(
/usr/local/lib/python3.11/dist-packages/rehline/_class.py:419: ConvergenceWarning: ReHLine failed to converge, increase the number of iterations: `max_iter`.
  warnings.warn(
/usr/local/lib/python3.11/dist-packages/rehline/_class.py:419: ConvergenceWarning: ReHLine failed to converge, increase the number of iterations: `max_iter`.
  warnings.warn(

PLQ ERM Path Solution Results
==========================================================================================
C Value        Iterations     Time (s)            Loss                L2 Norm
------------------------------------------------------------------------------------------
1              118636         39.936846           1594.670300         2.543800
1.668          192306         64.421898           1594.737600         2.679600
2.783          485638         151.079622          1594.824000         2.742100
4.642          700475         211.876613          1594.906600         2.783700
7.743          1000000        299.254036          1595.525600         2.815700
12.92          1000000        304.475502          1596.999200         2.878100
21.54          1000000        312.734514          1629.704900         2.913800
35.94          1000000        323.737584          1631.008400         2.961700
59.95          1000000        336.862036          1606.801500         2.892900
100            1000000        363.869178          1605.223100         2.783600
==========================================================================================
Total Time  2408.248485 sec
Avg Time/Iter0.000321 sec
==========================================================================================
/usr/local/lib/python3.11/dist-packages/rehline/_class.py:419: ConvergenceWarning: ReHLine failed to converge, increase the number of iterations: `max_iter`.
  warnings.warn(
[ ]:
## then try solver with warm start

print("\nRunning solver WITH warm start...")
start_warm = time.time()
Cs_warm, times_warm, n_iters_warm, loss_warm, L2_warm, coefs_warm = plqERM_Ridge_path_sol(
    X,
    y,
    loss=loss,
    Cs=Cs,
    max_iter=1000000,
    tol=1e-4,
    verbose=1,
    warm_start=True,
    return_time=True,
)
end_warm = time.time()
total_time_warm = end_warm - start_warm

Running solver WITH warm start...

PLQ ERM Path Solution Results
==========================================================================================
C Value        Iterations     Time (s)            Loss                L2 Norm
------------------------------------------------------------------------------------------
1              118636         40.181799           1594.670300         2.543800
1.668          104841         31.897681           1594.740700         2.682700
2.783          202547         61.151798           1594.831900         2.746200
4.642          126871         37.826569           1594.896400         2.779400
7.743          248854         73.859064           1595.001200         2.822700
12.92          285847         85.796846           1595.117000         2.869000
21.54          465084         139.253013          1595.146600         2.879800
35.94          190967         56.896123           1595.165800         2.885900
59.95          399966         118.717955          1595.165600         2.886400
100            353164         107.564119          1595.166100         2.886800
==========================================================================================
Total Time  753.145267 sec
Avg Time/Iter0.000302 sec
==========================================================================================
[ ]:
## print the comparison and summary

print("\nComparison: Warm Start vs. No Warm Start")
print("=" * 90)
print(f"{'C Value':<12}{'Iter_without_WS':<18}{'Time_without_WS':<20}{'Iter_WS':<18}{'Time_WS':<15}")
print("-" * 90)

for C, iters_nw, time_nw, iters_w, time_w in zip(Cs, n_iters_no_warm, times_no_warm, n_iters_warm, times_warm):
    print(f"{C:<12.4g}{iters_nw:<18}{time_nw:<20.6f}{iters_w:<18}{time_w:<15.6f}")

print("=" * 90)
print(f"{'Total Time (without WS)':<30}{total_time_no_warm:.6f} sec")
print(f"{'Total Time (WS)':<30}{total_time_warm:.6f} sec")
print(f"{'Speedup with Warm Start':<30}{total_time_no_warm / total_time_warm:.2f}x Faster")
print("=" * 90)

Comparison: Warm Start vs. No Warm Start
==========================================================================================
C Value     Iter_without_WS   Time_without_WS     Iter_WS           Time_WS
------------------------------------------------------------------------------------------
1           118636            39.936846           118636            40.181799
1.668       192306            64.421898           104841            31.897681
2.783       485638            151.079622          202547            61.151798
4.642       700475            211.876613          126871            37.826569
7.743       1000000           299.254036          248854            73.859064
12.92       1000000           304.475502          285847            85.796846
21.54       1000000           312.734514          465084            139.253013
35.94       1000000           323.737584          190967            56.896123
59.95       1000000           336.862036          399966            118.717955
100         1000000           363.869178          353164            107.564119
==========================================================================================
Total Time (without WS)       2408.250201 sec
Total Time (WS)               753.150022 sec
Speedup with Warm Start       3.20x Faster
==========================================================================================

Time comparison

The above summary table shows the warm-start solver is 3.20 times faster than the cold-start solver.

The warm-start solver also consistently reaches optimal results within the maximum iteration limit, which the cold-start solver doesn’t. This means the actual speedup from warm-starting is likely even higher.

[ ]:
## visualize the time used for each solver

plt.figure(figsize=(8, 6))
plt.plot(Cs, times_no_warm, label="No Warm Start", marker="o", linestyle="--")
plt.plot(Cs, times_warm, label="With Warm Start", marker="s", linestyle="-")
plt.xscale("log")
plt.xlabel("Regularization Parameter C")
plt.ylabel("Time (seconds)")
plt.title("Warm Start vs. No Warm Start")
plt.legend()
plt.grid(True)
plt.show()
../_images/examples_Warm_start_9_0.png