#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 10 2019

@authors: James Annan (translated to Python by Martin Renoult)
correspondence: martin.renoult@misu.su.se
"""

## Library
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats as stat
from scipy.stats import gaussian_kde
from adjustText import adjust_text
from pymc3 import *

#------------------------------------------------------------------------------
## Model data
# x = ECS
# y = Tropical temperature

# In the following order:
# PMIP2[MIROC, IPSL, CCSM, ECHAM, FGOALS, HadCM3, ECBILT]
# PMIP3[NCAR-CCSM4, IPSL-CM5A-LR, MIROC-ESM, MPI-ESM-P, CNRM-CM5, MRI-CGCM3, FGOALS-g2]

x = [4,4.4,2.7,3.4,2.3,3.3,1.8,
     3.2,4.13,4.67,3.45,3.25,2.6,3.37]

y = [-2.74899797,-2.82983243,-2.11649457,-3.15521371,-2.35793821,-2.77644642,-1.33816631,
     -2.56430725,-3.45682995,-2.40686508,-2.58021444,-1.67634272,-2.80245053,-3.1465664]

# PMIP2 only
#x = [4,4.4,2.7,3.4,2.3,3.3,1.8]
#y = [-2.74899797,-2.82983243,-2.11649457,-3.15521371,-2.35793821,-2.77644642,-1.33816631]

#------------------------------------------------------------------------------
## Required statistics for the prior

mnx = np.mean(x)
mny = np.mean(y)
stdx = np.std(x)
stdy = np.std(y)
pearson = stat.pearsonr(x, y)

nsamp = 5000

pri_mn = [mnx, mny]
pri_sd = [stdx, stdy]
pri_rho = pearson[0]
pri_var = [[pri_sd[0]**2, pri_sd[0]*pri_sd[1]*pri_rho],[pri_sd[0]*pri_sd[1]*pri_rho,pri_sd[1]**2]]

pri = np.random.multivariate_normal(mean=pri_mn, cov=pri_var, size=5000)
df_t = pri.transpose()

## Predictions from prior and plot

samp_mn = np.mean(pri, axis=0)
ls_mn = [samp_mn[0], samp_mn[1]]
samp_var = np.cov(df_t)
est = np.random.multivariate_normal(mean=samp_mn, cov=samp_var, size=nsamp)

df2 = pd.DataFrame(est, columns=['Climate Sensitivity', 'LGM tropical (20S - 30N) temperature change'])
fig, ax = plt.subplots(figsize=(7, 7))
plt.scatter(df2['Climate Sensitivity'], df2['LGM tropical (20S - 30N) temperature change'], c='#e0e0eb', marker='.', label='Prior')

#------------------------------------------------------------------------------
## One-step Kalman filtering

# LGM tropical temperature from geological proxy
ob_mn = -2.2
ob_sd = 0.4

y = ob_mn - samp_mn[1]
S = samp_var[1, 1] + ob_sd ** 2
K = np.matmul(samp_var, [0, 1]) / S

pos_mn = samp_mn + K*y
pos_var = np.matmul([[1, 0], [0, 1]] - K*[0,1], samp_var)
pos_samp = np.random.multivariate_normal(mean=pos_mn, cov=pos_var, size=nsamp)

df3 = pd.DataFrame(pos_samp, columns=['Climate Sensitivity', 'LGM tropical (20S - 30N) temperature change'])
xy = np.vstack([df3['Climate Sensitivity'], df3['LGM tropical (20S - 30N) temperature change']])
z = gaussian_kde(xy)(xy)
plt.scatter(df3['Climate Sensitivity'], df3['LGM tropical (20S - 30N) temperature change'], c=z, marker='.', label='Posterior', s=100, edgecolor='')

print(pos_mn)
print(np.sqrt(np.diag(pos_var)))

# Final statistics
stdfinal = np.sqrt(np.diag(pos_var))
stdprior = np.sqrt(np.diag(samp_var))
gausspri = np.random.normal(loc=pri_mn[0], scale=stdprior[0], size=1000)
gausspost = np.random.normal(loc=pos_mn[0], scale=stdfinal[0], size=1000)

post_stats_90_pri = stats.quantiles(gausspri, qlist=(5, 95))
post_stats_90_pri = [ v for v in post_stats_90_pri.values()]
post_stats_90_post = stats.quantiles(gausspost, qlist=(5, 95))
post_stats_90_post = [ v for v in post_stats_90_post.values()]

plt.xlabel('Climate Sensitivity')
plt.ylabel('LGM tropical (20S - 30N) temperature change')

# Redefine y for plotting

y = [-2.74899797,-2.82983243,-2.11649457,-3.15521371,-2.35793821,-2.77644642,-1.33816631,
     -2.56430725,-3.45682995,-2.40686508,-2.58021444,-1.67634272,-2.80245053,-3.1465664]

ylim = plt.ylim(-5, 0.5)
xlim = plt.xlim(-1,8)
plt.plot(x[0:7], y[0:7], '.', label='PMIP2', markersize=17, color='#0066ff',mec='#25097C')
plt.plot(x[7:14], y[7:14], '.', label='PMIP3',markersize=17, color='#ff9933',mec='#9C590C')
plt.xlabel('Climate Sensitivity (K)', labelpad=-35, fontsize=16)
plt.ylabel('LGM tropical (20° S - 30° N) \ntemperature anomaly (K)', position=(0,0.4), fontsize=16)
ax.spines['top'].set_alpha(0)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_position(('data', 0))
ax.spines['left'].set_position(('data', 0))
ax.spines['top'].set_position(('data', 0))
ax.spines['left'].set_linewidth(2)
ax.spines['bottom'].set_linewidth(2)
ax.tick_params('x', direction='in', pad=-20, width=2)
ax.tick_params('y', width=2)
plt.yticks(ticks=np.arange(-1, -6, -1), fontsize=14)
plt.xticks(ticks=np.arange(1,9,1), fontsize=14)

texts = [plt.text(x[i], y[i], '%s' %(i+1), ha='center', va='center', fontsize=15) for i in range(len(x))]
adjust_text(texts)

ob_sd = 0.7

plt.axvline(x=0.2, ymin=1-(-ylim[1]/(ylim[0]-ylim[1]))-((ob_mn-ob_sd)/(ylim[0]-ylim[1])), ymax=1-(-ylim[1]/(ylim[0]-ylim[1]))-((ob_mn+ob_sd)/(ylim[0]-ylim[1])), 
            color='#009900', label='5-95% observed value', linewidth=2)
plt.axvline(x=0.2, ymin=1-(-ylim[1]/(ylim[0]-ylim[1]))-((ob_mn-ob_sd)/(ylim[0]-ylim[1])), ymax=1-(-ylim[1]/(ylim[0]-ylim[1]))-((ob_mn-ob_sd)/(ylim[0]-ylim[1])), 
            color='#009900', marker='v')
plt.axvline(x=0.2, ymin=1-(-ylim[1]/(ylim[0]-ylim[1]))-((ob_mn+ob_sd)/(ylim[0]-ylim[1])), ymax=1-(-ylim[1]/(ylim[0]-ylim[1]))-((ob_mn+ob_sd)/(ylim[0]-ylim[1])), 
            color='#009900', marker='^')

plt.axvline(x=0.2, ymin=1-(-ylim[1]/(ylim[0]-ylim[1]))-(ob_mn/(ylim[0]-ylim[1])), ymax=1-(-ylim[1]/(ylim[0]-ylim[1]))-(ob_mn/(ylim[0]-ylim[1])), 
            color='#009900', marker='.', ms=12)
            
plt.axhline(y=-0.3, xmin=(-1/(xlim[0]-xlim[1]))-((post_stats_90_post[0])/(xlim[0]-xlim[1])), 
            xmax=(-1/(xlim[0]-xlim[1]))-((post_stats_90_post[1])/(xlim[0]-xlim[1])), c='#9933ff', label='5-95% posterior', linewidth=2)
plt.axhline(y=-0.3, xmin=(-1/(xlim[0]-xlim[1]))-((post_stats_90_post[0])/(xlim[0]-xlim[1])), 
            xmax=(-1/(xlim[0]-xlim[1]))-((post_stats_90_post[0])/(xlim[0]-xlim[1])), marker='<', c='#9933ff')
plt.axhline(y=-0.3, xmin=(-1/(xlim[0]-xlim[1]))-((post_stats_90_post[1])/(xlim[0]-xlim[1])), 
            xmax=(-1/(xlim[0]-xlim[1]))-((post_stats_90_post[1])/(xlim[0]-xlim[1])), marker='>', c='#9933ff')
plt.axhline(y=-0.3, xmin=(-1/(xlim[0]-xlim[1]))-(pos_mn[0]/(xlim[0]-xlim[1])),
            xmax=(-1/(xlim[0]-xlim[1]))-(pos_mn[0]/(xlim[0]-xlim[1])), c='#9933ff', marker='.', ms=12)

plt.axhline(y=-0.15, xmin=(-1/(xlim[0]-xlim[1]))-((post_stats_90_pri[0])/(xlim[0]-xlim[1])), 
            xmax=(-1/(xlim[0]-xlim[1]))-((post_stats_90_pri[1])/(xlim[0]-xlim[1])), c='#c2c2d6', label='5-95% prior', linewidth=2)
plt.axhline(y=-0.15, xmin=(-1/(xlim[0]-xlim[1]))-((post_stats_90_pri[0])/(xlim[0]-xlim[1])), 
            xmax=(-1/(xlim[0]-xlim[1]))-((post_stats_90_pri[0])/(xlim[0]-xlim[1])), marker='<', c='#c2c2d6')
plt.axhline(y=-0.15, xmin=(-1/(xlim[0]-xlim[1]))-((post_stats_90_pri[1])/(xlim[0]-xlim[1])), 
            xmax=(-1/(xlim[0]-xlim[1]))-((post_stats_90_pri[1])/(xlim[0]-xlim[1])), marker='>', c='#c2c2d6')
plt.axhline(y=-0.15, xmin=(-1/(xlim[0]-xlim[1]))-(pri_mn[0]/(xlim[0]-xlim[1])),
            xmax=(-1/(xlim[0]-xlim[1]))-(pri_mn[0]/(xlim[0]-xlim[1])), c='#c2c2d6', marker='.', ms=12)

plt.legend(loc='upper left', bbox_to_anchor=(0.6, 0.85), fancybox=True, ncol=1, edgecolor='k')
colorbar = plt.colorbar(shrink = 0.5, ticks=np.arange(0, 1, 0.1), extend='max')
colorbar.set_label('Probability', fontsize=14)
colorbar.ax.tick_params(labelsize=12, width=2)
fig.canvas.flush_events()
colorbar.ax.set_yticklabels(labels=['0', '0.1', '', '0.3', '', '0.5', '', '0.7', '', '0.9'])
plt.tight_layout()
#plt.savefig('KF_PMIP.pdf', dpi=300)