#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 10 11:14:16 2019

@author: Martin Renoult
correspondence: martin.renoult@misu.su.se
"""

## Library

import numpy as np
import statsmodels.api as sm

from pymc3 import *

#------------------------------------------------------------------------------
## Model data
# x = ECS
# y = Tropical temperature

# In the following order:
# PMIP2[MIROC, IPSL, CCSM, ECHAM, FGOALS, HadCM3, ECBILT]
# PMIP3[NCAR-CCSM4, IPSL-CM5A-LR, MIROC-ESM, MPI-ESM-P, CNRM-CM5, MRI-CGCM3, FGOALS-g2]

x = [4,4.4,2.7,3.4,2.3,3.3,1.8,
     3.2,4.13,4.67,3.45,3.25,2.6,3.37]

y = [-2.74899797,-2.82983243,-2.11649457,-3.15521371,-2.35793821,-2.77644642,-1.33816631,
     -2.5957031, -3.3755188, -2.5169983, -2.5567322, -1.6669922, -2.81839, -3.1465664]

# PMIP2 only
x_pmip2 = [4,4.4,2.7,3.4,2.3,3.3,1.8]
y_pmip2 = [-2.74899797,-2.82983243,-2.11649457,-3.15521371,-2.35793821,-2.77644642,-1.33816631]

# PMIP2+3
x_pmip23 = [4,4.4,2.7,3.4,2.3,3.3,1.8,
     3.2,4.13,4.67,3.45,3.25,2.6,3.37]

y_pmip23 = [-2.74899797,-2.82983243,-2.11649457,-3.15521371,-2.35793821,-2.77644642,-1.33816631,
     -2.5957031, -3.3755188, -2.5169983, -2.5567322, -1.6669922, -2.81839, -3.1465664]

# PMIP3 only
#x_3 = [3.2,4.13,4.67,3.45,3.25,2.6,3.37]
#
#y_3 = [-2.5957031, -3.3755188, -2.5169983, -2.5567322, -1.6669922, -2.81839, -3.1465664]

list_predict_t_stats = list()
low_bound = list()
up_bound = list()

# OLS on the dataset of Schmidt et al., 2014
y_pmip23 = sm.add_constant(y_pmip23)
model_schmidt = sm.OLS(x_pmip23, y_pmip23)
results_schmidt = model_schmidt.fit()

coeff_schmidt = results_schmidt.params
resid_schmidt = results_schmidt.resid
mse_resid_schmidt = results_schmidt.mse_resid

y_pmip23 = y_pmip23[:,1:]

# OLS on the dataset of Hargreaves et al., 2012
y_pmip2 = sm.add_constant(y_pmip2)
model_hargreaves = sm.OLS(x_pmip2, y_pmip2)
results_hargreaves = model_hargreaves.fit()

coeff_hargreaves = results_hargreaves.params
resid_hargreaves = results_hargreaves.resid
mse_resid_hargreaves = results_hargreaves.mse_resid

y_pmip2 = y_pmip2[:,1:]

ran_renoult = np.linspace(0, 20, 1000)
ran_schmidt = np.linspace(-6, 10, 500)

prediction_schmidt = (ran_schmidt-coeff_schmidt[0])/coeff_schmidt[1]
prediction_hargreaves = (ran_schmidt-coeff_hargreaves[0])/coeff_hargreaves[1]

# "Real" observed data of tropical LGM
T = -2.2
stdT = 0.4
distrib_T = np.random.normal(loc=T, scale=stdT, size=100000)

# Inference
inference_schmidt = distrib_T*coeff_schmidt[1]+coeff_schmidt[0]+np.random.normal(loc=0, scale=mse_resid_schmidt**0.5, size=100000)
inference_hargreaves = distrib_T*coeff_hargreaves[1]+coeff_hargreaves[0]+np.random.normal(loc=0, scale=mse_resid_hargreaves**0.5, size=100000)

# Some statistics
median_23 = np.median(inference_schmidt)
median_23 = np.round(median_23, decimals=2)

median_2 = np.median(inference_hargreaves)
median_2 = np.round(median_2, decimals=2)

post_stats_90_23 = stats.quantiles(inference_schmidt, qlist=(5, 95))

post_stats_90_2 = stats.quantiles(inference_hargreaves, qlist=(5, 95))