Imputation with trained STIMP in the Northern Gulf of Mexico

[1]:
import torch
import numpy as np
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd(), ".."))
print(sys.path)
from dataset.dataset_imputation import PRE8dDataset
['/home/mafzhang/code/STIMP/..', '/home/mafzhang/miniconda3/envs/torch/lib/python39.zip', '/home/mafzhang/miniconda3/envs/torch/lib/python3.9', '/home/mafzhang/miniconda3/envs/torch/lib/python3.9/lib-dynload', '', '/home/mafzhang/.local/lib/python3.9/site-packages', '/home/mafzhang/miniconda3/envs/torch/lib/python3.9/site-packages']
[2]:
import argparse
parser = argparse.ArgumentParser(description='Imputation')

# args for area and methods
parser.add_argument('--area', type=str, default='MEXICO', help='which bay area we focus')

# basic args
parser.add_argument('--epochs', type=int, default=500, help='epochs')
parser.add_argument('--batch_size', type=int, default=16, help='batch size')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--wd', type=float, default=1e-4, help='weight decay')
parser.add_argument('--test_freq', type=int, default=500, help='test per n epochs')
parser.add_argument('--embedding_size', type=int, default=32)
parser.add_argument('--hidden_channels', type=int, default=32)
parser.add_argument('--diffusion_embedding_size', type=int, default=64)
parser.add_argument('--side_channels', type=int, default=1)

# args for tasks
parser.add_argument('--in_len', type=int, default=46)
parser.add_argument('--out_len', type=int, default=46)
parser.add_argument('--missing_ratio', type=float, default=0.1)

# args for diffusion
parser.add_argument('--beta_start', type=float, default=0.0001, help='beta start from this')
parser.add_argument('--beta_end', type=float, default=0.2, help='beta end to this')
parser.add_argument('--num_steps', type=float, default=50, help='denoising steps')
parser.add_argument('--num_samples', type=int, default=10, help='n datasets')
parser.add_argument('--schedule', type=str, default='quad', help='noise schedule type')
parser.add_argument('--target_strategy', type=str, default='random', help='mask')

# args for mae
parser.add_argument('--num_heads', type=int, default=8, help='n heads for self attention')
config = parser.parse_args([])

Fig 5a: Missing rate is equal to 0.1

[5]:
from torch.utils.data import DataLoader
test_dloader = DataLoader(PRE8dDataset(config, mode='test'), 1, shuffle=False)
datas, data_ob_masks, data_gt_masks, labels, label_masks = next(iter(test_dloader))
device = "cuda"
datas, data_ob_masks, data_gt_masks, labels, label_masks = datas.float().to(device), data_ob_masks.to(device), data_gt_masks.to(device), labels.to(device), label_masks.to(device)

# load model
model = torch.load("./log/imputation/MEXICO/STIMP/best_0.1.pt")
model = model.to(device)
cond_mask = data_gt_masks
adj = np.load("./data/{}/adj.npy".format(config.area))
adj = torch.from_numpy(adj).float().to(device)

imputed_data_our = model.impute(datas, cond_mask, adj, 10)
imputed_data_our = imputed_data_our.median(1).values
mask = data_ob_masks - cond_mask
imputed_our = imputed_data_our[0][mask.bool().cpu()[0]]
truth = datas[0][mask.bool()[0]].cpu()
[8]:
# DINEOF
from einops import rearrange
from model.dineof_per_step import DINEOF
is_sea = np.load("./data/MEXICO/is_sea.npy")
H, W = is_sea.shape
model = DINEOF(10, [H, W], keep_non_negative_only=False)
datas_image = torch.zeros(1,46,1,H,W)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,H,W)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,H,W)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

impute_data_list = []
for t in range(datas_image.shape[1]):
    data = datas_image[:, t, :, :, :].squeeze()

    tmp_data = torch.where(cond_mask_image[:,t].cpu().squeeze()==0, float("nan"), data.cpu())
    model.fit(tmp_data.numpy())

    imputed_data = model.predict()
    imputed_data = rearrange(imputed_data, "(b t c h) w->b t c h w", b=1, t=1, c=1, h=data.shape[-2], w=data.shape[-1])
    impute_data_list.append(torch.from_numpy(imputed_data))

imputed_data = torch.cat(impute_data_list, dim=1).numpy()
2025-05-08 16:00:05.038 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 167: 0.0024114695843309164, 9.750714525580406e-06
2025-05-08 16:00:05.187 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 160: 0.0022046086378395557, 9.918585419654846e-06
2025-05-08 16:00:05.335 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 160: 0.0023427773267030716, 9.896699339151382e-06
2025-05-08 16:00:05.469 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 142: 0.0021681818179786205, 9.95188020169735e-06
2025-05-08 16:00:05.634 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 176: 0.0022749281488358974, 9.920215234160423e-06
2025-05-08 16:00:05.808 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 187: 0.0016638115048408508, 9.958515875041485e-06
2025-05-08 16:00:05.915 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 115: 0.0013774820836260915, 9.813928045332432e-06
2025-05-08 16:00:06.050 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 145: 0.0017048983136191964, 9.840703569352627e-06
2025-05-08 16:00:06.179 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 134: 0.001551663619466126, 9.983428753912449e-06
2025-05-08 16:00:06.323 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 141: 0.001512625953182578, 9.836163371801376e-06
2025-05-08 16:00:06.402 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 82: 0.0012235901085659862, 9.976094588637352e-06
2025-05-08 16:00:06.525 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 132: 0.001823695027269423, 9.786686860024929e-06
2025-05-08 16:00:06.588 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 65: 0.00048112511285580695, 9.861571015790105e-06
2025-05-08 16:00:06.691 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 110: 0.0015023669693619013, 9.911949746310711e-06
2025-05-08 16:00:06.834 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 150: 0.0019509911071509123, 9.765848517417908e-06
2025-05-08 16:00:06.989 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 164: 0.0020181152503937483, 9.888317435979843e-06
2025-05-08 16:00:07.128 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 141: 0.002009716583415866, 9.838957339525223e-06
2025-05-08 16:00:07.281 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 159: 0.0018364328425377607, 9.771785698831081e-06
2025-05-08 16:00:07.419 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 142: 0.0016502379439771175, 9.858398698270321e-06
2025-05-08 16:00:07.607 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 197: 0.0020804658997803926, 9.97120514512062e-06
2025-05-08 16:00:07.843 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 249: 0.002218526089563966, 9.85688529908657e-06
2025-05-08 16:00:08.008 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 174: 0.002242934424430132, 9.982846677303314e-06
2025-05-08 16:00:08.182 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 186: 0.0021157069131731987, 9.956536814570427e-06
2025-05-08 16:00:08.333 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 160: 0.00217120791785419, 9.941868484020233e-06
2025-05-08 16:00:08.490 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 166: 0.0020256678108125925, 9.973067790269852e-06
2025-05-08 16:00:08.645 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 164: 0.0021624802611768246, 9.7593292593956e-06
2025-05-08 16:00:08.846 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 214: 0.0025014879647642374, 9.731389582157135e-06
2025-05-08 16:00:09.021 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 188: 0.002268986077979207, 9.99821349978447e-06
2025-05-08 16:00:09.175 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 163: 0.002355672651901841, 9.834766387939453e-06
2025-05-08 16:00:09.349 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 185: 0.0027145289350301027, 9.778188541531563e-06
2025-05-08 16:00:09.493 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 155: 0.0028205476701259613, 9.940238669514656e-06
2025-05-08 16:00:09.668 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 186: 0.002752379048615694, 9.737908840179443e-06
2025-05-08 16:00:09.889 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 237: 0.0023082986008375883, 9.797047823667526e-06
2025-05-08 16:00:10.055 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 179: 0.00194880785420537, 9.716721251606941e-06
2025-05-08 16:00:10.267 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 215: 0.00289082620292902, 9.587500244379044e-06
2025-05-08 16:00:10.473 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 220: 0.0019433313282206655, 9.993440471589565e-06
2025-05-08 16:00:10.591 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 127: 0.0022772825323045254, 9.901123121380806e-06
2025-05-08 16:00:10.737 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 156: 0.002538780216127634, 9.990530088543892e-06
2025-05-08 16:00:10.912 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 190: 0.002722960663959384, 9.739305824041367e-06
2025-05-08 16:00:11.088 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 185: 0.0022970004938542843, 9.990297257900238e-06
2025-05-08 16:00:11.248 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 171: 0.0022447442170232534, 9.74978320300579e-06
2025-05-08 16:00:11.286 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 40: 0.00034438399598002434, 9.817042155191302e-06
2025-05-08 16:00:11.424 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 151: 0.002266534138470888, 9.848037734627724e-06
2025-05-08 16:00:11.590 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 179: 0.002390695968642831, 9.997515007853508e-06
2025-05-08 16:00:11.727 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 148: 0.0019407444633543491, 9.64419450610876e-06
2025-05-08 16:00:11.876 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 162: 0.0019273994257673621, 9.737559594213963e-06
[9]:
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_our = imputed_our[:]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['STIMP' for i in range(imputed_our.shape[0])])
method.extend(['DINEOF' for i in range(imputed_our.shape[0])])
data = {'truth': np.concatenate([truth.numpy() for i in range(2)]),
        'imputed':np.concatenate([imputed_our.numpy(), imputed_dineof]),
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_our).statistic.item()
pcc2=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
plt.text(0.40, 0.88, 'PCC={:.4f}'.format(pcc2), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#576fa0")
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
[9]:
(array([-3., -2., -1.,  0.,  1.,  2.,  3.]),
 [Text(0, -3.0, '−3'),
  Text(0, -2.0, '−2'),
  Text(0, -1.0, '−1'),
  Text(0, 0.0, '0'),
  Text(0, 1.0, '1'),
  Text(0, 2.0, '2'),
  Text(0, 3.0, '3')])
../../_images/analysis_MEXICO_05-imputation-mexico_6_1.png

Fig 5a: Missing rate is equal to 0.5

[10]:
config.missing_ratio=0.5
from torch.utils.data import DataLoader
test_dloader = DataLoader(PRE8dDataset(config, mode='test'), 1, shuffle=False)
datas, data_ob_masks, data_gt_masks, labels, label_masks = next(iter(test_dloader))
device = "cuda"
datas, data_ob_masks, data_gt_masks, labels, label_masks = datas.float().to(device), data_ob_masks.to(device), data_gt_masks.to(device), labels.to(device), label_masks.to(device)

model = torch.load("./log/imputation/MEXICO/STIMP/best_0.5.pt")
model = model.to(device)
cond_mask = data_gt_masks
imputed_data_our = model.impute(datas, cond_mask, adj, 10)
imputed_data_our = imputed_data_our.median(1).values
mask = data_ob_masks - cond_mask
imputed_our = imputed_data_our[0][mask.bool().cpu()[0]]
truth = datas[0][mask.bool()[0]].cpu()
[11]:
# DINEOF
from einops import rearrange
from model.dineof_per_step import DINEOF
is_sea = np.load("./data/MEXICO/is_sea.npy")
H, W = is_sea.shape
model = DINEOF(10, [H, W], keep_non_negative_only=False)
datas_image = torch.zeros(1,46,1,H,W)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,H,W)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,H,W)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

impute_data_list = []
for t in range(datas_image.shape[1]):
    data = datas_image[:, t, :, :, :].squeeze()

    tmp_data = torch.where(cond_mask_image[:,t].cpu().squeeze()==0, float("nan"), data.cpu())
    model.fit(tmp_data.numpy())

    imputed_data = model.predict()
    imputed_data = rearrange(imputed_data, "(b t c h) w->b t c h w", b=1, t=1, c=1, h=data.shape[-2], w=data.shape[-1])
    impute_data_list.append(torch.from_numpy(imputed_data))

imputed_data = torch.cat(impute_data_list, dim=1).numpy()
2025-05-08 16:04:03.602 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 139: 0.0015157636953517795, 9.98715404421091e-06
2025-05-08 16:04:03.784 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 184: 0.0016669080359861255, 9.927665814757347e-06
2025-05-08 16:04:03.949 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 163: 0.0015358608216047287, 9.816023521125317e-06
2025-05-08 16:04:04.065 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 118: 0.001035201014019549, 9.757932275533676e-06
2025-05-08 16:04:04.186 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 120: 0.0011178131680935621, 9.990064427256584e-06
2025-05-08 16:04:04.319 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 131: 0.0015064311446622014, 9.89111140370369e-06
2025-05-08 16:04:04.440 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 121: 0.0011072386987507343, 9.85502265393734e-06
2025-05-08 16:04:04.547 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 107: 0.0014296036679297686, 9.90403350442648e-06
2025-05-08 16:04:04.655 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 109: 0.0009885813342407346, 9.844312444329262e-06
2025-05-08 16:04:04.755 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 102: 0.0008216698770411313, 9.978015441447496e-06
2025-05-08 16:04:04.846 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 92: 0.0009396987152285874, 9.939365554600954e-06
2025-05-08 16:04:04.992 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 146: 0.0013311404036357999, 9.838840924203396e-06
2025-05-08 16:04:05.037 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 44: 0.00015673169400542974, 9.80004551820457e-06
2025-05-08 16:04:05.133 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 98: 0.0007812091498635709, 9.964627679437399e-06
2025-05-08 16:04:05.246 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 114: 0.0012790574692189693, 9.900424629449844e-06
2025-05-08 16:04:05.376 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 132: 0.00124261318705976, 9.820330888032913e-06
2025-05-08 16:04:05.488 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 114: 0.0011830134317278862, 9.74640715867281e-06
2025-05-08 16:04:05.622 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 130: 0.0012919403379783034, 9.897281415760517e-06
2025-05-08 16:04:05.726 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 103: 0.0009124678908847272, 9.811134077608585e-06
2025-05-08 16:04:05.876 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 150: 0.0017269201343879104, 9.869690984487534e-06
2025-05-08 16:04:06.055 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 173: 0.0018341108225286007, 9.76654700934887e-06
2025-05-08 16:04:06.206 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 137: 0.0014894654741510749, 9.950948879122734e-06
2025-05-08 16:04:06.373 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 161: 0.001414332538843155, 9.828130714595318e-06
2025-05-08 16:04:06.518 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 145: 0.0014833847526460886, 9.912648238241673e-06
2025-05-08 16:04:06.656 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 136: 0.0013469243422150612, 9.737210348248482e-06
2025-05-08 16:04:06.805 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 149: 0.0015318260993808508, 9.987270459532738e-06
2025-05-08 16:04:06.955 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 150: 0.0014535944210365415, 9.937095455825329e-06
2025-05-08 16:04:07.113 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 159: 0.0015478064306080341, 9.857816621661186e-06
2025-05-08 16:04:07.237 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 124: 0.001629698439501226, 9.963172487914562e-06
2025-05-08 16:04:07.397 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 160: 0.0018093626713380218, 9.837327525019646e-06
2025-05-08 16:04:07.540 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 145: 0.001903114141896367, 9.949319064617157e-06
2025-05-08 16:04:07.696 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 157: 0.0018188718240708113, 9.960262104868889e-06
2025-05-08 16:04:07.846 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 151: 0.0018710327567532659, 9.846757166087627e-06
2025-05-08 16:04:08.001 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 160: 0.0015902334125712514, 9.818468242883682e-06
2025-05-08 16:04:08.147 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 149: 0.002163328230381012, 9.551877155900002e-06
2025-05-08 16:04:08.301 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 157: 0.0014534150250256062, 9.987270459532738e-06
2025-05-08 16:04:08.434 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 135: 0.0013811803655698895, 9.677489288151264e-06
2025-05-08 16:04:08.572 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 143: 0.0014524502912536263, 9.977375157177448e-06
2025-05-08 16:04:08.719 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 150: 0.0017091295449063182, 9.905779734253883e-06
2025-05-08 16:04:08.850 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 132: 0.0019908843096345663, 9.979819878935814e-06
2025-05-08 16:04:08.986 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 140: 0.00110264599788934, 9.981915354728699e-06
2025-05-08 16:04:09.028 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 44: 0.0001312033273279667, 9.238312486559153e-06
2025-05-08 16:04:09.168 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 144: 0.0019513461738824844, 9.915558621287346e-06
2025-05-08 16:04:09.325 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 161: 0.0015245361719280481, 9.820330888032913e-06
2025-05-08 16:04:09.468 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 145: 0.0016904419753700495, 9.934534318745136e-06
2025-05-08 16:04:09.604 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 139: 0.0016243739519268274, 9.85502265393734e-06
[12]:
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_our = imputed_our[:]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['STIMP' for i in range(imputed_our.shape[0])])
method.extend(['DINEOF' for i in range(imputed_our.shape[0])])
data = {'truth': np.concatenate([truth.numpy() for i in range(2)]),
        'imputed':np.concatenate([imputed_our.numpy(), imputed_dineof]),
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_our).statistic.item()
pcc2=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
plt.text(0.40, 0.88, 'PCC={:.4f}'.format(pcc2), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#576fa0")
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
[12]:
(array([-3., -2., -1.,  0.,  1.,  2.,  3.]),
 [Text(0, -3.0, '−3'),
  Text(0, -2.0, '−2'),
  Text(0, -1.0, '−1'),
  Text(0, 0.0, '0'),
  Text(0, 1.0, '1'),
  Text(0, 2.0, '2'),
  Text(0, 3.0, '3')])
../../_images/analysis_MEXICO_05-imputation-mexico_10_1.png

Fig 5a: Missing rate is equal to 0.9

[14]:
config.missing_ratio=0.9
from torch.utils.data import DataLoader
test_dloader = DataLoader(PRE8dDataset(config, mode='test'), 1, shuffle=False)
datas, data_ob_masks, data_gt_masks, labels, label_masks = next(iter(test_dloader))
device = "cuda"
datas, data_ob_masks, data_gt_masks, labels, label_masks = datas.float().to(device), data_ob_masks.to(device), data_gt_masks.to(device), labels.to(device), label_masks.to(device)

model = torch.load("./log/imputation/MEXICO/STIMP/best_0.9.pt")
model = model.to(device)
cond_mask = data_gt_masks
imputed_data_our = model.impute(datas, cond_mask, adj, 10)
imputed_data_our = imputed_data_our.median(1).values
mask = data_ob_masks - cond_mask
imputed_our = imputed_data_our[0][mask.bool().cpu()[0]]
truth = datas[0][mask.bool()[0]].cpu()
[15]:
from einops import rearrange
from model.dineof_per_step import DINEOF
is_sea = np.load("./data/MEXICO/is_sea.npy")
H, W = is_sea.shape
model = DINEOF(10, [H, W], keep_non_negative_only=False)
datas_image = torch.zeros(1,46,1,H,W)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,H,W)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,H,W)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

impute_data_list = []
for t in range(datas_image.shape[1]):
    data = datas_image[:, t, :, :, :].squeeze()

    tmp_data = torch.where(cond_mask_image[:,t].cpu().squeeze()==0, float("nan"), data.cpu())
    model.fit(tmp_data.numpy())

    imputed_data = model.predict()
    imputed_data = rearrange(imputed_data, "(b t c h) w->b t c h w", b=1, t=1, c=1, h=data.shape[-2], w=data.shape[-1])
    impute_data_list.append(torch.from_numpy(imputed_data))

imputed_data = torch.cat(impute_data_list, dim=1).numpy()
2025-05-08 16:07:26.927 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 58: 0.00015187902317848057, 9.929281077347696e-06
2025-05-08 16:07:27.000 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 74: 0.0002237727167084813, 9.703464456833899e-06
2025-05-08 16:07:27.057 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 57: 0.00017613488307688385, 9.938885341398418e-06
2025-05-08 16:07:27.113 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 57: 8.697685552760959e-05, 9.286239219363779e-06
2025-05-08 16:07:27.158 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 44: 0.00036218654713593423, 8.104776497930288e-06
2025-05-08 16:07:27.228 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 71: 4.801901013706811e-05, 7.680551789235324e-06
2025-05-08 16:07:27.270 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 42: 0.0001388771052006632, 5.718000466004014e-06
2025-05-08 16:07:27.326 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 57: 0.00025041893240995705, 9.762414265424013e-06
2025-05-08 16:07:27.359 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 32: 0.0004084364336449653, 3.939931048080325e-06
2025-05-08 16:07:27.407 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 49: 0.00017096057126764208, 9.526629582978785e-06
2025-05-08 16:07:27.444 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 37: 0.00011079287651227787, 9.134848369285464e-06
2025-05-08 16:07:27.510 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 69: 0.0001042808944475837, 9.451672667637467e-06
2025-05-08 16:07:27.548 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 39: 0.0001796404249034822, 8.186689228750765e-06
2025-05-08 16:07:27.592 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 45: 0.000145958925713785, 9.630588465370238e-06
2025-05-08 16:07:27.636 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 44: 0.0001175681027234532, 9.213843441102654e-06
2025-05-08 16:07:27.681 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 45: 0.00022565377003047615, 9.378723916597664e-06
2025-05-08 16:07:27.724 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 43: 0.00015027915651444346, 9.112394764088094e-06
2025-05-08 16:07:27.766 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 43: 9.944717749021947e-05, 9.269810107070953e-06
2025-05-08 16:07:27.808 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 41: 9.884594328468665e-05, 9.866300388239324e-06
2025-05-08 16:07:27.848 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 40: 0.0005300716729834676, 6.018788553774357e-06
2025-05-08 16:07:27.923 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 74: 0.000187698271474801, 9.872252121567726e-06
2025-05-08 16:07:27.962 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 37: 0.0007283929153345525, 4.24310564994812e-06
2025-05-08 16:07:28.018 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 57: 0.00020205293549224734, 9.961382602341473e-06
2025-05-08 16:07:28.064 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 46: 0.0005271050613373518, 7.425551302731037e-07
2025-05-08 16:07:28.121 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 58: 0.00021120367455296218, 9.574097930453718e-06
2025-05-08 16:07:28.186 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 67: 0.0001723091263556853, 9.87622479442507e-06
2025-05-08 16:07:28.227 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 40: 0.00039949396159499884, 3.246619598940015e-06
2025-05-08 16:07:28.284 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 59: 0.00012777798110619187, 9.618699550628662e-06
2025-05-08 16:07:28.339 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 56: 0.00020881296950392425, 9.725990821607411e-06
2025-05-08 16:07:28.397 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 59: 0.00014268478844314814, 9.572409908287227e-06
2025-05-08 16:07:28.454 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 58: 0.0001750936935422942, 9.582261554896832e-06
2025-05-08 16:07:28.521 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 67: 0.00024027546169236302, 9.76555747911334e-06
2025-05-08 16:07:28.580 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 60: 0.00017896997451316565, 9.778246749192476e-06
2025-05-08 16:07:28.637 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 58: 0.000167188816703856, 9.616589522920549e-06
2025-05-08 16:07:28.695 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 59: 0.00041574350325390697, 9.605544619262218e-06
2025-05-08 16:07:28.762 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 69: 0.00020687095820903778, 9.370458428747952e-06
2025-05-08 16:07:28.824 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 64: 0.0001603587734280154, 9.606665116734803e-06
2025-05-08 16:07:28.880 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 57: 0.0002463874698150903, 9.479583241045475e-06
2025-05-08 16:07:28.945 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 63: 0.0001788117951946333, 9.628449333831668e-06
2025-05-08 16:07:28.999 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 51: 0.0001592903572600335, 9.383860742673278e-06
2025-05-08 16:07:29.055 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 56: 0.00020430688164196908, 9.535520803183317e-06
2025-05-08 16:07:29.074 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 18: 3.636835390352644e-05, 8.21973299025558e-06
2025-05-08 16:07:29.127 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 55: 0.0001306253398070112, 9.472758392803371e-06
2025-05-08 16:07:29.178 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 52: 0.00018380387336947024, 9.648618288338184e-06
2025-05-08 16:07:29.241 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 64: 0.0001010614141705446, 9.077521099243313e-06
2025-05-08 16:07:29.290 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 48: 0.00022896577138453722, 3.043169272132218e-06
[16]:
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_our = imputed_our[:]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['STIMP' for i in range(imputed_our.shape[0])])
method.extend(['DINEOF' for i in range(imputed_our.shape[0])])
data = {'truth': np.concatenate([truth.numpy() for i in range(2)]),
        'imputed':np.concatenate([imputed_our.numpy(), imputed_dineof]),
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_our).statistic.item()
pcc2=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
plt.text(0.40, 0.88, 'PCC={:.4f}'.format(pcc2), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#576fa0")
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
[16]:
(array([-2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5]),
 [Text(0, -2.0, '−2.0'),
  Text(0, -1.5, '−1.5'),
  Text(0, -1.0, '−1.0'),
  Text(0, -0.5, '−0.5'),
  Text(0, 0.0, '0.0'),
  Text(0, 0.5, '0.5'),
  Text(0, 1.0, '1.0'),
  Text(0, 1.5, '1.5'),
  Text(0, 2.0, '2.0'),
  Text(0, 2.5, '2.5')])
../../_images/analysis_MEXICO_05-imputation-mexico_14_1.png