# ruff: noqa: E402

Linear regression with ESMDA#

We solve a linear regression problem using ESMDA. First we define the forward model as \(g(x) = Ax\), then we set up a prior ensemble on the linear regression coefficients, so \(x \sim \mathcal{N}(0, 1)\).

As shown in the 2013 paper by Emerick et al, when a set of inflation weights \(\alpha_i\) is chosen so that \(\sum_i \alpha_i^{-1} = 1\), ESMDA yields the correct posterior mean for the linear-Gaussian case.

Import packages#

import numpy as np
from matplotlib import pyplot as plt

from iterative_ensemble_smoother import ESMDA
Matplotlib is building the font cache; this may take a moment.
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[2], line 2
      1 import numpy as np
----> 2 from matplotlib import pyplot as plt
      4 from iterative_ensemble_smoother import ESMDA

File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/pyplot.py:56
     54 from cycler import cycler
     55 import matplotlib
---> 56 import matplotlib.colorbar
     57 import matplotlib.image
     58 from matplotlib import _api

File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/colorbar.py:19
     16 import numpy as np
     18 import matplotlib as mpl
---> 19 from matplotlib import _api, cbook, collections, cm, colors, contour, ticker
     20 import matplotlib.artist as martist
     21 import matplotlib.patches as mpatches

File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/contour.py:15
     13 import matplotlib as mpl
     14 from matplotlib import _api, _docstring
---> 15 from matplotlib.backend_bases import MouseButton
     16 from matplotlib.lines import Line2D
     17 from matplotlib.path import Path

File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/backend_bases.py:46
     43 import numpy as np
     45 import matplotlib as mpl
---> 46 from matplotlib import (
     47     _api, backend_tools as tools, cbook, colors, _docstring, text,
     48     _tight_bbox, transforms, widgets, is_interactive, rcParams)
     49 from matplotlib._pylab_helpers import Gcf
     50 from matplotlib.backend_managers import ToolManager

File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/text.py:16
     14 from . import _api, artist, cbook, _docstring
     15 from .artist import Artist
---> 16 from .font_manager import FontProperties
     17 from .patches import FancyArrowPatch, FancyBboxPatch, Rectangle
     18 from .textpath import TextPath, TextToPath  # noqa # Logically located here

File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/font_manager.py:1582
   1578     _log.info("generated new fontManager")
   1579     return fm
-> 1582 fontManager = _load_fontmanager()
   1583 findfont = fontManager.findfont
   1584 get_font_names = fontManager.get_font_names

File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/font_manager.py:1576, in _load_fontmanager(try_read_cache)
   1574             _log.debug("Using fontManager instance from %s", fm_path)
   1575             return fm
-> 1576 fm = FontManager()
   1577 json_dump(fm, fm_path)
   1578 _log.info("generated new fontManager")

File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/font_manager.py:1043, in FontManager.__init__(self, size, weight)
   1040 for path in [*findSystemFonts(paths, fontext=fontext),
   1041              *findSystemFonts(fontext=fontext)]:
   1042     try:
-> 1043         self.addfont(path)
   1044     except OSError as exc:
   1045         _log.info("Failed to open font file %s: %s", path, exc)

File ~/checkouts/readthedocs.org/user_builds/dafeda-iterative-ensemble-smoother/envs/fix_readthedocs/lib/python3.10/site-packages/matplotlib/font_manager.py:1076, in FontManager.addfont(self, path)
   1074     self.afmlist.append(prop)
   1075 else:
-> 1076     font = ft2font.FT2Font(path)
   1077     prop = ttfFontProperty(font)
   1078     self.ttflist.append(prop)

KeyboardInterrupt: 

Create problem data#

Some settings worth experimenting with:

  • Decreasing prior_std=1 will pull the posterior solution toward zero.

  • Increasing num_ensemble will increase the quality of the solution.

  • Increasing num_observations / num_parameters will increase the quality of the solution.

num_parameters = 25
num_observations = 100
num_ensemble = 30
prior_std = 1
rng = np.random.default_rng(42)

# Create a problem with g(x) = A @ x
A = rng.standard_normal(size=(num_observations, num_parameters))


def g(X):
    """Forward model."""
    return A @ X


# Create observations: obs = g(x) + N(0, 1)
x_true = np.linspace(-1, 1, num=num_parameters)
observation_noise = rng.standard_normal(size=num_observations)
observations = g(x_true) + observation_noise

# Initial ensemble X ~ N(0, prior_std) and diagonal covariance with ones
X = rng.normal(size=(num_parameters, num_ensemble)) * prior_std

# Covariance matches the noise added to observations above
covariance = np.ones(num_observations)

Solve the maximum likelihood problem#

We can solve \(Ax = b\), where \(b\) is the observations, for the maximum likelihood estimate. Notice that unlike using a Ridge model, solving \(Ax = b\) directly does not use any prior information.

x_ml, *_ = np.linalg.lstsq(A, observations, rcond=None)

plt.figure(figsize=(8, 3))
plt.scatter(np.arange(len(x_true)), x_true, label="True parameter values")
plt.scatter(np.arange(len(x_true)), x_ml, label="ML estimate (no prior)")
plt.xlabel("Parameter index")
plt.ylabel("Parameter value")
plt.grid(True, ls="--", zorder=0, alpha=0.33)
plt.legend()
plt.show()

Solve using ESMDA#

We crease an ESMDA instance and solve the Guass-linear problem.

smoother = ESMDA(
    covariance=covariance,
    observations=observations,
    alpha=5,
    seed=1,
)

X_i = np.copy(X)
for i, alpha_i in enumerate(smoother.alpha, 1):
    print(
        f"ESMDA iteration {i}/{smoother.num_assimilations()}"
        + f" with inflation factor alpha_i={alpha_i}"
    )
    X_i = smoother.assimilate(X_i, Y=g(X_i))


X_posterior = np.copy(X_i)

Plot and compare solutions#

Compare the true parameters with both the ML estimate from linear regression and the posterior means obtained using ESMDA.

plt.figure(figsize=(8, 3))
plt.scatter(np.arange(len(x_true)), x_true, label="True parameter values")
plt.scatter(np.arange(len(x_true)), x_ml, label="ML estimate (no prior)")
plt.scatter(
    np.arange(len(x_true)), np.mean(X_posterior, axis=1), label="Posterior mean"
)
plt.xlabel("Parameter index")
plt.ylabel("Parameter value")
plt.grid(True, ls="--", zorder=0, alpha=0.33)
plt.legend()
plt.show()

We now include the posterior samples as well.

plt.figure(figsize=(8, 3))
plt.scatter(np.arange(len(x_true)), x_true, label="True parameter values")
plt.scatter(np.arange(len(x_true)), x_ml, label="ML estimate (no prior)")
plt.scatter(
    np.arange(len(x_true)), np.mean(X_posterior, axis=1), label="Posterior mean"
)

# Loop over every ensemble member and plot it
for j in range(num_ensemble):
    # Jitter along the x-axis a little bit
    x_jitter = np.arange(len(x_true)) + rng.normal(loc=0, scale=0.1, size=len(x_true))

    # Plot this ensemble member
    plt.scatter(
        x_jitter,
        X_posterior[:, j],
        label=("Posterior values" if j == 0 else None),
        color="black",
        alpha=0.2,
        s=5,
        zorder=0,
    )
plt.xlabel("Parameter index")
plt.ylabel("Parameter value")
plt.grid(True, ls="--", zorder=0, alpha=0.33)
plt.legend()
plt.show()