Skip to content
Snippets Groups Projects
run.py 8.72 KiB
Newer Older
import datetime
import os
import pathlib
import pdb
import pickle
import shutil
import warnings

import geopandas as gpd
import matplotlib.pyplot as plt
import matplotlib.patheffects
import numpy as np
import pandas as pd
import xdem

import marma_dems.analysis
import marma_dems.plotting


def main():

    # warnings.simplefilter("error")

    reference_year = 2016

    dems, ddems, unstable_terrain = marma_dems.main.prepare_dems(reference_year=reference_year)


    variogram_dfs = []

    areas = np.round(10 ** np.linspace(0, 7, num=20), 0)
    stderrs = pd.DataFrame(index=areas)
    for ddem in ddems:
        vgm = ddem.variograms["variogram"].copy()
        ddem_label = f"dH_{ddem.start_time.year}-{ddem.end_time.year}"
        vgm = vgm.set_index([[ddem_label] * vgm.shape[0], vgm.index])
        list_stderr_doublerange = []

        for area in areas:
            # For a double-range model
            neff_doublerange = xdem.spatialstats.neff_circ(
                area,
                [
                    (ddem.variograms["vgm_params"][0], "Sph", ddem.variograms["vgm_params"][1]),
                    (ddem.variograms["vgm_params"][2], "Sph", ddem.variograms["vgm_params"][3]),
                ],
            )

            # Convert into a standard error
            stderr_doublerange = np.nanstd(ddem.data) / np.sqrt(neff_doublerange)
            list_stderr_doublerange.append(stderr_doublerange)

        stderrs.loc[areas, ddem_label] = list_stderr_doublerange
        variogram_dfs.append(vgm)

    variograms = pd.concat(variogram_dfs).drop(columns=["bins"])

    os.makedirs("output/", exist_ok=True)
    
    variograms.to_csv("output/dH_variograms.csv", index_label=["product", "bins"])
    stderrs.to_csv("output/dH_integration_uncertainty.csv", index_label=["integration_area"])

    # The dDEM is interpolated, so generate an interpolated 2021 DEM using "DEM_2016 + dDEM_2016_2021"
    dem_2021_interp = (
        dems[reference_year]
        + filter(lambda ddem: ddem.end_time.year == 2021 and ddem.start_time.year == reference_year, ddems).__next__()
    )

    tdem = marma_dems.analysis.InterpolatedDEM(dems.copy())
    tdem.dems[2021] = dem_2021_interp

    changes = marma_dems.analysis.volume_change(ddems, unstable_terrain)

    changes["mean_area"] = changes[["start_area", "end_area"]].mean(axis=1)
    changes["mb"] = (changes["mean_dv"] / changes["mean_area"]) * 0.85
    changes["mb_err"] = np.sqrt(
        (changes["dv_error"] / changes["mean_area"]) ** 2 + ((changes["mb"] / 0.85) * 0.06) ** 2
    )
    changes["start_year"] = changes.index.left
    changes["end_year"] = changes.index.right

    creation_options = {"COMPRESS": "DEFLATE", "ZLEVEL": 12, "PREDICTOR": 3, "TILED": True, "NUM_THREADS": "ALL_CPUS"}

    for year in dems:
        dems[year].save(f"output/Marma_DEM_{year}.tif", co_opts=creation_options)

    dem_2021_interp.save("output/Marma_DEM_2021_interp.tif", co_opts=creation_options)

    changes.to_csv("output/Marma_geodetic_1959-2021.csv", index=False)

    #for name, outlines in gpd.read_file("GIS/shapes/glacier.geojson").groupby("name"):
    #    with warnings.catch_warnings():
    #        warnings.filterwarnings("ignore", message="pandas.Int64Index is deprecated")
    #        outlines.to_file(f"output/{name}_glacier_outlines_1959-2021.geojson", engine="GeoJSON")
    for filepath in pathlib.Path("GIS/shapes/").glob("*.geojson"):
        shutil.copy(filepath, pathlib.Path("output/").joinpath(filepath.name))

    # marma.plotting.plot_volume_change(changes)

    """
    return
    
    marma.plotting.plot_variograms(ddems)
    plt.show()
    return
    """

    os.makedirs("output/figures/", exist_ok=True)


    changes["start_year_dec"] = changes["start_year"].apply(lambda dt: dt.year + dt.month / 12 + dt.day / 365)
    changes["end_year_dec"] = changes["end_year"].apply(lambda dt: dt.year + dt.month / 12 + dt.day / 365)
    changes["mean_year_dec"] = changes[["start_year_dec", "end_year_dec"]].mean(axis=1)
    changes["year_duration_dec"] = changes["end_year_dec"] - changes["start_year_dec"]

    changes.sort_values("start_year_dec", inplace=True)

    changes["mean_area"] = changes[["start_area", "end_area"]].mean(axis=1)
    changes["dhdt"] = changes["mean_dh"] / changes["year_duration_dec"]
    changes["dhdt_err"] = changes["dh_error"] / changes["year_duration_dec"]
    changes["mb_cumsum"] = changes["mb"].cumsum()

    fig = plt.figure(figsize=(9, 4))
    axis0 = plt.subplot(121)
    for i in range(changes.shape[0] - 1):
        plt.plot([changes.iloc[i]["end_year_dec"], changes.iloc[i + 1]["start_year_dec"]], [changes.iloc[i]["dhdt"], changes.iloc[i + 1]["dhdt"]], linestyle="--", color="black", zorder=1)
    plt.bar(changes["mean_year_dec"], changes["dhdt_err"] * 2, width=changes["year_duration_dec"], bottom=(changes["dhdt"] - changes["dhdt_err"]), color="crimson", edgecolor="black", zorder=2)
    plt.ylabel("Elevation change rate (m/a)")

    axis1 = plt.subplot(122)
    plt.plot(changes["mean_year_dec"], changes["mb_cumsum"], linestyle="--", color="black", zorder=1)
    plt.bar(changes["mean_year_dec"], changes["mb_err"] * 2, width=changes["year_duration_dec"], bottom=changes["mb_cumsum"] - changes["mb_err"], zorder=2, edgecolor="black")
    plt.ylabel("Cumulative mass balance (m w.e.)")

    for axis, col in [(axis0, "dhdt"), (axis1, "mb_cumsum")]:
        for _, row in changes.iterrows():
            axis.plot([row["start_year_dec"], row["end_year_dec"]], [row[col]] * 2, color="black", zorder=3)
        axis.set_ylim(changes[col].min() - (abs(changes[col].min()) * 0.1), changes[col].max() + (abs(changes[col].min()) * 0.1))
        axis.grid(alpha=0.5, linewidth=1, zorder=0)

    plt.tight_layout()
    plt.savefig("output/figures/marma_dh_series.jpg", dpi=600)


    fig = plt.figure(figsize=(12.5, 5))
    for i, ddem in enumerate(ddems):

        ddem.data /= abs(ddem.start_time.year - ddem.end_time.year)
        axis0 = plt.subplot(2, 6, i + 1)
        ddem.show(
            ax=axis0,
            cmap="RdBu",
            vmin=-2,
            vmax=2,
            title=f"{min(ddem.start_time.year, ddem.end_time.year)}{max(ddem.start_time.year, ddem.end_time.year)}",
            add_cb=False,
        )
        axis1 = plt.subplot(2, 6, i + 7)
        xdem.DEM.from_array(ddem.error, ddem.transform, ddem.crs, -9999).show(
            cmap="Reds", ax=axis1, vmin=0, vmax=10, add_cb=False
        )

        for j, axis in enumerate([axis0, axis1]):
            axis.set_xlim(ddem.bounds.left, ddem.bounds.right)
            axis.set_ylim(ddem.bounds.bottom, ddem.bounds.top)

            yticks = np.linspace(ddem.bounds.bottom, ddem.bounds.top, 5)
            axis.set_yticks(yticks)
            if i % 5 == 0:
                yticks_str = [str(tick)[:5] for tick in yticks / 1e6]
                axis.set_yticklabels(
                    [f"{yticks_str[0]} * 10⁶"] + [""] * (len(yticks) - 2) + [f"{yticks_str[-1]} * 10⁶"]
                )
                axis.set_ylabel("Northing (m)", labelpad=-50)
            else:
                axis.set_yticklabels([""] * len(yticks))

            xticks = np.linspace(ddem.bounds.left, ddem.bounds.right, 5)
            axis.set_xticks(xticks)
            if i % 5 == 0 and j == 1:
                axis.set_xticklabels(
                    [int(round(xticks[0], -1))] + [""] * (len(xticks) - 2) + [int(round(xticks[-1], -1))]
                )
                axis.set_xlabel("Easting (m)", labelpad=-10)
            else:
                axis.set_xticklabels([""] * len(xticks))

    plt.subplot(2, 6, 6)

    plt.imshow([[0, 0], [0, 0]], cmap="RdBu", vmin=-2, vmax=2)
    plt.axis("off")
    cbar = plt.colorbar(fraction=1, aspect=10, shrink=0.7)
    cbar.set_label(
        """Elevation change
rate (m/a)"""
    )

    plt.subplot(2, 6, 12)

    plt.imshow([[0, 0], [0, 0]], cmap="Reds", vmin=0, vmax=10)
    plt.axis("off")
    cbar = plt.colorbar(fraction=1, aspect=10, shrink=0.7)
    cbar.set_label(
        """Elevation change
uncertainty (m/a)"""
    )

    plt.tight_layout()
    plt.subplots_adjust(top=0.958, bottom=0.048, left=0.079, right=0.998, hspace=0.0, wspace=0.025)
    plt.savefig("output/figures/marma_dh_maps.jpg", dpi=600)
    plt.show()

    plt.figure(figsize=(8, 5))
    #for product in stderrs:
    #    plt.plot(stderrs[product], cmap="Purples")
    axis = plt.subplot(111)
    stderrs.rename(columns={col: col.replace("_", " ").replace("-", "") for col in stderrs}).plot.line(colormap="coolwarm", legend=True, ax=axis, path_effects=[matplotlib.patheffects.withStroke(linewidth=2, foreground="black")])
    plt.xlabel("Integrated area (m²)")
    plt.ylabel("Elevation change uncertainty (m)")
    plt.xscale("log")
    plt.grid(alpha=0.5, linewidth=1)

    plt.tight_layout()
    plt.savefig("output/figures/marma_integrated_area_uncertainty.jpg", dpi=600)


if __name__ == "__main__":
    main()