# ruff: noqa: E402
Linear regression with ESMDA#
We solve a linear regression problem using ESMDA. First we define the forward model as \(g(x) = Ax\), then we set up a prior ensemble on the linear regression coefficients, so \(x \sim \mathcal{N}(0, 1)\).
As shown in the 2013 paper by Emerick et al, when a set of inflation weights \(\alpha_i\) is chosen so that \(\sum_i \alpha_i^{-1} = 1\), ESMDA yields the correct posterior mean for the linear-Gaussian case.
Import packages#
import numpy as np
from matplotlib import pyplot as plt
from iterative_ensemble_smoother import ESMDA
Matplotlib is building the font cache; this may take a moment.
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[2], line 2
1 import numpy as np
----> 2 from matplotlib import pyplot as plt
4 from iterative_ensemble_smoother import ESMDA
File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/pyplot.py:56
54 from cycler import cycler
55 import matplotlib
---> 56 import matplotlib.colorbar
57 import matplotlib.image
58 from matplotlib import _api
File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/colorbar.py:19
16 import numpy as np
18 import matplotlib as mpl
---> 19 from matplotlib import _api, cbook, collections, cm, colors, contour, ticker
20 import matplotlib.artist as martist
21 import matplotlib.patches as mpatches
File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/contour.py:15
13 import matplotlib as mpl
14 from matplotlib import _api, _docstring
---> 15 from matplotlib.backend_bases import MouseButton
16 from matplotlib.lines import Line2D
17 from matplotlib.path import Path
File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/backend_bases.py:46
43 import numpy as np
45 import matplotlib as mpl
---> 46 from matplotlib import (
47 _api, backend_tools as tools, cbook, colors, _docstring, text,
48 _tight_bbox, transforms, widgets, is_interactive, rcParams)
49 from matplotlib._pylab_helpers import Gcf
50 from matplotlib.backend_managers import ToolManager
File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/text.py:16
14 from . import _api, artist, cbook, _docstring
15 from .artist import Artist
---> 16 from .font_manager import FontProperties
17 from .patches import FancyArrowPatch, FancyBboxPatch, Rectangle
18 from .textpath import TextPath, TextToPath # noqa # Logically located here
File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/font_manager.py:1582
1578 _log.info("generated new fontManager")
1579 return fm
-> 1582 fontManager = _load_fontmanager()
1583 findfont = fontManager.findfont
1584 get_font_names = fontManager.get_font_names
File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/font_manager.py:1576, in _load_fontmanager(try_read_cache)
1574 _log.debug("Using fontManager instance from %s", fm_path)
1575 return fm
-> 1576 fm = FontManager()
1577 json_dump(fm, fm_path)
1578 _log.info("generated new fontManager")
File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/font_manager.py:1043, in FontManager.__init__(self, size, weight)
1040 for path in [*findSystemFonts(paths, fontext=fontext),
1041 *findSystemFonts(fontext=fontext)]:
1042 try:
-> 1043 self.addfont(path)
1044 except OSError as exc:
1045 _log.info("Failed to open font file %s: %s", path, exc)
File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/font_manager.py:1076, in FontManager.addfont(self, path)
1074 self.afmlist.append(prop)
1075 else:
-> 1076 font = ft2font.FT2Font(path)
1077 prop = ttfFontProperty(font)
1078 self.ttflist.append(prop)
KeyboardInterrupt:
Create problem data#
Some settings worth experimenting with:
Decreasing
prior_std=1will pull the posterior solution toward zero.Increasing
num_ensemblewill increase the quality of the solution.Increasing
num_observations / num_parameterswill increase the quality of the solution.
num_parameters = 25
num_observations = 100
num_ensemble = 30
prior_std = 1
rng = np.random.default_rng(42)
# Create a problem with g(x) = A @ x
A = rng.standard_normal(size=(num_observations, num_parameters))
def g(X):
"""Forward model."""
return A @ X
# Create observations: obs = g(x) + N(0, 1)
x_true = np.linspace(-1, 1, num=num_parameters)
observation_noise = rng.standard_normal(size=num_observations)
observations = g(x_true) + observation_noise
# Initial ensemble X ~ N(0, prior_std) and diagonal covariance with ones
X = rng.normal(size=(num_parameters, num_ensemble)) * prior_std
# Covariance matches the noise added to observations above
covariance = np.ones(num_observations)
Solve the maximum likelihood problem#
We can solve \(Ax = b\), where \(b\) is the observations, for the maximum likelihood estimate. Notice that unlike using a Ridge model, solving \(Ax = b\) directly does not use any prior information.
x_ml, *_ = np.linalg.lstsq(A, observations, rcond=None)
plt.figure(figsize=(8, 3))
plt.scatter(np.arange(len(x_true)), x_true, label="True parameter values")
plt.scatter(np.arange(len(x_true)), x_ml, label="ML estimate (no prior)")
plt.xlabel("Parameter index")
plt.ylabel("Parameter value")
plt.grid(True, ls="--", zorder=0, alpha=0.33)
plt.legend()
plt.show()
Solve using ESMDA#
We crease an ESMDA instance and solve the Guass-linear problem.
smoother = ESMDA(
covariance=covariance,
observations=observations,
alpha=5,
seed=1,
)
X_i = np.copy(X)
for i, alpha_i in enumerate(smoother.alpha, 1):
print(
f"ESMDA iteration {i}/{smoother.num_assimilations()}"
+ f" with inflation factor alpha_i={alpha_i}"
)
X_i = smoother.assimilate(X_i, Y=g(X_i))
X_posterior = np.copy(X_i)
Plot and compare solutions#
Compare the true parameters with both the ML estimate
from linear regression and the posterior means obtained using ESMDA.
plt.figure(figsize=(8, 3))
plt.scatter(np.arange(len(x_true)), x_true, label="True parameter values")
plt.scatter(np.arange(len(x_true)), x_ml, label="ML estimate (no prior)")
plt.scatter(
np.arange(len(x_true)), np.mean(X_posterior, axis=1), label="Posterior mean"
)
plt.xlabel("Parameter index")
plt.ylabel("Parameter value")
plt.grid(True, ls="--", zorder=0, alpha=0.33)
plt.legend()
plt.show()
We now include the posterior samples as well.
plt.figure(figsize=(8, 3))
plt.scatter(np.arange(len(x_true)), x_true, label="True parameter values")
plt.scatter(np.arange(len(x_true)), x_ml, label="ML estimate (no prior)")
plt.scatter(
np.arange(len(x_true)), np.mean(X_posterior, axis=1), label="Posterior mean"
)
# Loop over every ensemble member and plot it
for j in range(num_ensemble):
# Jitter along the x-axis a little bit
x_jitter = np.arange(len(x_true)) + rng.normal(loc=0, scale=0.1, size=len(x_true))
# Plot this ensemble member
plt.scatter(
x_jitter,
X_posterior[:, j],
label=("Posterior values" if j == 0 else None),
color="black",
alpha=0.2,
s=5,
zorder=0,
)
plt.xlabel("Parameter index")
plt.ylabel("Parameter value")
plt.grid(True, ls="--", zorder=0, alpha=0.33)
plt.legend()
plt.show()