Overall improvement of STIMP compared to two climate mean predictions
[1]:
import h5py
import scipy
import numpy as np
import torch
[2]:
base_dir = "/home/mafzhang/code/Project/CHLA-Imputation-and-Prediction-for-Bay/log/prediction/PRE/"
[3]:
label = np.load("/home/mafzhang/data/PRE/8d/trues.npy")
label_masks = np.load("/home/mafzhang/data/PRE/8d/true_masks.npy")
[4]:
prediction_our = np.load(base_dir+"GraphTransformer/with_imputation/prediction.npy", allow_pickle=True)
prediction_weekly = np.load(base_dir+"weekly_climatology/prediction.npy")
prediction_monthly = np.load(base_dir+"monthly_climatology/prediction.npy")
[5]:
label_masks = label_masks.squeeze()
label = label.squeeze()
label = torch.from_numpy(label)
label_masks = torch.from_numpy(label_masks)
[6]:
prediction_our = torch.from_numpy(prediction_our).squeeze()
prediction_weekly = torch.from_numpy(prediction_weekly)
prediction_monthly = torch.from_numpy(prediction_monthly)
[7]:
is_sea2 = np.load("/home/mafzhang/code/Project/STIMP/data/PRE/is_sea_2.npy")
is_sea = np.load("/home/mafzhang/code/Project/STIMP/data/PRE/is_sea.npy")
tmp = is_sea2[is_sea.astype(bool)].astype(bool)
label = label[:,:,tmp]
label_masks = label_masks[:,:,tmp]
prediction_our = prediction_our[:,:,:,tmp]
prediction_monthly = prediction_monthly[:,:,tmp]
prediction_weekly = prediction_weekly[:,:,tmp]
[8]:
mse_our= (((prediction_our.mean(1)- label)*label_masks)**2).sum([0,1])/(label_masks.sum([0,1])+1e-5)
print(np.nanmean(mse_our))
mse_weekly = (((prediction_weekly - label)*label_masks)**2).sum([0,1])/(label_masks.sum([0,1])+1e-5)
mse_monthly = (((prediction_monthly- label)*label_masks)**2).sum([0,1])/(label_masks.sum([0,1])+1e-5)
print(np.nanmean(mse_weekly))
print(np.nanmean(mse_monthly))
0.070237435
0.0853331522367091
0.07903061542068927
[9]:
mse_our[mse_our==0]=np.nan
mse_weekly[mse_weekly==0]=np.nan
mse_monthly[mse_monthly==0]=np.nan
[11]:
import pandas as pd
import numpy as np
import pandas as pd
import numpy as np
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
category = []
category.extend(['Weekly Mean State' for i in range(4325)])
category.extend(['Monthly Mean State' for i in range(4325)])
category.extend(['Our' for i in range(4325)])
print(np.nanmean(mse_weekly))
print(np.nanmean(mse_monthly))
print(np.nanmean(mse_our))
data = {'mse': np.concatenate([ mse_weekly.numpy(), mse_monthly.numpy(), mse_our.numpy()],0),
'methods':category}
# 'imputation':imputation}
data = pd.DataFrame.from_dict(data)
plt.xticks(rotation=30,fontsize=16)
plt.ylim(0.04,0.10)
plt.yticks([0.04,0.05,0.06,0.07,0.08,0.09],fontsize=16)
plt.xlabel("Methods",fontsize=20)
plt.ylabel("Mean Square Error", fontsize=20)
color = ["#F8766D", "#80AE6B", "#4E5E7B"][::-1]
print(color)
sns.barplot(x="methods", y="mse", data=data, palette=color)
g.xaxis.label.set_size(20)
g.yaxis.label.set_size(20)
# plt.xticks([])
0.0853331522367091
0.07903061542068927
0.070237435
['#4E5E7B', '#80AE6B', '#F8766D']
/tmp/ipykernel_172255/494347466.py:29: FutureWarning:
Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.
sns.barplot(x="methods", y="mse", data=data, palette=color)
[12]:
mae_our= ((np.abs(prediction_our.mean(1)- label)*label_masks)).sum([0,1])/(label_masks.sum([0,1])+1e-5)
print(np.nanmean(mae_our))
mae_weekly = ((np.abs(prediction_weekly - label)*label_masks)).sum([0,1])/(label_masks.sum([0,1])+1e-5)
mae_monthly = ((np.abs(prediction_monthly - label)*label_masks)).sum([0,1])/(label_masks.sum([0,1])+1e-5)
print(np.nanmean(mae_weekly))
print(np.nanmean(mae_monthly))
0.19564027
0.21751725238377978
0.20831942264104245
[13]:
mae_our[mae_our==0]=np.nan
mae_weekly[mae_weekly==0]=np.nan
mae_monthly[mae_monthly==0]=np.nan
[14]:
import pandas as pd
import numpy as np
category = []
category.extend(['Weekly Mean State' for i in range(4325)])
category.extend(['Monthly Mean State' for i in range(4325)])
category.extend(['Our' for i in range(4325)])
data = {'mae': np.concatenate([ mae_weekly.numpy(), mae_monthly.numpy(), mae_our.numpy()],0),
'methods':category}
# 'imputation':imputation}
data = pd.DataFrame.from_dict(data)
plt.xticks(rotation=30)
print(np.nanmean(mae_our))
print(np.nanmean(mae_weekly))
print(np.nanmean(mae_monthly))
# color= sns.color_palette()[:7][::-1]
color = ["#F8766D", "#80AE6B", "#4E5E7B"][::-1]
g = sns.boxplot(x='methods', y='mae', linewidth=2,showfliers=False,showmeans=True,data=data,palette=color, meanprops={
"markerfacecolor": "red",
"markeredgecolor": "black",
"markersize": "6"})
g.xaxis.label.set_size(20)
g.yaxis.label.set_size(20)
plt.xticks([])
0.19564027
0.21751725238377978
0.20831942264104245
/tmp/ipykernel_172255/96548801.py:17: FutureWarning:
Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.
g = sns.boxplot(x='methods', y='mae', linewidth=2,showfliers=False,showmeans=True,data=data,palette=color, meanprops={
[14]:
([], [])
[16]:
improvement = mse_our/mse_weekly
[17]:
np.set_printoptions(threshold = np.inf)
print(np.argsort(improvement)[:500])
tensor([ 594, 4078, 3983, 4174, 4079, 3886, 3982, 3791, 3984, 653, 3887, 4077,
3790, 4175, 712, 3792, 4173, 3694, 4075, 3980, 3979, 3888, 3985, 4076,
3689, 4172, 4270, 3690, 4171, 3788, 4269, 3884, 3981, 4177, 4168, 4082,
3695, 3699, 3663, 3664, 3698, 3598, 4163, 39, 3796, 3885, 4176, 4080,
3986, 1556, 15, 3794, 14, 3787, 4074, 4081, 3883, 3665, 3988, 3789,
61, 4178, 3696, 3593, 4167, 3686, 2282, 2090, 3889, 3793, 1460, 4185,
3978, 2473, 2568, 3892, 3693, 3700, 3662, 3051, 4268, 38, 2472, 3989,
4273, 2955, 3493, 4084, 3987, 4083, 3691, 3785, 4068, 4179, 2186, 4170,
3692, 4085, 3504, 780, 3795, 2377, 975, 3890, 2860, 904, 3667, 3797,
3893, 3589, 3592, 3146, 4157, 3784, 3891, 3421, 3383, 2567, 3126, 3497,
3052, 4067, 4056, 2935, 4272, 3974, 3505, 3147, 3697, 905, 3496, 3595,
3786, 3599, 4180, 2376, 2240, 2770, 3590, 3594, 4072, 3518, 779, 3703,
2050, 3470, 2839, 4186, 4090, 1992, 4181, 2091, 3239, 3882, 4162, 2241,
3480, 842, 3704, 2956, 3805, 3494, 2911, 4064, 3409, 2146, 4153, 3798,
3519, 3477, 4184, 2957, 3125, 2283, 3701, 4271, 1713, 1902, 1991, 3516,
4169, 2671, 3408, 4059, 906, 1903, 4155, 2088, 1904, 4164, 2089, 3503,
3709, 3880, 327, 2814, 1802, 2643, 3602, 3607, 2669, 3597, 4062, 3606,
2767, 3287, 4154, 2375, 3083, 1805, 2988, 3877, 3900, 3082, 1900, 3612,
2815, 4073, 3568, 3422, 242, 3710, 3373, 1712, 37, 3031, 3996, 2766,
3510, 3975, 3600, 1809, 3030, 3567, 2242, 3668, 36, 156, 2569, 3878,
4022, 241, 1461, 4089, 4058, 4061, 2382, 3613, 3507, 13, 841, 4094,
4156, 1901, 3384, 3666, 3569, 1897, 3799, 3053, 3238, 2187, 2759, 2674,
3499, 3804, 4023, 3702, 1617, 1801, 1807, 2673, 4086, 1808, 1905, 2093,
286, 4264, 1527, 3081, 3512, 3588, 3469, 3143, 2822, 4267, 3587, 4057,
3084, 3968, 4041, 3901, 2913, 3509, 3960, 3511, 3977, 3402, 3144, 3517,
2840, 1896, 2764, 3210, 3091, 2861, 3603, 2185, 1806, 2049, 1990, 3515,
3095, 2790, 3608, 2791, 4158, 227, 285, 2000, 3685, 3508, 3688, 3611,
200, 1895, 3879, 3495, 3762, 3566, 3286, 2765, 2048, 4071, 4182, 3502,
278, 243, 2051, 3127, 2987, 3961, 1898, 3605, 3191, 2936, 1474, 2813,
3500, 3219, 2092, 4024, 3894, 2246, 2985, 2958, 3767, 2863, 3093, 2717,
3094, 3596, 1457, 143, 2145, 4160, 3382, 4055, 4003, 4274, 3718, 3959,
1525, 3190, 3374, 3926, 3781, 3687, 3760, 2739, 2910, 2564, 3601, 3896,
2916, 1804, 1993, 2855, 1899, 235, 3990, 840, 4263, 2760, 3189, 3335,
2670, 272, 2474, 1810, 3478, 3407, 2912, 2184, 3471, 3217, 4002, 1710,
12, 4042, 3501, 1996, 3993, 331, 2982, 2574, 2095, 3187, 2887, 1462,
3669, 2665, 328, 3230, 3086, 321, 4060, 3415, 2859, 4120, 3661, 3145,
2866, 2808, 3078, 2961, 287, 2355, 3759, 4069, 3997, 3572, 3783, 3614,
3973, 442, 2097, 3285, 2478, 2763, 3148, 3416, 3571, 3113, 4213, 3188,
1195, 3186, 3766, 3707, 1526, 1906, 3476, 3090, 3288, 282, 3958, 3211,
2356, 3185, 3918, 1999, 3491, 3897, 3475, 3007, 3906, 236, 1803, 3992,
2743, 3054, 3009, 3725, 1471, 3092, 2960, 326, 279, 3152, 1717, 3807,
3000, 3800, 1998, 2841, 2251, 3319, 4152, 4032])
[19]:
lat = np.load("/home/mafzhang/code/Project/STIMP/data/PRE/lati.npy")
lon = np.load("/home/mafzhang/code/Project/STIMP/data/PRE/lon.npy")
[25]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.io.shapereader as shpreader
import matplotlib.ticker as mticker
from cmap import Colormap
cm = Colormap('vispy:fire').to_mpl() # case insensitive
fig, ax= plt.subplots(figsize=(6, 4),
subplot_kw={'projection': ccrs.PlateCarree()})
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
lat1 = np.nanmin(lat)
lat2 = np.nanmax(lat)
lon1 = np.nanmin(lon)
lon2 = np.nanmax(lon)
# 遍历每个子图
ax.set_extent([lon1, lon2, lat1, lat2], crs=ccrs.PlateCarree())
tmp = np.zeros((60,96))
is_sea = np.load("/home/mafzhang/code/Project/STIMP/data/PRE/is_sea_2.npy")
is_sea2 = np.load("/home/mafzhang/code/Project/STIMP/data/PRE/is_sea_new.npy")
tmp[is_sea.astype(bool)] = improvement
tmp[~is_sea2.astype(bool)] = np.nan
# 添加基础地理特征
ax.add_feature(cfeature.LAND, facecolor='#f0f0f0')
ax.add_feature(cfeature.COASTLINE.with_scale('10m'), linewidth=0.8)
ax.add_feature(cfeature.OCEAN, facecolor='#d0e0f0')
ax.add_feature(cfeature.LAKES.with_scale('10m'), facecolor='#d0e0f0')
ax.add_feature(cfeature.RIVERS.with_scale('10m'), edgecolor='#004d99')
contourf = ax.contourf(lon, lat, (1-tmp)*100,
levels=np.linspace(0, 50, 40),
cmap=cm, # 颜色映射
transform=ccrs.PlateCarree(),extend="both")
cbar_ax = fig.add_axes([0.15, 0.1, 0.7, 0.03]) # 增大 height 的值,使 colorbar 更宽
cbar = fig.colorbar(contourf, ticks=[0,25,50], cax=cbar_ax,orientation='horizontal', label='improvement (%)')
plt.show()
[26]:
improvement = mse_our/mse_monthly
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.io.shapereader as shpreader
import matplotlib.ticker as mticker
from cmap import Colormap
cm = Colormap('vispy:fire').to_mpl() # case insensitive
fig, ax= plt.subplots(figsize=(6, 4),
subplot_kw={'projection': ccrs.PlateCarree()})
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
lat1 = np.nanmin(lat)
lat2 = np.nanmax(lat)
lon1 = np.nanmin(lon)
lon2 = np.nanmax(lon)
# 遍历每个子图
ax.set_extent([lon1, lon2, lat1, lat2], crs=ccrs.PlateCarree())
tmp = np.zeros((60,96))
is_sea = np.load("/home/mafzhang/code/Project/STIMP/data/PRE/is_sea_2.npy")
is_sea2 = np.load("/home/mafzhang/code/Project/STIMP/data/PRE/is_sea_new.npy")
tmp[is_sea.astype(bool)] = improvement
tmp[~is_sea2.astype(bool)] = np.nan
# 添加基础地理特征
ax.add_feature(cfeature.LAND, facecolor='#f0f0f0')
ax.add_feature(cfeature.COASTLINE.with_scale('10m'), linewidth=0.8)
ax.add_feature(cfeature.OCEAN, facecolor='#d0e0f0')
ax.add_feature(cfeature.LAKES.with_scale('10m'), facecolor='#d0e0f0')
ax.add_feature(cfeature.RIVERS.with_scale('10m'), edgecolor='#004d99')
contourf = ax.contourf(lon, lat, (1-tmp)*100,
levels=np.linspace(0, 50, 40),
cmap=cm, # 颜色映射
transform=ccrs.PlateCarree(), extend='both')
cbar_ax = fig.add_axes([0.15, 0.1, 0.7, 0.03]) # 增大 height 的值,使 colorbar 更宽
cbar = fig.colorbar(contourf, ticks=[0, 25, 50], cax=cbar_ax,orientation='horizontal', label='improvement (%)')
plt.show()