Case study of prediction performance in the Northern Gulf of Mexico

[1]:
import h5py
import scipy
import numpy as np
import torch
[2]:
def load(path):
    path = path + "/with_imputation"
    preds = []
    for i in range(10):
        preds.append(np.load(path+"/prediction_{}.npy".format(i)))
    return np.stack(preds, axis=1)
[4]:
base_dir = "../log_bak/prediction/MEXICO/"
prediction_xg_wo = torch.from_numpy(np.load(base_dir+"XGBoost/without_imputation/prediction_0.npy", allow_pickle=True))
prediction_our = torch.from_numpy(load(base_dir+"STIMP"))
prediction_predrnn_wo = torch.from_numpy(np.load(base_dir+"PredRNN/without_imputation/prediction_0.npy", allow_pickle=True))
[7]:
import pandas as pd
label = np.load("../data/MEXICO/trues.npy")
label_masks = np.load("../data/MEXICO/true_masks.npy")
index = [46*i for i in range(306//46)]
label_masks = label_masks.squeeze()
label = label.squeeze()
label = torch.from_numpy(label)
label_masks = torch.from_numpy(label_masks)
date = pd.date_range(start='2016-02-02', end='2022-02-10', freq='8D')

Fig 5c: Predicted Chl_a of XGBoost, STIMP and PredRNN at Position 1

[18]:
#XGBoost
from copy import deepcopy
import seaborn as sns
import matplotlib.pyplot as plt
n=530

index = [46*i for i in range(306//46)]
tmp  = deepcopy(label[index].reshape(276,-1))
tmp_mask  = deepcopy(label_masks[index].reshape(276,-1))
tmp[~tmp_mask.bool()]=np.nan
plt.scatter(date, tmp[:,n], c='red', s=10,label="Observed")

predict = deepcopy(prediction_xg_wo[index].reshape(276,1, 2907))
predict = predict[:,0,n]

plt.plot(date, predict, label="PredRNN")

# plt.legend()
# plt.xticks([])
# plt.yticks([])
plt.ylim(-0.5,1.5)
plt.yticks([-0.4,0.4,1.4], fontsize=16)
plt.ylabel("Chl_a (Log(ug/L))", fontsize=20)
plt.xticks(fontsize=16)
plt.show()
../../_images/analysis_MEXICO_07-prediction-mexico-case-study_6_0.png
[19]:
#STIMP
plt.scatter(date, tmp[:,n], c='red', s=10,label="Observed")
predict = deepcopy(prediction_our[index].transpose(1,2).reshape(276,10,2907))
predict = predict[:,:,n]
mean = predict.mean(1)
std = predict.std(1)

plt.plot(date, mean, label="STImp")
plt.fill_between(date, mean-std, mean+std, alpha=0.3)

plt.xticks(fontsize=16)
# plt.yticks([])
plt.yticks([-0.4,0.4,1.4], fontsize=16)
plt.ylim(-0.5,1.5)
plt.show()
../../_images/analysis_MEXICO_07-prediction-mexico-case-study_7_0.png
[21]:
#PredRNN
plt.scatter(date, tmp[:,n], c='red', s=10,label="Observed")
predict = deepcopy(prediction_predrnn_wo[index].reshape(276,2907))
predict = predict[:,n]
plt.plot(date, predict, label="STImp")

plt.ylim(-0.5,1.5)
plt.xticks(fontsize=16)
plt.yticks([-0.4,0.4,1.4], fontsize=16)
plt.show()
../../_images/analysis_MEXICO_07-prediction-mexico-case-study_8_0.png

Fig 5c: Predicted Chl_a of XGBoost, STIMP and PredRNN at Position 2

[22]:
n=2000

index = [46*i for i in range(306//46)]
tmp  = deepcopy(label[index].reshape(276,-1))
tmp_mask  = deepcopy(label_masks[index].reshape(276,-1))
tmp[~tmp_mask.bool()]=np.nan
plt.scatter(date, tmp[:,n], c='red', s=10,label="Observed")

predict = deepcopy(prediction_xg_wo[index].reshape(276,1, 2907))
predict = predict[:,0,n]

plt.plot(date, predict, label="PredRNN")

# plt.legend()
plt.xticks(fontsize=16)
# plt.yticks([])
plt.ylim(-0.5,1.5)
plt.yticks([-0.4,0.4,1.4], fontsize=16)
plt.ylabel("Chl_a (Log(ug/L))", fontsize=20)
plt.show()
../../_images/analysis_MEXICO_07-prediction-mexico-case-study_10_0.png
[24]:
#STIMP
plt.scatter(date, tmp[:,n], c='red', s=10,label="Observed")

predict = deepcopy(prediction_our[index].transpose(1,2).reshape(276,10,2907))
predict = predict[:,:,n]
mean = predict.mean(1)
std = predict.std(1)

plt.plot(date, mean, label="STImp")
plt.fill_between(date, mean-std, mean+std, alpha=0.3)

# plt.legend()
plt.xticks(fontsize=16)
plt.yticks([-0.4,0.4,1.4], fontsize=16)
plt.ylim(-0.5,1.5)
plt.show()
../../_images/analysis_MEXICO_07-prediction-mexico-case-study_11_0.png
[25]:
#PredRNN
plt.scatter(date, tmp[:,n], c='red', s=10,label="Observed")

predict = deepcopy(prediction_predrnn_wo[index].reshape(276,2907))
predict = predict[:,n]

plt.plot(date, predict, label="STImp")

# plt.legend()
plt.xticks(fontsize=16)
plt.yticks([-0.4,0.4,1.4], fontsize=16)
plt.ylim(-0.5,1.5)
plt.show()
../../_images/analysis_MEXICO_07-prediction-mexico-case-study_12_0.png