Imputation with trained STIMP in the Yangtze River Estuary

[1]:
import torch
import numpy as np
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd(), ".."))
print(sys.path)
from dataset.dataset_imputation import PRE8dDataset
['/home/mafzhang/code/STIMP/..', '/home/mafzhang/miniconda3/envs/torch/lib/python39.zip', '/home/mafzhang/miniconda3/envs/torch/lib/python3.9', '/home/mafzhang/miniconda3/envs/torch/lib/python3.9/lib-dynload', '', '/home/mafzhang/.local/lib/python3.9/site-packages', '/home/mafzhang/miniconda3/envs/torch/lib/python3.9/site-packages']
[2]:
import argparse
parser = argparse.ArgumentParser(description='Imputation')

# args for area and methods
parser.add_argument('--area', type=str, default='Yangtze', help='which bay area we focus')

# basic args
parser.add_argument('--epochs', type=int, default=500, help='epochs')
parser.add_argument('--batch_size', type=int, default=16, help='batch size')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--wd', type=float, default=1e-4, help='weight decay')
parser.add_argument('--test_freq', type=int, default=500, help='test per n epochs')
parser.add_argument('--embedding_size', type=int, default=32)
parser.add_argument('--hidden_channels', type=int, default=32)
parser.add_argument('--diffusion_embedding_size', type=int, default=64)
parser.add_argument('--side_channels', type=int, default=1)

# args for tasks
parser.add_argument('--in_len', type=int, default=46)
parser.add_argument('--out_len', type=int, default=46)
parser.add_argument('--missing_ratio', type=float, default=0.1)

# args for diffusion
parser.add_argument('--beta_start', type=float, default=0.0001, help='beta start from this')
parser.add_argument('--beta_end', type=float, default=0.2, help='beta end to this')
parser.add_argument('--num_steps', type=float, default=50, help='denoising steps')
parser.add_argument('--num_samples', type=int, default=10, help='n datasets')
parser.add_argument('--schedule', type=str, default='quad', help='noise schedule type')
parser.add_argument('--target_strategy', type=str, default='random', help='mask')

# args for mae
parser.add_argument('--num_heads', type=int, default=8, help='n heads for self attention')
config = parser.parse_args([])

Fig 5g: Missing rate is equal to 0.1

[3]:
from torch.utils.data import DataLoader
test_dloader = DataLoader(PRE8dDataset(config, mode='test'), 1, shuffle=False)
datas, data_ob_masks, data_gt_masks, labels, label_masks = next(iter(test_dloader))
device = "cuda"
datas, data_ob_masks, data_gt_masks, labels, label_masks = datas.float().to(device), data_ob_masks.to(device), data_gt_masks.to(device), labels.to(device), label_masks.to(device)

# load model
model = torch.load("./log_bak/imputation/Yangtze/STIMP/best_0.1.pt")
model = model.to(device)
cond_mask = data_gt_masks
adj = np.load("./data/{}/adj.npy".format(config.area))
adj = torch.from_numpy(adj).float().to(device)

imputed_data_our = model.impute(datas, cond_mask, adj, 10)
imputed_data_our = imputed_data_our.median(1).values
mask = data_ob_masks - cond_mask
imputed_our = imputed_data_our[0][mask.bool().cpu()[0]]
truth = datas[0][mask.bool()[0]].cpu()
[5]:
# DINEOF
from einops import rearrange
from model.dineof_per_step import DINEOF
is_sea = np.load("./data/Yangtze/is_sea.npy")
H, W = is_sea.shape
model = DINEOF(10, [H, W], keep_non_negative_only=False)
datas_image = torch.zeros(1,46,1,H,W)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,H,W)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,H,W)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

impute_data_list = []
for t in range(datas_image.shape[1]):
    data = datas_image[:, t, :, :, :].squeeze()

    tmp_data = torch.where(cond_mask_image[:,t].cpu().squeeze()==0, float("nan"), data.cpu())
    model.fit(tmp_data.numpy())

    imputed_data = model.predict()
    imputed_data = rearrange(imputed_data, "(b t c h) w->b t c h w", b=1, t=1, c=1, h=data.shape[-2], w=data.shape[-1])
    impute_data_list.append(torch.from_numpy(imputed_data))

imputed_data = torch.cat(impute_data_list, dim=1).numpy()
2025-05-08 16:42:57.570 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 21: 4.708132109954022e-05, 8.019527740543708e-06
2025-05-08 16:42:57.701 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 112: 0.0010282495059072971, 9.766197763383389e-06
2025-05-08 16:42:57.773 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 71: 0.0007443492067977786, 9.97487222775817e-06
2025-05-08 16:42:57.776 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 4.399793951392894e-08, 4.399793951392894e-08
2025-05-08 16:42:57.906 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 123: 0.0016074057202786207, 9.941053576767445e-06
2025-05-08 16:42:58.003 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 102: 0.0008586125331930816, 9.865849278867245e-06
2025-05-08 16:42:58.130 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 124: 0.0017285322537645698, 9.996467269957066e-06
2025-05-08 16:42:58.227 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 98: 0.0008929798495955765, 9.876443073153496e-06
2025-05-08 16:42:58.326 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 101: 0.0013608409790322185, 9.842333383858204e-06
2025-05-08 16:42:58.392 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 68: 0.001529828761704266, 9.713578037917614e-06
2025-05-08 16:42:58.502 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 109: 0.0025927559472620487, 9.823590517044067e-06
2025-05-08 16:42:58.634 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 132: 0.0012813438661396503, 9.85537189990282e-06
2025-05-08 16:42:58.834 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 201: 0.001912857056595385, 9.500537998974323e-06
2025-05-08 16:42:58.927 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 89: 0.0013037541648373008, 9.968411177396774e-06
2025-05-08 16:42:59.056 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 129: 0.0009492257377132773, 9.716430213302374e-06
2025-05-08 16:42:59.190 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 135: 0.0013640134129673243, 9.834766387939453e-06
2025-05-08 16:42:59.215 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 24: 7.357676076935604e-05, 9.848386980593204e-06
2025-05-08 16:42:59.281 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 67: 0.0006225646357052028, 9.807816240936518e-06
2025-05-08 16:42:59.296 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 13: 0.0002542529546190053, 6.905669579282403e-06
2025-05-08 16:42:59.319 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 22: 0.0004059072816744447, 9.747775038704276e-06
2025-05-08 16:42:59.427 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 108: 0.001281833741813898, 9.945477358996868e-06
2025-05-08 16:42:59.545 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 118: 0.00228595151565969, 9.91998240351677e-06
2025-05-08 16:42:59.631 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 75: 0.003958098590373993, 9.817536920309067e-06
2025-05-08 16:42:59.782 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 150: 0.0019288202747702599, 9.90286935120821e-06
2025-05-08 16:42:59.993 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 190: 0.001833006739616394, 9.884708561003208e-06
2025-05-08 16:43:00.132 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 139: 0.001557467970997095, 9.996467269957066e-06
2025-05-08 16:43:00.192 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 60: 0.0008546045864932239, 9.848503395915031e-06
2025-05-08 16:43:00.291 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 86: 0.0026618584524840117, 9.825685992836952e-06
2025-05-08 16:43:00.496 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 192: 0.001992049627006054, 9.892042726278305e-06
2025-05-08 16:43:00.717 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 221: 0.0017417825292795897, 9.877607226371765e-06
2025-05-08 16:43:00.892 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 160: 0.0019981060177087784, 9.859679266810417e-06
2025-05-08 16:43:01.035 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 126: 0.0018632980063557625, 9.83639620244503e-06
2025-05-08 16:43:01.236 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 177: 0.0018766681896522641, 9.935698471963406e-06
2025-05-08 16:43:01.365 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 128: 0.0018958383006975055, 9.955023415386677e-06
2025-05-08 16:43:01.510 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 144: 0.0015788031741976738, 9.9594471976161e-06
2025-05-08 16:43:01.543 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 32: 0.00016102954396046698, 9.845243766903877e-06
2025-05-08 16:43:01.633 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 91: 0.0007083074888214469, 9.8293530754745e-06
2025-05-08 16:43:01.706 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 73: 0.0014121433487161994, 9.9940225481987e-06
2025-05-08 16:43:01.752 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 45: 0.0006760748801752925, 7.245165761560202e-06
2025-05-08 16:43:01.754 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 3.099742329482069e-08, 3.099742329482069e-08
2025-05-08 16:43:01.779 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 23: 0.0005124595481902361, 9.853160008788109e-06
2025-05-08 16:43:01.781 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 1.3740951487761777e-07, 1.3740951487761777e-07
2025-05-08 16:43:01.883 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 102: 0.0014482949627563357, 9.968061931431293e-06
2025-05-08 16:43:02.158 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 299: nan, nan
2025-05-08 16:43:02.162 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 2.021827896214745e-07, 2.021827896214745e-07
2025-05-08 16:43:02.252 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 90: 0.0013474671868607402, 9.796698577702045e-06
[6]:
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_our = imputed_our[:]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['STIMP' for i in range(imputed_our.shape[0])])
method.extend(['DINEOF' for i in range(imputed_our.shape[0])])
data = {'truth': np.concatenate([truth.numpy() for i in range(2)]),
        'imputed':np.concatenate([imputed_our.numpy(), imputed_dineof]),
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_our).statistic.item()
pcc2=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
plt.text(0.40, 0.88, 'PCC={:.4f}'.format(pcc2), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#576fa0")
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
[6]:
(array([-2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5]),
 [Text(0, -2.0, '−2.0'),
  Text(0, -1.5, '−1.5'),
  Text(0, -1.0, '−1.0'),
  Text(0, -0.5, '−0.5'),
  Text(0, 0.0, '0.0'),
  Text(0, 0.5, '0.5'),
  Text(0, 1.0, '1.0'),
  Text(0, 1.5, '1.5'),
  Text(0, 2.0, '2.0'),
  Text(0, 2.5, '2.5')])
../../_images/analysis_Yangtze_11-imputation-yangtze-estuary_6_1.png

Fig 5g: Missing rate is equal to 0.5

[8]:
config.missing_ratio=0.5
from torch.utils.data import DataLoader
test_dloader = DataLoader(PRE8dDataset(config, mode='test'), 1, shuffle=False)
datas, data_ob_masks, data_gt_masks, labels, label_masks = next(iter(test_dloader))
device = "cuda"
datas, data_ob_masks, data_gt_masks, labels, label_masks = datas.float().to(device), data_ob_masks.to(device), data_gt_masks.to(device), labels.to(device), label_masks.to(device)

model = torch.load("./log_bak/imputation/Yangtze/STIMP/best_0.5.pt")
model = model.to(device)
cond_mask = data_gt_masks
imputed_data_our = model.impute(datas, cond_mask, adj, 10)
imputed_data_our = imputed_data_our.median(1).values
mask = data_ob_masks - cond_mask
imputed_our = imputed_data_our[0][mask.bool().cpu()[0]]
truth = datas[0][mask.bool()[0]].cpu()
[9]:
from einops import rearrange
from model.dineof_per_step import DINEOF
is_sea = np.load("./data/Yangtze/is_sea.npy")
H, W = is_sea.shape
model = DINEOF(10, [H, W], keep_non_negative_only=False)
datas_image = torch.zeros(1,46,1,H,W)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,H,W)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,H,W)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

impute_data_list = []
for t in range(datas_image.shape[1]):
    data = datas_image[:, t, :, :, :].squeeze()

    tmp_data = torch.where(cond_mask_image[:,t].cpu().squeeze()==0, float("nan"), data.cpu())
    model.fit(tmp_data.numpy())

    imputed_data = model.predict()
    imputed_data = rearrange(imputed_data, "(b t c h) w->b t c h w", b=1, t=1, c=1, h=data.shape[-2], w=data.shape[-1])
    impute_data_list.append(torch.from_numpy(imputed_data))

imputed_data = torch.cat(impute_data_list, dim=1).numpy()
2025-05-08 16:44:13.178 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 60: 0.00018172245472669601, 9.632494766265154e-06
2025-05-08 16:44:13.287 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 103: 0.001003644079901278, 9.98412724584341e-06
2025-05-08 16:44:13.367 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 81: 0.00034616910852491856, 9.852286893874407e-06
2025-05-08 16:44:13.370 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 3.6166316164099044e-08, 3.6166316164099044e-08
2025-05-08 16:44:13.489 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 114: 0.001367044635117054, 9.751529432833195e-06
2025-05-08 16:44:13.559 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 72: 0.0005230333190411329, 9.848852641880512e-06
2025-05-08 16:44:13.679 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 114: 0.001115604885853827, 9.942566975951195e-06
2025-05-08 16:44:13.750 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 70: 0.0005284450016915798, 9.563169442117214e-06
2025-05-08 16:44:13.851 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 89: 0.0007980529335327446, 9.904848411679268e-06
2025-05-08 16:44:13.925 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 74: 0.0005747053073719144, 9.889889042824507e-06
2025-05-08 16:44:14.097 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 154: 0.0016806392231956124, 9.83488280326128e-06
2025-05-08 16:44:14.216 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 114: 0.0011925606522709131, 9.8345335572958e-06
2025-05-08 16:44:14.410 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 182: 0.0016186311841011047, 9.693787433207035e-06
2025-05-08 16:44:14.540 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 126: 0.0013545361580327153, 9.890994988381863e-06
2025-05-08 16:44:14.639 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 97: 0.0009067094069905579, 9.997456800192595e-06
2025-05-08 16:44:14.767 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 123: 0.001609492814168334, 9.93290450423956e-06
2025-05-08 16:44:14.785 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 16: 2.015665995713789e-05, 7.93514664110262e-06
2025-05-08 16:44:14.841 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 54: 0.00024294178001582623, 9.629933629184961e-06
2025-05-08 16:44:14.860 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 18: 3.5876695619663224e-05, 8.131828508339822e-06
2025-05-08 16:44:14.880 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 19: 0.0003714900813065469, 3.156892489641905e-07
2025-05-08 16:44:14.992 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 110: 0.0007659235852770507, 9.788316674530506e-06
2025-05-08 16:44:15.160 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 160: 0.0017956334631890059, 9.888550266623497e-06
2025-05-08 16:44:15.345 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 172: 0.0020615363027900457, 9.840354323387146e-06
2025-05-08 16:44:15.487 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 133: 0.0015492900274693966, 9.929994121193886e-06
2025-05-08 16:44:15.633 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 136: 0.002137102885171771, 9.997747838497162e-06
2025-05-08 16:44:15.756 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 117: 0.001405011280439794, 9.96212475001812e-06
2025-05-08 16:44:15.812 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 55: 0.0003225987602490932, 9.592913556843996e-06
2025-05-08 16:44:15.950 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 125: 0.0019885250367224216, 9.818002581596375e-06
2025-05-08 16:44:16.114 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 154: 0.0015739976661279798, 9.79576725512743e-06
2025-05-08 16:44:16.260 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 134: 0.0015807206509634852, 9.996816515922546e-06
2025-05-08 16:44:16.415 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 148: 0.0014943223213776946, 9.724986739456654e-06
2025-05-08 16:44:16.602 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 172: 0.001525407424196601, 9.94314905256033e-06
2025-05-08 16:44:16.733 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 120: 0.0029906926210969687, 9.904149919748306e-06
2025-05-08 16:44:16.843 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 105: 0.0012036649277433753, 9.967712685465813e-06
2025-05-08 16:44:16.963 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 114: 0.0009628551779314876, 9.952287655323744e-06
2025-05-08 16:44:16.982 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 18: 3.529783134581521e-05, 8.971062925411388e-06
2025-05-08 16:44:17.054 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 70: 0.00032338700839318335, 9.956740541383624e-06
2025-05-08 16:44:17.123 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 68: 0.0004974923795089126, 9.769981261342764e-06
2025-05-08 16:44:17.180 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 58: 0.00024812668561935425, 9.839219273999333e-06
2025-05-08 16:44:17.186 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 2.4436809908934265e-08, 2.4436809908934265e-08
2025-05-08 16:44:17.220 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 33: 0.00036679705954156816, 9.863695595413446e-06
2025-05-08 16:44:17.222 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 1.701689029687259e-07, 1.701689029687259e-07
2025-05-08 16:44:17.334 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 108: 0.0010512343142181635, 9.797979146242142e-06
2025-05-08 16:44:17.602 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 299: nan, nan
2025-05-08 16:44:17.604 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 9.214258511747175e-08, 9.214258511747175e-08
2025-05-08 16:44:17.712 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 104: 0.0018026444595307112, 9.951996617019176e-06
[10]:
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_our = imputed_our[:]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['STIMP' for i in range(imputed_our.shape[0])])
method.extend(['DINEOF' for i in range(imputed_our.shape[0])])
data = {'truth': np.concatenate([truth.numpy() for i in range(2)]),
        'imputed':np.concatenate([imputed_our.numpy(), imputed_dineof]),
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_our).statistic.item()
pcc2=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
plt.text(0.40, 0.88, 'PCC={:.4f}'.format(pcc2), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#576fa0")
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
[10]:
(array([-2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5]),
 [Text(0, -2.0, '−2.0'),
  Text(0, -1.5, '−1.5'),
  Text(0, -1.0, '−1.0'),
  Text(0, -0.5, '−0.5'),
  Text(0, 0.0, '0.0'),
  Text(0, 0.5, '0.5'),
  Text(0, 1.0, '1.0'),
  Text(0, 1.5, '1.5'),
  Text(0, 2.0, '2.0'),
  Text(0, 2.5, '2.5')])
../../_images/analysis_Yangtze_11-imputation-yangtze-estuary_10_1.png

Fig 5g: Missing rate is equal to 0.9

[12]:
config.missing_ratio=0.9
from torch.utils.data import DataLoader
test_dloader = DataLoader(PRE8dDataset(config, mode='test'), 1, shuffle=False)
datas, data_ob_masks, data_gt_masks, labels, label_masks = next(iter(test_dloader))
device = "cuda"
datas, data_ob_masks, data_gt_masks, labels, label_masks = datas.float().to(device), data_ob_masks.to(device), data_gt_masks.to(device), labels.to(device), label_masks.to(device)

model = torch.load("./log_bak/imputation/Yangtze/STIMP/best_0.9.pt")
model = model.to(device)
cond_mask = data_gt_masks
imputed_data_our = model.impute(datas, cond_mask, adj, 10)
imputed_data_our = imputed_data_our.median(1).values
mask = data_ob_masks - cond_mask
imputed_our = imputed_data_our[0][mask.bool().cpu()[0]]
truth = datas[0][mask.bool()[0]].cpu()
[14]:
from einops import rearrange
from model.dineof_per_step import DINEOF
is_sea = np.load("./data/Yangtze/is_sea.npy")
H, W = is_sea.shape
model = DINEOF(10, [H, W], keep_non_negative_only=False, nitemax=3)
datas_image = torch.zeros(1,46,1,H,W)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,H,W)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,H,W)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

impute_data_list = []
for t in range(datas_image.shape[1]):
    data = datas_image[:, t, :, :, :].squeeze()

    tmp_data = torch.where(cond_mask_image[:,t].cpu().squeeze()==0, float("nan"), data.cpu())
    model.fit(tmp_data.numpy())

    imputed_data = model.predict()
    imputed_data = rearrange(imputed_data, "(b t c h) w->b t c h w", b=1, t=1, c=1, h=data.shape[-2], w=data.shape[-1])
    impute_data_list.append(torch.from_numpy(imputed_data))

imputed_data = torch.cat(impute_data_list, dim=1).numpy()
2025-05-08 16:49:09.361 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 7.668309365271853e-08, 7.668309365271853e-08
2025-05-08 16:49:09.375 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.05715912580490112, 0.022987596690654755
2025-05-08 16:49:09.386 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.03161339461803436, 0.02170996367931366
2025-05-08 16:49:09.403 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: inf, nan
2025-05-08 16:49:09.419 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.0353475846350193, 0.013851732015609741
2025-05-08 16:49:09.428 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.019878100603818893, 0.016625836491584778
2025-05-08 16:49:09.439 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.025722039863467216, 0.011239001527428627
2025-05-08 16:49:09.448 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.015220475383102894, 0.004366486333310604
2025-05-08 16:49:09.456 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.017662154510617256, 0.003993446007370949
2025-05-08 16:49:09.464 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.021306386217474937, 0.012515569105744362
2025-05-08 16:49:09.474 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.055746566504240036, 0.01922396942973137
2025-05-08 16:49:09.483 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.03499658405780792, 0.013764653354883194
2025-05-08 16:49:09.492 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.058387160301208496, 0.02401900291442871
2025-05-08 16:49:09.501 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.026054155081510544, 0.010700307786464691
2025-05-08 16:49:09.510 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.021644258871674538, 0.008256109431385994
2025-05-08 16:49:09.515 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.03321734815835953, 0.014978978782892227
2025-05-08 16:49:09.518 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 2.3848681252047754e-08, 2.3848681252047754e-08
2025-05-08 16:49:09.523 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.009581639431416988, 0.004611657932400703
2025-05-08 16:49:09.525 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 2.346368255246034e-08, 2.346368255246034e-08
2025-05-08 16:49:09.527 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 3.7093514038133435e-08, 3.7093514038133435e-08
2025-05-08 16:49:09.533 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.022515250369906425, 0.009211746975779533
2025-05-08 16:49:09.538 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.06756709516048431, 0.028331279754638672
2025-05-08 16:49:09.544 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.09469026327133179, 0.03785631060600281
2025-05-08 16:49:09.550 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.03914198279380798, 0.010798364877700806
2025-05-08 16:49:09.556 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.04378948733210564, 0.019532058387994766
2025-05-08 16:49:09.561 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.03323502838611603, 0.014274928718805313
2025-05-08 16:49:09.567 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.011877693235874176, 0.0035524489358067513
2025-05-08 16:49:09.573 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.046788524836301804, 0.019438933581113815
2025-05-08 16:49:09.579 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.03796656057238579, 0.015091147273778915
2025-05-08 16:49:09.584 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.04216518625617027, 0.016477413475513458
2025-05-08 16:49:09.590 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.03814290836453438, 0.016480151563882828
2025-05-08 16:49:09.596 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.055455874651670456, 0.024559568613767624
2025-05-08 16:49:09.602 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.05383404344320297, 0.02300872653722763
2025-05-08 16:49:09.607 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.03453878313302994, 0.013379789888858795
2025-05-08 16:49:09.613 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.02408694103360176, 0.012295190244913101
2025-05-08 16:49:09.615 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 2.2849956593518073e-08, 2.2849956593518073e-08
2025-05-08 16:49:09.620 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.015276920981705189, 0.008797475136816502
2025-05-08 16:49:09.626 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.02445455826818943, 0.01103702001273632
2025-05-08 16:49:09.631 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.00716039165854454, 0.004828255623579025
2025-05-08 16:49:09.633 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 7.721340011812572e-09, 7.721340011812572e-09
2025-05-08 16:49:09.637 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.001410915981978178, 0.0010338590946048498
2025-05-08 16:49:09.640 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 0: 3.007425064538438e-08, 3.007425064538438e-08
2025-05-08 16:49:09.645 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.0503329262137413, 0.020350046455860138
2025-05-08 16:49:09.649 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: nan, nan
2025-05-08 16:49:09.654 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: inf, nan
2025-05-08 16:49:09.660 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 2: 0.08589810878038406, 0.04171190410852432
[15]:
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_our = imputed_our[:]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['STIMP' for i in range(imputed_our.shape[0])])
method.extend(['DINEOF' for i in range(imputed_our.shape[0])])
data = {'truth': np.concatenate([truth.numpy() for i in range(2)]),
        'imputed':np.concatenate([imputed_our.numpy(), imputed_dineof]),
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_our).statistic.item()
pcc2=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
plt.text(0.40, 0.88, 'PCC={:.4f}'.format(pcc2), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#576fa0")
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
[15]:
(array([-1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5]),
 [Text(0, -1.5, '−1.5'),
  Text(0, -1.0, '−1.0'),
  Text(0, -0.5, '−0.5'),
  Text(0, 0.0, '0.0'),
  Text(0, 0.5, '0.5'),
  Text(0, 1.0, '1.0'),
  Text(0, 1.5, '1.5'),
  Text(0, 2.0, '2.0'),
  Text(0, 2.5, '2.5')])
../../_images/analysis_Yangtze_11-imputation-yangtze-estuary_14_1.png