Imputation with trained STIMP in the Chesapeake Bay
[1]:
import torch
import numpy as np
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd(), ".."))
print(sys.path)
from dataset.dataset_imputation import PRE8dDataset
['/home/mafzhang/code/STIMP/..', '/home/mafzhang/miniconda3/envs/torch/lib/python39.zip', '/home/mafzhang/miniconda3/envs/torch/lib/python3.9', '/home/mafzhang/miniconda3/envs/torch/lib/python3.9/lib-dynload', '', '/home/mafzhang/.local/lib/python3.9/site-packages', '/home/mafzhang/miniconda3/envs/torch/lib/python3.9/site-packages']
[2]:
import argparse
parser = argparse.ArgumentParser(description='Imputation')
# args for area and methods
parser.add_argument('--area', type=str, default='Chesapeake', help='which bay area we focus')
# basic args
parser.add_argument('--epochs', type=int, default=500, help='epochs')
parser.add_argument('--batch_size', type=int, default=16, help='batch size')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--wd', type=float, default=1e-4, help='weight decay')
parser.add_argument('--test_freq', type=int, default=500, help='test per n epochs')
parser.add_argument('--embedding_size', type=int, default=32)
parser.add_argument('--hidden_channels', type=int, default=32)
parser.add_argument('--diffusion_embedding_size', type=int, default=64)
parser.add_argument('--side_channels', type=int, default=1)
# args for tasks
parser.add_argument('--in_len', type=int, default=46)
parser.add_argument('--out_len', type=int, default=46)
parser.add_argument('--missing_ratio', type=float, default=0.1)
# args for diffusion
parser.add_argument('--beta_start', type=float, default=0.0001, help='beta start from this')
parser.add_argument('--beta_end', type=float, default=0.2, help='beta end to this')
parser.add_argument('--num_steps', type=float, default=50, help='denoising steps')
parser.add_argument('--num_samples', type=int, default=10, help='n datasets')
parser.add_argument('--schedule', type=str, default='quad', help='noise schedule type')
parser.add_argument('--target_strategy', type=str, default='random', help='mask')
# args for mae
parser.add_argument('--num_heads', type=int, default=8, help='n heads for self attention')
config = parser.parse_args([])
Fig 5d: Missing rate is equal to 0.1
[4]:
from torch.utils.data import DataLoader
test_dloader = DataLoader(PRE8dDataset(config, mode='test'), 1, shuffle=False)
datas, data_ob_masks, data_gt_masks, labels, label_masks = next(iter(test_dloader))
device = "cuda"
datas, data_ob_masks, data_gt_masks, labels, label_masks = datas.float().to(device), data_ob_masks.to(device), data_gt_masks.to(device), labels.to(device), label_masks.to(device)
# load model
model = torch.load("./log_bak/imputation/Chesapeake/STIMP/best_0.1.pt")
model = model.to(device)
cond_mask = data_gt_masks
adj = np.load("./data/{}/adj.npy".format(config.area))
adj = torch.from_numpy(adj).float().to(device)
imputed_data_our = model.impute(datas, cond_mask, adj, 10)
imputed_data_our = imputed_data_our.median(1).values
mask = data_ob_masks - cond_mask
imputed_our = imputed_data_our[0][mask.bool().cpu()[0]]
truth = datas[0][mask.bool()[0]].cpu()
[6]:
# DINEOF
from einops import rearrange
from model.dineof_per_step import DINEOF
is_sea = np.load("./data/Chesapeake/is_sea.npy")
H, W = is_sea.shape
model = DINEOF(10, [H, W], keep_non_negative_only=False)
datas_image = torch.zeros(1,46,1,H,W)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,H,W)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,H,W)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks
impute_data_list = []
for t in range(datas_image.shape[1]):
data = datas_image[:, t, :, :, :].squeeze()
tmp_data = torch.where(cond_mask_image[:,t].cpu().squeeze()==0, float("nan"), data.cpu())
model.fit(tmp_data.numpy())
imputed_data = model.predict()
imputed_data = rearrange(imputed_data, "(b t c h) w->b t c h w", b=1, t=1, c=1, h=data.shape[-2], w=data.shape[-1])
impute_data_list.append(torch.from_numpy(imputed_data))
imputed_data = torch.cat(impute_data_list, dim=1).numpy()
2025-05-08 16:22:39.386 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 181: 0.002368027111515403, 9.766314178705215e-06
2025-05-08 16:22:39.542 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 175: 0.002141791395843029, 9.954441338777542e-06
2025-05-08 16:22:39.686 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 163: 0.0018823777791112661, 9.810784831643105e-06
2025-05-08 16:22:39.748 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 69: 0.0008081254782155156, 9.835290256887674e-06
2025-05-08 16:22:39.898 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 170: 0.002310964046046138, 9.95071604847908e-06
2025-05-08 16:22:40.074 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 200: 0.0021930246148258448, 9.937677532434464e-06
2025-05-08 16:22:40.193 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 134: 0.0019664261490106583, 9.924173355102539e-06
2025-05-08 16:22:40.346 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 172: 0.0026422597002238035, 9.846873581409454e-06
2025-05-08 16:22:40.505 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 178: 0.0028323030564934015, 9.594950824975967e-06
2025-05-08 16:22:40.623 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 133: 0.0019033786375075579, 9.8405871540308e-06
2025-05-08 16:22:40.767 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 165: 0.002204755088314414, 9.584939107298851e-06
2025-05-08 16:22:40.867 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 112: 0.002517536748200655, 9.579816833138466e-06
2025-05-08 16:22:40.990 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 139: 0.002249656245112419, 9.963056072592735e-06
2025-05-08 16:22:41.114 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 141: 0.001943048438988626, 9.954790584743023e-06
2025-05-08 16:22:41.235 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 137: 0.002023356733843684, 9.590527042746544e-06
2025-05-08 16:22:41.345 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 124: 0.0019023623317480087, 9.947107173502445e-06
2025-05-08 16:22:41.481 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 154: 0.0024012825451791286, 9.997515007853508e-06
2025-05-08 16:22:41.636 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 176: 0.0018642169889062643, 9.915675036609173e-06
2025-05-08 16:22:41.766 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 147: 0.0019493672298267484, 9.875628165900707e-06
2025-05-08 16:22:41.896 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 149: 0.001893012085929513, 9.901821613311768e-06
2025-05-08 16:22:42.017 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 135: 0.0017551361816003919, 9.988783858716488e-06
2025-05-08 16:22:42.140 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 138: 0.0020047093275934458, 9.95607115328312e-06
2025-05-08 16:22:42.287 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 168: 0.0018259527860209346, 9.988667443394661e-06
2025-05-08 16:22:42.420 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 152: 0.0019529719138517976, 9.699608199298382e-06
2025-05-08 16:22:42.548 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 146: 0.0022290535271167755, 9.95676964521408e-06
2025-05-08 16:22:42.697 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 170: 0.002007921226322651, 9.960029274225235e-06
2025-05-08 16:22:42.835 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 158: 0.0019222581759095192, 9.820680133998394e-06
2025-05-08 16:22:42.987 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 171: 0.0028489846736192703, 9.981216862797737e-06
2025-05-08 16:22:43.124 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 155: 0.002163659781217575, 9.910902008414268e-06
2025-05-08 16:22:43.228 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 113: 0.0012657925253733993, 9.859679266810417e-06
2025-05-08 16:22:43.374 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 164: 0.0022976298350840807, 9.872950613498688e-06
2025-05-08 16:22:43.502 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 144: 0.0027964538894593716, 9.72929410636425e-06
2025-05-08 16:22:43.634 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 148: 0.00252650398761034, 9.96817834675312e-06
2025-05-08 16:22:43.853 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 244: 0.0023826800752431154, 9.979819878935814e-06
2025-05-08 16:22:43.985 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 151: 0.002620214596390724, 9.869225323200226e-06
2025-05-08 16:22:44.125 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 159: 0.0022313857916742563, 9.898561984300613e-06
2025-05-08 16:22:44.265 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 162: 0.0018916797125712037, 9.994604624807835e-06
2025-05-08 16:22:44.487 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 252: 0.0019510366255417466, 9.939656592905521e-06
2025-05-08 16:22:44.550 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 69: 0.0053041549399495125, 9.514391422271729e-06
2025-05-08 16:22:44.725 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 200: 0.00248316815122962, 9.713927283883095e-06
2025-05-08 16:22:44.879 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 178: 0.001902697840705514, 9.977840818464756e-06
2025-05-08 16:22:44.984 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 120: 0.0015478215645998716, 9.942799806594849e-06
2025-05-08 16:22:45.165 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 207: 0.001961691305041313, 9.938608855009079e-06
2025-05-08 16:22:45.303 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 154: 0.001942677074111998, 9.635929018259048e-06
2025-05-08 16:22:45.435 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 150: 0.00280376011505723, 9.921612218022346e-06
2025-05-08 16:22:45.580 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 164: 0.0019110547145828605, 9.842566214501858e-06
[7]:
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_our = imputed_our[:]
imputed_dineof = imputed_dineof[:]
truth = truth[:]
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['STIMP' for i in range(imputed_our.shape[0])])
method.extend(['DINEOF' for i in range(imputed_our.shape[0])])
data = {'truth': np.concatenate([truth.numpy() for i in range(2)]),
'imputed':np.concatenate([imputed_our.numpy(), imputed_dineof]),
'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_our).statistic.item()
pcc2=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
plt.text(0.40, 0.88, 'PCC={:.4f}'.format(pcc2), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#576fa0")
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
[7]:
(array([-1.5, -1. , -0.5, 0. , 0.5, 1. , 1.5, 2. , 2.5]),
[Text(0, -1.5, '−1.5'),
Text(0, -1.0, '−1.0'),
Text(0, -0.5, '−0.5'),
Text(0, 0.0, '0.0'),
Text(0, 0.5, '0.5'),
Text(0, 1.0, '1.0'),
Text(0, 1.5, '1.5'),
Text(0, 2.0, '2.0'),
Text(0, 2.5, '2.5')])
Fig 5d: Missing rate is equal to 0.5
[9]:
config.missing_ratio=0.5
from torch.utils.data import DataLoader
test_dloader = DataLoader(PRE8dDataset(config, mode='test'), 1, shuffle=False)
datas, data_ob_masks, data_gt_masks, labels, label_masks = next(iter(test_dloader))
device = "cuda"
datas, data_ob_masks, data_gt_masks, labels, label_masks = datas.float().to(device), data_ob_masks.to(device), data_gt_masks.to(device), labels.to(device), label_masks.to(device)
model = torch.load("./log_bak/imputation/Chesapeake/STIMP/best_0.5.pt")
model = model.to(device)
cond_mask = data_gt_masks
imputed_data_our = model.impute(datas, cond_mask, adj, 10)
imputed_data_our = imputed_data_our.median(1).values
mask = data_ob_masks - cond_mask
imputed_our = imputed_data_our[0][mask.bool().cpu()[0]]
truth = datas[0][mask.bool()[0]].cpu()
[10]:
from einops import rearrange
from model.dineof_per_step import DINEOF
is_sea = np.load("./data/Chesapeake/is_sea.npy")
H, W = is_sea.shape
model = DINEOF(10, [H, W], keep_non_negative_only=False)
datas_image = torch.zeros(1,46,1,H,W)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,H,W)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,H,W)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks
impute_data_list = []
for t in range(datas_image.shape[1]):
data = datas_image[:, t, :, :, :].squeeze()
tmp_data = torch.where(cond_mask_image[:,t].cpu().squeeze()==0, float("nan"), data.cpu())
model.fit(tmp_data.numpy())
imputed_data = model.predict()
imputed_data = rearrange(imputed_data, "(b t c h) w->b t c h w", b=1, t=1, c=1, h=data.shape[-2], w=data.shape[-1])
impute_data_list.append(torch.from_numpy(imputed_data))
imputed_data = torch.cat(impute_data_list, dim=1).numpy()
2025-05-08 16:24:09.692 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 159: 0.0017314578872174025, 9.730109013617039e-06
2025-05-08 16:24:09.819 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 140: 0.0013141802046447992, 9.925337508320808e-06
2025-05-08 16:24:09.946 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 139: 0.0011402089148759842, 9.956886060535908e-06
2025-05-08 16:24:09.998 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 58: 0.0002077865501632914, 9.97736060526222e-06
2025-05-08 16:24:10.126 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 140: 0.0019777193665504456, 9.919516742229462e-06
2025-05-08 16:24:10.254 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 141: 0.0016729923663660884, 9.933137334883213e-06
2025-05-08 16:24:10.360 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 116: 0.0013833654811605811, 9.932555258274078e-06
2025-05-08 16:24:10.491 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 144: 0.0015701536322012544, 9.930459782481194e-06
2025-05-08 16:24:10.642 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 167: 0.0014423009706661105, 9.912881068885326e-06
2025-05-08 16:24:10.762 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 132: 0.0015777467051520944, 9.909039363265038e-06
2025-05-08 16:24:10.869 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 117: 0.0018124505877494812, 9.889365173876286e-06
2025-05-08 16:24:10.979 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 121: 0.0015570628456771374, 9.691575542092323e-06
2025-05-08 16:24:11.106 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 139: 0.001378058921545744, 9.941053576767445e-06
2025-05-08 16:24:11.216 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 119: 0.0011673335684463382, 9.831972420215607e-06
2025-05-08 16:24:11.339 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 134: 0.0014079093234613538, 9.920448064804077e-06
2025-05-08 16:24:11.457 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 130: 0.0012561194598674774, 9.982846677303314e-06
2025-05-08 16:24:11.575 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 129: 0.001343117794021964, 9.854324162006378e-06
2025-05-08 16:24:11.682 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 117: 0.0012106687063351274, 9.76806040853262e-06
2025-05-08 16:24:11.797 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 125: 0.0013469601981341839, 9.960262104868889e-06
2025-05-08 16:24:11.902 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 113: 0.0010996171040460467, 9.777722880244255e-06
2025-05-08 16:24:12.003 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 109: 0.0007880739867687225, 9.949260856956244e-06
2025-05-08 16:24:12.123 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 130: 0.0011422740062698722, 9.816372767090797e-06
2025-05-08 16:24:12.232 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 117: 0.0012993597192689776, 9.9564203992486e-06
2025-05-08 16:24:12.344 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 123: 0.001426519243977964, 9.90019179880619e-06
2025-05-08 16:24:12.448 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 114: 0.0014031181344762444, 9.758980013430119e-06
2025-05-08 16:24:12.573 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 137: 0.001387285185046494, 9.922194294631481e-06
2025-05-08 16:24:12.676 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 113: 0.0018661736976355314, 9.951181709766388e-06
2025-05-08 16:24:12.819 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 159: 0.0016807927750051022, 9.940355084836483e-06
2025-05-08 16:24:12.955 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 150: 0.0015708176651969552, 9.777606464922428e-06
2025-05-08 16:24:13.033 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 86: 0.0006294612539932132, 9.950948879122734e-06
2025-05-08 16:24:13.180 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 163: 0.0015713460743427277, 9.724986739456654e-06
2025-05-08 16:24:13.322 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 156: 0.0016466426895931363, 9.843846783041954e-06
2025-05-08 16:24:13.441 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 130: 0.0014522622805088758, 9.742914699018002e-06
2025-05-08 16:24:13.613 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 191: 0.0015897249104455113, 9.904615581035614e-06
2025-05-08 16:24:13.743 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 143: 0.0013609019806608558, 9.95676964521408e-06
2025-05-08 16:24:13.860 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 129: 0.0019183253170922399, 9.921728633344173e-06
2025-05-08 16:24:13.974 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 125: 0.0013899278128519654, 9.924056939780712e-06
2025-05-08 16:24:14.109 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 149: 0.0014694456476718187, 9.860144928097725e-06
2025-05-08 16:24:14.242 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 147: 0.001763472449965775, 9.818701073527336e-06
2025-05-08 16:24:14.367 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 138: 0.0018544064369052649, 9.933952242136002e-06
2025-05-08 16:24:14.479 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 122: 0.0014743668725714087, 9.922659955918789e-06
2025-05-08 16:24:14.583 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 113: 0.0006675503682345152, 9.782204870134592e-06
2025-05-08 16:24:14.745 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 162: 0.0015250594587996602, 9.976094588637352e-06
2025-05-08 16:24:14.860 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 122: 0.0016461563063785434, 9.995768778026104e-06
2025-05-08 16:24:15.057 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 219: 0.001722136395983398, 9.750830940902233e-06
2025-05-08 16:24:15.173 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 127: 0.0015540922759100795, 9.79006290435791e-06
[11]:
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_our = imputed_our[:]
imputed_dineof = imputed_dineof[:]
truth = truth[:]
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['STIMP' for i in range(imputed_our.shape[0])])
method.extend(['DINEOF' for i in range(imputed_our.shape[0])])
data = {'truth': np.concatenate([truth.numpy() for i in range(2)]),
'imputed':np.concatenate([imputed_our.numpy(), imputed_dineof]),
'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_our).statistic.item()
pcc2=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
plt.text(0.40, 0.88, 'PCC={:.4f}'.format(pcc2), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#576fa0")
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
[11]:
(array([-1.5, -1. , -0.5, 0. , 0.5, 1. , 1.5, 2. , 2.5]),
[Text(0, -1.5, '−1.5'),
Text(0, -1.0, '−1.0'),
Text(0, -0.5, '−0.5'),
Text(0, 0.0, '0.0'),
Text(0, 0.5, '0.5'),
Text(0, 1.0, '1.0'),
Text(0, 1.5, '1.5'),
Text(0, 2.0, '2.0'),
Text(0, 2.5, '2.5')])
Fig 5d: Missing rate is equal to 0.9
[12]:
config.missing_ratio=0.9
from torch.utils.data import DataLoader
test_dloader = DataLoader(PRE8dDataset(config, mode='test'), 1, shuffle=False)
datas, data_ob_masks, data_gt_masks, labels, label_masks = next(iter(test_dloader))
device = "cuda"
datas, data_ob_masks, data_gt_masks, labels, label_masks = datas.float().to(device), data_ob_masks.to(device), data_gt_masks.to(device), labels.to(device), label_masks.to(device)
model = torch.load("./log_bak/imputation/Chesapeake/STIMP/best_0.9.pt")
model = model.to(device)
cond_mask = data_gt_masks
imputed_data_our = model.impute(datas, cond_mask, adj, 10)
imputed_data_our = imputed_data_our.median(1).values
mask = data_ob_masks - cond_mask
imputed_our = imputed_data_our[0][mask.bool().cpu()[0]]
truth = datas[0][mask.bool()[0]].cpu()
[13]:
from einops import rearrange
from model.dineof_per_step import DINEOF
is_sea = np.load("./data/Chesapeake/is_sea.npy")
H, W = is_sea.shape
model = DINEOF(10, [H, W], keep_non_negative_only=False)
datas_image = torch.zeros(1,46,1,H,W)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,H,W)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,H,W)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks
impute_data_list = []
for t in range(datas_image.shape[1]):
data = datas_image[:, t, :, :, :].squeeze()
tmp_data = torch.where(cond_mask_image[:,t].cpu().squeeze()==0, float("nan"), data.cpu())
model.fit(tmp_data.numpy())
imputed_data = model.predict()
imputed_data = rearrange(imputed_data, "(b t c h) w->b t c h w", b=1, t=1, c=1, h=data.shape[-2], w=data.shape[-1])
impute_data_list.append(torch.from_numpy(imputed_data))
imputed_data = torch.cat(impute_data_list, dim=1).numpy()
2025-05-08 16:25:46.860 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 60: 0.00013774332182947546, 9.50467074289918e-06
2025-05-08 16:25:46.913 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 56: 0.0001352647814201191, 9.51731635723263e-06
2025-05-08 16:25:46.937 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 22: 0.007202259264886379, 5.593523383140564e-06
2025-05-08 16:25:46.961 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 26: 4.797338624484837e-05, 8.481751137878746e-06
2025-05-08 16:25:47.023 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 67: 0.00022221436665859073, 9.723269613459706e-06
2025-05-08 16:25:47.087 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 69: 0.00014980521518737078, 9.589522960595787e-06
2025-05-08 16:25:47.189 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 107: 0.00020429532742127776, 9.966766810975969e-06
2025-05-08 16:25:47.251 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 67: 0.0001600992400199175, 9.530049283057451e-06
2025-05-08 16:25:47.309 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 62: 0.0003824131563305855, 9.706564014777541e-06
2025-05-08 16:25:47.385 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 79: 0.0004236020031385124, 9.884592145681381e-06
2025-05-08 16:25:47.469 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 91: 0.0003311572945676744, 9.79736796580255e-06
2025-05-08 16:25:47.538 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 74: 0.0002372031012782827, 9.232040611095726e-06
2025-05-08 16:25:47.598 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 65: 0.00016072219295892864, 9.979601600207388e-06
2025-05-08 16:25:47.663 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 70: 0.00023136925301514566, 9.905328624881804e-06
2025-05-08 16:25:47.725 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 66: 0.0001833214337239042, 9.494164260104299e-06
2025-05-08 16:25:47.779 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 56: 0.00014705469948239625, 9.364623110741377e-06
2025-05-08 16:25:47.830 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 55: 0.00015558071027044207, 9.921423043124378e-06
2025-05-08 16:25:47.882 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 55: 0.0002287646639160812, 9.493494872003794e-06
2025-05-08 16:25:47.947 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 70: 0.00015768100274726748, 9.70863038673997e-06
2025-05-08 16:25:48.001 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 57: 0.0002158077695639804, 9.82517667580396e-06
2025-05-08 16:25:48.046 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 47: 0.00011789263953687623, 9.304123523179442e-06
2025-05-08 16:25:48.096 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 54: 0.00015408174658659846, 9.693350875750184e-06
2025-05-08 16:25:48.153 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 59: 0.0001644960866542533, 9.591953130438924e-06
2025-05-08 16:25:48.206 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 56: 0.0001693888334557414, 9.87284875009209e-06
2025-05-08 16:25:48.258 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 56: 0.0001659023982938379, 9.27771907299757e-06
2025-05-08 16:25:48.323 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 70: 0.00017116148956120014, 9.291557944379747e-06
2025-05-08 16:25:48.415 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 101: 0.00021653684962075204, 9.828043403103948e-06
2025-05-08 16:25:48.467 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 56: 0.00018471057410351932, 9.786526788957417e-06
2025-05-08 16:25:48.515 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 52: 0.00028928351821377873, 8.90000956133008e-06
2025-05-08 16:25:48.558 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 43: 0.000747488287743181, 8.304486982524395e-06
2025-05-08 16:25:48.613 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 59: 0.0001940839138114825, 9.864699677564204e-06
2025-05-08 16:25:48.675 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 67: 0.00014330112026073039, 9.880095603875816e-06
2025-05-08 16:25:48.737 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 65: 0.00018797042139340192, 9.63297497946769e-06
2025-05-08 16:25:48.795 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 63: 0.0002210592938354239, 9.56872827373445e-06
2025-05-08 16:25:48.859 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 68: 0.00037349379272200167, 9.214330930262804e-06
2025-05-08 16:25:48.923 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 65: 0.00015307463763747364, 9.705196134746075e-06
2025-05-08 16:25:48.976 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 57: 0.000520966190379113, 9.903335012495518e-06
2025-05-08 16:25:49.033 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 61: 0.00022527057444676757, 9.968134691007435e-06
2025-05-08 16:25:49.093 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 65: 0.00016766160842962563, 9.550771210342646e-06
2025-05-08 16:25:49.161 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 73: 0.0003789346374105662, 9.836279787123203e-06
2025-05-08 16:25:49.225 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 68: 0.00023893857724033296, 9.76253068074584e-06
2025-05-08 16:25:49.288 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 70: 3.1005638447823e-05, 7.696908141952008e-06
2025-05-08 16:25:49.342 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 59: 0.00017916885553859174, 9.662500815466046e-06
2025-05-08 16:25:49.396 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 57: 0.00033074163366109133, 9.978772141039371e-06
2025-05-08 16:25:49.452 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 61: 0.00014282879419624805, 9.454699466004968e-06
2025-05-08 16:25:49.513 | INFO | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 66: 0.00012892748054582626, 9.723647963255644e-06
[14]:
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_our = imputed_our[:]
imputed_dineof = imputed_dineof[:]
truth = truth[:]
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['STIMP' for i in range(imputed_our.shape[0])])
method.extend(['DINEOF' for i in range(imputed_our.shape[0])])
data = {'truth': np.concatenate([truth.numpy() for i in range(2)]),
'imputed':np.concatenate([imputed_our.numpy(), imputed_dineof]),
'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_our).statistic.item()
pcc2=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
plt.text(0.40, 0.88, 'PCC={:.4f}'.format(pcc2), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#576fa0")
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
[14]:
(array([-1.5, -1. , -0.5, 0. , 0.5, 1. , 1.5, 2. , 2.5]),
[Text(0, -1.5, '−1.5'),
Text(0, -1.0, '−1.0'),
Text(0, -0.5, '−0.5'),
Text(0, 0.0, '0.0'),
Text(0, 0.5, '0.5'),
Text(0, 1.0, '1.0'),
Text(0, 1.5, '1.5'),
Text(0, 2.0, '2.0'),
Text(0, 2.5, '2.5')])