{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "56fb6c6b", "metadata": {}, "outputs": [], "source": [ "# ruff: noqa: E402" ] }, { "cell_type": "markdown", "id": "99ec1e27", "metadata": {}, "source": [ "# Linear regression with ESMDA\n", "\n", "We solve a linear regression problem using ESMDA.\n", "First we define the forward model as $g(x) = Ax$,\n", "then we set up a prior ensemble on the linear\n", "regression coefficients, so $x \\sim \\mathcal{N}(0, 1)$.\n", "\n", "As shown in the 2013 paper by Emerick et al, when a set of\n", "inflation weights $\\alpha_i$ is chosen so that $\\sum_i \\alpha_i^{-1} = 1$,\n", "ESMDA yields the correct posterior mean for the linear-Gaussian case." ] }, { "cell_type": "markdown", "id": "3a4fcdd8", "metadata": {}, "source": [ "## Import packages" ] }, { "cell_type": "code", "execution_count": null, "id": "f26669a7", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from matplotlib import pyplot as plt\n", "\n", "from iterative_ensemble_smoother import ESMDA" ] }, { "cell_type": "markdown", "id": "358d9fbe", "metadata": {}, "source": [ "## Create problem data\n", "\n", "Some settings worth experimenting with:\n", "\n", "- Decreasing `prior_std=1` will pull the posterior solution toward zero.\n", "- Increasing `num_ensemble` will increase the quality of the solution.\n", "- Increasing `num_observations / num_parameters`\n", " will increase the quality of the solution." ] }, { "cell_type": "code", "execution_count": null, "id": "f1ea78b5", "metadata": {}, "outputs": [], "source": [ "num_parameters = 25\n", "num_observations = 100\n", "num_ensemble = 30\n", "prior_std = 1" ] }, { "cell_type": "code", "execution_count": null, "id": "3694f24d", "metadata": {}, "outputs": [], "source": [ "rng = np.random.default_rng(42)\n", "\n", "# Create a problem with g(x) = A @ x\n", "A = rng.standard_normal(size=(num_observations, num_parameters))\n", "\n", "\n", "def g(X):\n", " \"\"\"Forward model.\"\"\"\n", " return A @ X\n", "\n", "\n", "# Create observations: obs = g(x) + N(0, 1)\n", "x_true = np.linspace(-1, 1, num=num_parameters)\n", "observation_noise = rng.standard_normal(size=num_observations)\n", "observations = g(x_true) + observation_noise\n", "\n", "# Initial ensemble X ~ N(0, prior_std) and diagonal covariance with ones\n", "X = rng.normal(size=(num_parameters, num_ensemble)) * prior_std\n", "\n", "# Covariance matches the noise added to observations above\n", "covariance = np.ones(num_observations)" ] }, { "cell_type": "markdown", "id": "f03acfc0", "metadata": {}, "source": [ "## Solve the maximum likelihood problem\n", "\n", "We can solve $Ax = b$, where $b$ is the observations,\n", "for the maximum likelihood estimate.\n", "Notice that unlike using a Ridge model,\n", "solving $Ax = b$ directly does not use any prior information." ] }, { "cell_type": "code", "execution_count": null, "id": "506e5da0", "metadata": {}, "outputs": [], "source": [ "x_ml, *_ = np.linalg.lstsq(A, observations, rcond=None)\n", "\n", "plt.figure(figsize=(8, 3))\n", "plt.scatter(np.arange(len(x_true)), x_true, label=\"True parameter values\")\n", "plt.scatter(np.arange(len(x_true)), x_ml, label=\"ML estimate (no prior)\")\n", "plt.xlabel(\"Parameter index\")\n", "plt.ylabel(\"Parameter value\")\n", "plt.grid(True, ls=\"--\", zorder=0, alpha=0.33)\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "45bf99b7", "metadata": {}, "source": [ "## Solve using ESMDA\n", "\n", "We crease an `ESMDA` instance and solve the Guass-linear problem." ] }, { "cell_type": "code", "execution_count": null, "id": "b25c4178", "metadata": {}, "outputs": [], "source": [ "smoother = ESMDA(\n", " covariance=covariance,\n", " observations=observations,\n", " alpha=5,\n", " seed=1,\n", ")\n", "\n", "X_i = np.copy(X)\n", "for i, alpha_i in enumerate(smoother.alpha, 1):\n", " print(\n", " f\"ESMDA iteration {i}/{smoother.num_assimilations()}\"\n", " f\" with inflation factor alpha_i={alpha_i}\"\n", " )\n", " smoother.prepare_assimilation(Y=g(X_i))\n", " X_i = smoother.assimilate_batch(X=X_i)\n", "\n", "\n", "X_posterior = np.copy(X_i)" ] }, { "cell_type": "markdown", "id": "3517d37a", "metadata": {}, "source": [ "## Plot and compare solutions\n", "\n", "Compare the true parameters with both the ML estimate\n", "from linear regression and the posterior means obtained using `ESMDA`." ] }, { "cell_type": "code", "execution_count": null, "id": "55b5c9bb", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(8, 3))\n", "plt.scatter(np.arange(len(x_true)), x_true, label=\"True parameter values\")\n", "plt.scatter(np.arange(len(x_true)), x_ml, label=\"ML estimate (no prior)\")\n", "plt.scatter(\n", " np.arange(len(x_true)), np.mean(X_posterior, axis=1), label=\"Posterior mean\"\n", ")\n", "plt.xlabel(\"Parameter index\")\n", "plt.ylabel(\"Parameter value\")\n", "plt.grid(True, ls=\"--\", zorder=0, alpha=0.33)\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "93d0443e", "metadata": {}, "source": [ "We now include the posterior samples as well." ] }, { "cell_type": "code", "execution_count": null, "id": "2fd08940", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(8, 3))\n", "plt.scatter(np.arange(len(x_true)), x_true, label=\"True parameter values\")\n", "plt.scatter(np.arange(len(x_true)), x_ml, label=\"ML estimate (no prior)\")\n", "plt.scatter(\n", " np.arange(len(x_true)), np.mean(X_posterior, axis=1), label=\"Posterior mean\"\n", ")\n", "\n", "# Loop over every ensemble member and plot it\n", "for j in range(num_ensemble):\n", " # Jitter along the x-axis a little bit\n", " x_jitter = np.arange(len(x_true)) + rng.normal(loc=0, scale=0.1, size=len(x_true))\n", "\n", " # Plot this ensemble member\n", " plt.scatter(\n", " x_jitter,\n", " X_posterior[:, j],\n", " label=(\"Posterior values\" if j == 0 else None),\n", " color=\"black\",\n", " alpha=0.2,\n", " s=5,\n", " zorder=0,\n", " )\n", "plt.xlabel(\"Parameter index\")\n", "plt.ylabel(\"Parameter value\")\n", "plt.grid(True, ls=\"--\", zorder=0, alpha=0.33)\n", "plt.legend()\n", "plt.show()" ] } ], "metadata": { "jupytext": { "formats": "ipynb,py:percent" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }