STIMP vs DINEOF with different values of t

[1]:
import torch
import numpy as np
import sys
import os
from dataset.dataset_imputation import PRE8dDataset
[2]:
import argparse
parser = argparse.ArgumentParser(description='Imputation')

# args for area and methods
parser.add_argument('--area', type=str, default='PRE', help='which bay area we focus')

# basic args
parser.add_argument('--epochs', type=int, default=500, help='epochs')
parser.add_argument('--batch_size', type=int, default=16, help='batch size')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--wd', type=float, default=1e-4, help='weight decay')
parser.add_argument('--test_freq', type=int, default=500, help='test per n epochs')
parser.add_argument('--embedding_size', type=int, default=32)
parser.add_argument('--hidden_channels', type=int, default=32)
parser.add_argument('--diffusion_embedding_size', type=int, default=64)
parser.add_argument('--side_channels', type=int, default=1)

# args for tasks
parser.add_argument('--in_len', type=int, default=46)
parser.add_argument('--out_len', type=int, default=46)
parser.add_argument('--missing_ratio', type=float, default=0.1)

# args for diffusion
parser.add_argument('--beta_start', type=float, default=0.0001, help='beta start from this')
parser.add_argument('--beta_end', type=float, default=0.2, help='beta end to this')
parser.add_argument('--num_steps', type=float, default=50, help='denoising steps')
parser.add_argument('--num_samples', type=int, default=10, help='n datasets')
parser.add_argument('--schedule', type=str, default='quad', help='noise schedule type')
parser.add_argument('--target_strategy', type=str, default='random', help='mask')

# args for mae
parser.add_argument('--num_heads', type=int, default=8, help='n heads for self attention')
config = parser.parse_args([])

missing rate is equal to 0.1

[3]:
from torch.utils.data import DataLoader
test_dloader = DataLoader(PRE8dDataset(config, mode='test'), 1, shuffle=False)
datas, data_ob_masks, data_gt_masks, labels, label_masks = next(iter(test_dloader))
device = "cuda"
datas, data_ob_masks, data_gt_masks, labels, label_masks = datas.float().to(device), data_ob_masks.to(device), data_gt_masks.to(device), labels.to(device), label_masks.to(device)

# load model
model = torch.load("./log_bak/imputation/PRE/STIMP/best_0.1.pt")
model = model.to(device)
cond_mask = data_gt_masks
adj = np.load("./data/{}/adj.npy".format(config.area))
adj = torch.from_numpy(adj).float().to(device)

imputed_data_our = model.impute(datas, cond_mask, adj, 10)
imputed_data_our = imputed_data_our.median(1).values
mask = data_ob_masks - cond_mask
imputed_our = imputed_data_our[0][mask.bool().cpu()[0]]
truth = datas[0][mask.bool()[0]].cpu()
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['STIMP' for i in range(imputed_our.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_our,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_our).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
[3]:
Text(0.4, 0.92, 'PCC=0.9960')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_4_1.png
[4]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()
model.fit(tmp_data)

imputed_data = model.predict()
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=46)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:39:32.140 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 187: 0.0019914316944777966, 9.987736120820045e-06
[4]:
Text(0.4, 0.92, 'PCC=0.9094')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_5_2.png
[5]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(4, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=5
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = model.predict()
imputed_data_list.append(imputed_data[:,46//t*t-46:])
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=5)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:39:42.556 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 77: 0.0026193836238235235, 9.065261110663414e-06
2025-06-03 14:39:42.811 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 72: 0.0014016376808285713, 9.566894732415676e-06
2025-06-03 14:39:42.988 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 67: 0.0008826245903037488, 9.261711966246367e-06
2025-06-03 14:39:43.147 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 69: 0.0005015823990106583, 9.681680239737034e-06
2025-06-03 14:39:43.248 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 34: 0.000318384263664484, 9.886338375508785e-06
2025-06-03 14:39:43.268 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 0: 1.001445326664907e-07, 1.001445326664907e-07
2025-06-03 14:39:43.440 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 68: 0.0011385581456124783, 9.389128535985947e-06
2025-06-03 14:39:43.684 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 103: 0.001417607069015503, 9.933486580848694e-06
2025-06-03 14:39:43.826 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 53: 0.0011201013112440705, 9.623472578823566e-06
[5]:
Text(0.4, 0.92, 'PCC=0.7522')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_6_2.png
[6]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=11
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = model.predict()
imputed_data_list.append(imputed_data[:,46//t*t-46:])
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=11)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:39:53.557 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 92: 0.0008462776313535869, 9.692506864666939e-06
2025-06-03 14:39:54.342 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 79: 0.0003628540434874594, 9.699346264824271e-06
2025-06-03 14:39:54.367 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 0: 2.0300009850870993e-07, 2.0300009850870993e-07
2025-06-03 14:39:55.027 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 73: 0.0006849292549304664, 9.863113518804312e-06
[6]:
Text(0.4, 0.92, 'PCC=0.8131')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_7_2.png
[7]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=23
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])

imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=23)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:40:05.888 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 173: 0.0021644565276801586, 9.856652468442917e-06
2025-06-03 14:40:08.135 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 153: 0.00187805260065943, 9.957118891179562e-06
[7]:
Text(0.4, 0.92, 'PCC=0.7706')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_8_2.png
[8]:
from einops import rearrange
from model.dineof_per_step import DINEOF
model = DINEOF(10, [60, 96], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

impute_data_list = []
for t in range(datas_image.shape[1]):
    data = datas_image[:, t, :, :, :].squeeze()

    tmp_data = torch.where(cond_mask_image[:,t].cpu().squeeze()==0, float("nan"), data.cpu())
    model.fit(tmp_data.numpy())

    imputed_data = model.predict()
    imputed_data = rearrange(imputed_data, "(b t c h) w->b t c h w", b=1, t=1, c=1, h=data.shape[-2], w=data.shape[-1])
    impute_data_list.append(torch.from_numpy(imputed_data))

imputed_data = torch.cat(impute_data_list, dim=1).numpy()
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=1)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:40:16.510 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 20: 0.003975672647356987, 1.2898817658424377e-07
2025-06-03 14:40:16.684 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 161: 0.0020144523587077856, 9.949551895260811e-06
2025-06-03 14:40:16.866 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 178: 0.0018819323740899563, 9.995768778026104e-06
2025-06-03 14:40:17.049 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 178: 0.0022644158452749252, 9.86829400062561e-06
2025-06-03 14:40:17.203 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 153: 0.002110376488417387, 9.974930435419083e-06
2025-06-03 14:40:17.310 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 106: 0.004458888433873653, 9.761657565832138e-06
2025-06-03 14:40:17.498 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 195: 0.0020867022685706615, 9.883660823106766e-06
2025-06-03 14:40:17.618 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 121: 0.00197081221267581, 9.909737855195999e-06
2025-06-03 14:40:17.766 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 144: 0.001497694756835699, 9.916257113218307e-06
2025-06-03 14:40:17.887 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 113: 0.0014181643491610885, 9.966548532247543e-06
2025-06-03 14:40:18.006 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 117: 0.0020637616980820894, 9.940704330801964e-06
2025-06-03 14:40:18.128 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 124: 0.0016065065283328295, 9.834999218583107e-06
2025-06-03 14:40:18.229 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 101: 0.0014993386575952172, 9.959563612937927e-06
2025-06-03 14:40:18.347 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 120: 0.001413665246218443, 9.934185072779655e-06
2025-06-03 14:40:18.475 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 125: 0.0026759603060781956, 9.981682524085045e-06
2025-06-03 14:40:18.562 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 90: 0.0006393226212821901, 9.612878784537315e-06
2025-06-03 14:40:18.662 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 103: 0.001422219444066286, 9.986921213567257e-06
2025-06-03 14:40:18.754 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 92: 0.001823773025535047, 9.940238669514656e-06
2025-06-03 14:40:18.914 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 164: 0.0015448789345100522, 9.940122254192829e-06
2025-06-03 14:40:19.074 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 165: 0.0015842188149690628, 9.893206879496574e-06
2025-06-03 14:40:19.175 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 100: 0.0025141809601336718, 9.589595720171928e-06
2025-06-03 14:40:19.305 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 133: 0.001975116552785039, 9.931623935699463e-06
2025-06-03 14:40:19.410 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 106: 0.001957172527909279, 9.74256545305252e-06
2025-06-03 14:40:19.515 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 107: 0.0011938064126297832, 9.875744581222534e-06
2025-06-03 14:40:19.580 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 67: 0.000746228382922709, 9.880692232400179e-06
2025-06-03 14:40:19.718 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 141: 0.002045275643467903, 9.821029379963875e-06
2025-06-03 14:40:19.814 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 98: 0.0007599775562994182, 9.876617696136236e-06
2025-06-03 14:40:20.083 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 299: nan, nan
2025-06-03 14:40:20.177 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 95: 0.0016601908719167113, 9.773299098014832e-06
2025-06-03 14:40:20.265 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 91: 0.0007317067356780171, 9.893381502479315e-06
2025-06-03 14:40:20.411 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 149: 0.002038968028500676, 9.984010830521584e-06
2025-06-03 14:40:20.565 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 156: 0.00195962842553854, 9.837094694375992e-06
2025-06-03 14:40:20.644 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 81: 0.0007577511132694781, 9.7480951808393e-06
2025-06-03 14:40:20.796 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 155: 0.002205200493335724, 9.92720015347004e-06
2025-06-03 14:40:20.915 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 122: 0.0013895472511649132, 9.71788540482521e-06
2025-06-03 14:40:21.106 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 196: 0.0022645434364676476, 9.779119864106178e-06
2025-06-03 14:40:21.187 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 84: 0.0009352737688459456, 9.85770020633936e-06
2025-06-03 14:40:21.331 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 147: 0.0012988743837922812, 9.927898645401001e-06
2025-06-03 14:40:21.530 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 201: 0.002475715707987547, 9.78703610599041e-06
2025-06-03 14:40:21.708 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 179: 0.002565964125096798, 9.855953976511955e-06
2025-06-03 14:40:21.838 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 131: 0.0013297094264999032, 9.847921319305897e-06
2025-06-03 14:40:21.982 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 147: 0.001986656803637743, 9.72743146121502e-06
2025-06-03 14:40:22.090 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 108: 0.0013636482181027532, 9.835697710514069e-06
2025-06-03 14:40:22.204 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 116: 0.0013005349319428205, 9.927665814757347e-06
2025-06-03 14:40:22.285 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 82: 0.000990861328318715, 9.592738933861256e-06
2025-06-03 14:40:22.447 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 167: 0.0031685391440987587, 9.688083082437515e-06
[8]:
Text(0.4, 0.92, 'PCC=0.9877')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_9_2.png

missing rate is equal to 0.3

[9]:
config.missing_ratio=0.3
from torch.utils.data import DataLoader
test_dloader = DataLoader(PRE8dDataset(config, mode='test'), 1, shuffle=False)
datas, data_ob_masks, data_gt_masks, labels, label_masks = next(iter(test_dloader))
device = "cuda"
datas, data_ob_masks, data_gt_masks, labels, label_masks = datas.float().to(device), data_ob_masks.to(device), data_gt_masks.to(device), labels.to(device), label_masks.to(device)

model = torch.load("./log_bak/imputation/PRE/STIMP/best_0.3.pt")
model = model.to(device)
cond_mask = data_gt_masks
adj = np.load("./data/{}/adj.npy".format(config.area))
adj = torch.from_numpy(adj).float().to(device)

imputed_data_our = model.impute(datas, cond_mask, adj, 10)
imputed_data_our = imputed_data_our.median(1).values
mask = data_ob_masks - cond_mask
imputed_our = imputed_data_our[0][mask.bool().cpu()[0]]
truth = datas[0][mask.bool()[0]].cpu()
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['STIMP' for i in range(imputed_our.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_our,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_our).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
[9]:
Text(0.4, 0.92, 'PCC=0.9929')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_11_1.png
[10]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()
model.fit(tmp_data)

imputed_data = model.predict()
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=46)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:41:20.321 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 155: 0.0017548399046063423, 9.990297257900238e-06
[10]:
Text(0.4, 0.92, 'PCC=0.8779')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_12_2.png
[11]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(4, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=5
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = model.predict()
imputed_data_list.append(imputed_data[:,46//t*t-46:])
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=5)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:41:48.607 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 53: 0.0027655723970383406, 9.028706699609756e-06
2025-06-03 14:41:48.825 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 74: 0.0006474465481005609, 9.669165592640638e-06
2025-06-03 14:41:48.928 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 35: 0.0010408916277810931, 6.753136403858662e-06
2025-06-03 14:41:49.083 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 57: 0.0010452123824506998, 9.997398592531681e-06
2025-06-03 14:41:49.186 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 37: 0.0003766003646887839, 9.427341865375638e-06
2025-06-03 14:41:49.205 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 0: 6.371696059659371e-08, 6.371696059659371e-08
2025-06-03 14:41:49.297 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 33: 0.0015013111988082528, 7.3642004281282425e-06
2025-06-03 14:41:49.474 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 67: 0.0006101377075538039, 9.5631112344563e-06
2025-06-03 14:41:49.672 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 75: 0.0002087982720695436, 9.611743735149503e-06
[11]:
Text(0.4, 0.92, 'PCC=0.7057')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_13_2.png
[12]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=11
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = model.predict()
imputed_data_list.append(imputed_data[:,46//t*t-46:])
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=11)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:42:18.213 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 31: 0.0011603409657254815, 9.620096534490585e-06
2025-06-03 14:42:18.945 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 72: 0.0003450763178989291, 9.726471034809947e-06
2025-06-03 14:42:18.970 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 0: 2.0481138562900014e-07, 2.0481138562900014e-07
2025-06-03 14:42:19.449 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 47: 0.0004008303221780807, 9.724462870508432e-06
[12]:
Text(0.4, 0.92, 'PCC=0.7840')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_14_2.png
[13]:
# DINEO
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=23
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])

imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=23)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:42:47.936 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 161: 0.0018636604072526097, 9.829876944422722e-06
2025-06-03 14:42:51.022 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 134: 0.0013445770600810647, 9.900075383484364e-06
[13]:
Text(0.4, 0.92, 'PCC=0.7324')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_15_2.png
[14]:
from einops import rearrange
from model.dineof_per_step import DINEOF
model = DINEOF(10, [60, 96], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

impute_data_list = []
for t in range(datas_image.shape[1]):
    data = datas_image[:, t, :, :, :].squeeze()

    tmp_data = torch.where(cond_mask_image[:,t].cpu().squeeze()==0, float("nan"), data.cpu())
    model.fit(tmp_data.numpy())

    imputed_data = model.predict()
    imputed_data = rearrange(imputed_data, "(b t c h) w->b t c h w", b=1, t=1, c=1, h=data.shape[-2], w=data.shape[-1])
    impute_data_list.append(torch.from_numpy(imputed_data))

imputed_data = torch.cat(impute_data_list, dim=1).numpy()
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=1)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:43:17.904 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 48: 0.004655675962567329, 8.131377398967743e-06
2025-06-03 14:43:18.071 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 160: 0.0019928442779928446, 9.9940225481987e-06
2025-06-03 14:43:18.226 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 148: 0.002047802321612835, 9.770272299647331e-06
2025-06-03 14:43:18.394 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 163: 0.002119346521794796, 9.896466508507729e-06
2025-06-03 14:43:18.551 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 152: 0.0017108683241531253, 9.92603600025177e-06
2025-06-03 14:43:18.728 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 172: 0.0022290218621492386, 9.857118129730225e-06
2025-06-03 14:43:18.876 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 144: 0.0018798939418047667, 9.973184205591679e-06
2025-06-03 14:43:19.011 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 133: 0.00172597193159163, 9.88272950053215e-06
2025-06-03 14:43:19.125 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 110: 0.0016323969466611743, 9.995303116738796e-06
2025-06-03 14:43:19.237 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 111: 0.001017807167954743, 9.887618944048882e-06
2025-06-03 14:43:19.372 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 130: 0.001411044504493475, 9.873183444142342e-06
2025-06-03 14:43:19.483 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 105: 0.0014871939783915877, 9.82615165412426e-06
2025-06-03 14:43:19.609 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 123: 0.001167740672826767, 9.941868484020233e-06
2025-06-03 14:43:19.732 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 122: 0.0011805256363004446, 9.973999112844467e-06
2025-06-03 14:43:19.886 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 150: 0.0017491979524493217, 9.903451427817345e-06
2025-06-03 14:43:19.968 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 84: 0.0005303048528730869, 9.882147423923016e-06
2025-06-03 14:43:20.078 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 110: 0.00118404277600348, 9.70461405813694e-06
2025-06-03 14:43:20.208 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 129: 0.0013348772190511227, 9.870622307062149e-06
2025-06-03 14:43:20.346 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 134: 0.0014239137526601553, 9.969458915293217e-06
2025-06-03 14:43:20.469 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 120: 0.0019584931433200836, 9.825453162193298e-06
2025-06-03 14:43:20.594 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 124: 0.0015320188831537962, 9.776325896382332e-06
2025-06-03 14:43:20.757 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 161: 0.0014761799247935414, 9.734882041811943e-06
2025-06-03 14:43:20.887 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 129: 0.0012096795253455639, 9.93593130260706e-06
2025-06-03 14:43:20.910 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 18: 0.005009130109101534, 1.3993121683597565e-06
2025-06-03 14:43:20.970 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 62: 0.00033705425448715687, 9.87551175057888e-06
2025-06-03 14:43:21.137 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 167: 0.001635870081372559, 9.87970270216465e-06
2025-06-03 14:43:21.198 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 63: 0.0007587539730593562, 9.657815098762512e-06
2025-06-03 14:43:21.465 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 299: nan, nan
2025-06-03 14:43:21.601 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 137: 0.0016998598584905267, 9.828712791204453e-06
2025-06-03 14:43:21.672 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 72: 0.0008733900031074882, 9.939016308635473e-06
2025-06-03 14:43:21.805 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 129: 0.0018647173419594765, 9.874231182038784e-06
2025-06-03 14:43:21.925 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 119: 0.0018401555716991425, 9.884359315037727e-06
2025-06-03 14:43:22.022 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 98: 0.0009357224917039275, 9.852403309196234e-06
2025-06-03 14:43:22.170 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 145: 0.0018999038729816675, 9.903684258460999e-06
2025-06-03 14:43:22.314 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 145: 0.0022125658579170704, 9.517418220639229e-06
2025-06-03 14:43:22.491 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 174: 0.0018526016501709819, 9.742798283696175e-06
2025-06-03 14:43:22.572 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 84: 0.0010193741181865335, 9.939656592905521e-06
2025-06-03 14:43:22.684 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 109: 0.0013175209751352668, 9.90019179880619e-06
2025-06-03 14:43:22.857 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 165: 0.0021723092067986727, 9.910669177770615e-06
2025-06-03 14:43:23.012 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 150: 0.001580399926751852, 9.829993359744549e-06
2025-06-03 14:43:23.137 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 122: 0.0010253838263452053, 9.854207746684551e-06
2025-06-03 14:43:23.300 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 160: 0.0014198306016623974, 9.61648765951395e-06
2025-06-03 14:43:23.418 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 116: 0.0011437155772000551, 9.907875210046768e-06
2025-06-03 14:43:23.525 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 105: 0.0014255134155973792, 9.8842428997159e-06
2025-06-03 14:43:23.617 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 94: 0.0008255602442659438, 9.94786387309432e-06
2025-06-03 14:43:23.805 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 184: 0.0018437991384416819, 9.90438275039196e-06
[14]:
Text(0.4, 0.92, 'PCC=0.9678')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_16_2.png

missing rate is equal to 0.5

[15]:
config.missing_ratio=0.5
from torch.utils.data import DataLoader
test_dloader = DataLoader(PRE8dDataset(config, mode='test'), 1, shuffle=False)
datas, data_ob_masks, data_gt_masks, labels, label_masks = next(iter(test_dloader))
device = "cuda"
datas, data_ob_masks, data_gt_masks, labels, label_masks = datas.float().to(device), data_ob_masks.to(device), data_gt_masks.to(device), labels.to(device), label_masks.to(device)

model = torch.load("./log_bak/imputation/PRE/STIMP/best_0.5.pt")
model = model.to(device)
cond_mask = data_gt_masks
adj = np.load("./data/{}/adj.npy".format(config.area))
adj = torch.from_numpy(adj).float().to(device)

imputed_data_our = model.impute(datas, cond_mask, adj, 10)
imputed_data_our = imputed_data_our.median(1).values
mask = data_ob_masks - cond_mask
imputed_our = imputed_data_our[0][mask.bool().cpu()[0]]
truth = datas[0][mask.bool()[0]].cpu()
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['STIMP' for i in range(imputed_our.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_our,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_our).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
[15]:
Text(0.4, 0.92, 'PCC=0.9869')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_18_1.png
[16]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()
model.fit(tmp_data)

imputed_data = model.predict()
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=46)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:44:57.624 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 144: 0.0018268038984388113, 9.925919584929943e-06
[16]:
Text(0.4, 0.92, 'PCC=0.8193')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_19_2.png
[17]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(4, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=5
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = model.predict()
imputed_data_list.append(imputed_data[:,46//t*t-46:])
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=5)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:45:42.085 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 78: 0.000733412045519799, 9.405659511685371e-06
2025-06-03 14:45:42.381 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 99: 0.00023820197384338826, 9.938448783941567e-06
2025-06-03 14:45:42.590 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 81: 0.00017151898646261543, 9.9015305750072e-06
2025-06-03 14:45:42.913 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 69: 0.0006918624276295304, 9.94885340332985e-06
2025-06-03 14:45:43.047 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 48: 0.00010128105350304395, 8.76635021995753e-06
2025-06-03 14:45:43.067 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 0: 8.490124514537456e-08, 8.490124514537456e-08
2025-06-03 14:45:43.176 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 37: 0.00032473201281391084, 9.391078492626548e-06
2025-06-03 14:45:43.267 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 29: 0.0008388007408939302, 8.374510798603296e-06
2025-06-03 14:45:43.308 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 9: 0.00332482042722404, 9.101582691073418e-06
[17]:
Text(0.4, 0.92, 'PCC=0.6175')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_20_2.png
[18]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=11
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = model.predict()
imputed_data_list.append(imputed_data[:,46//t*t-46:])
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=11)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:46:30.915 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 54: 0.0002735634334385395, 9.858689736574888e-06
2025-06-03 14:46:31.557 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 38: 0.0004660697595681995, 9.848095942288637e-06
2025-06-03 14:46:31.582 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 0: 1.9639867332443828e-07, 1.9639867332443828e-07
2025-06-03 14:46:32.032 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 48: 0.00019795661501120776, 9.590163244865835e-06
[18]:
Text(0.4, 0.92, 'PCC=0.7293')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_21_2.png
[19]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=23
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])

imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=23)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:47:19.572 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 135: 0.0013060076162219048, 9.912531822919846e-06
2025-06-03 14:47:21.003 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 107: 0.0009507655631750822, 9.934999980032444e-06
[19]:
Text(0.4, 0.92, 'PCC=0.6760')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_22_2.png
[20]:
from einops import rearrange
from model.dineof_per_step import DINEOF
model = DINEOF(10, [60, 96], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

impute_data_list = []
for t in range(datas_image.shape[1]):
    data = datas_image[:, t, :, :, :].squeeze()

    tmp_data = torch.where(cond_mask_image[:,t].cpu().squeeze()==0, float("nan"), data.cpu())
    model.fit(tmp_data.numpy())

    imputed_data = model.predict()
    imputed_data = rearrange(imputed_data, "(b t c h) w->b t c h w", b=1, t=1, c=1, h=data.shape[-2], w=data.shape[-1])
    impute_data_list.append(torch.from_numpy(imputed_data))

imputed_data = torch.cat(impute_data_list, dim=1).numpy()
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=1)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:48:06.799 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 49: 0.005038238130509853, 6.91460445523262e-06
2025-06-03 14:48:06.957 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 148: 0.0013725385069847107, 9.842682629823685e-06
2025-06-03 14:48:07.113 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 145: 0.0017285357462242246, 9.748036973178387e-06
2025-06-03 14:48:07.241 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 119: 0.0015530084492638707, 9.872368536889553e-06
2025-06-03 14:48:07.383 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 134: 0.001610918203368783, 9.803683497011662e-06
2025-06-03 14:48:07.540 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 148: 0.0017730167601257563, 9.95490700006485e-06
2025-06-03 14:48:07.694 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 145: 0.0017319935141131282, 9.993673302233219e-06
2025-06-03 14:48:07.854 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 154: 0.001431414857506752, 9.97120514512062e-06
2025-06-03 14:48:07.991 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 131: 0.0014655741397291422, 9.989249520003796e-06
2025-06-03 14:48:08.093 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 99: 0.000777890847530216, 9.931682143360376e-06
2025-06-03 14:48:08.213 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 114: 0.0014321752823889256, 9.897630661725998e-06
2025-06-03 14:48:08.310 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 94: 0.0010301697766408324, 9.978190064430237e-06
2025-06-03 14:48:08.414 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 102: 0.0005962348659522831, 9.701645467430353e-06
2025-06-03 14:48:08.520 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 103: 0.0010563157266005874, 9.98261384665966e-06
2025-06-03 14:48:08.673 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 146: 0.0017648418433964252, 9.884126484394073e-06
2025-06-03 14:48:08.763 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 91: 0.0005355935427360237, 9.874813258647919e-06
2025-06-03 14:48:08.866 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 100: 0.0010803189361467957, 9.809155017137527e-06
2025-06-03 14:48:08.962 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 92: 0.0012320077512413263, 9.81322955340147e-06
2025-06-03 14:48:09.099 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 130: 0.001507043605670333, 9.753857739269733e-06
2025-06-03 14:48:09.216 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 112: 0.0016237872187048197, 9.86072700470686e-06
2025-06-03 14:48:09.348 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 127: 0.0015006443718448281, 9.866082109510899e-06
2025-06-03 14:48:09.477 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 124: 0.0013007342349737883, 9.812647476792336e-06
2025-06-03 14:48:09.578 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 97: 0.0009409699705429375, 9.930517990142107e-06
2025-06-03 14:48:09.669 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 90: 0.0005766808171756566, 9.72225097939372e-06
2025-06-03 14:48:09.705 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 35: 6.190189014887437e-05, 9.855255484580994e-06
2025-06-03 14:48:09.894 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 188: 0.0017637026030570269, 9.837443940341473e-06
2025-06-03 14:48:09.944 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 50: 0.00015990309475455433, 9.815717930905521e-06
2025-06-03 14:48:10.212 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 299: nan, nan
2025-06-03 14:48:10.366 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 155: 0.0014985997695475817, 9.733368642628193e-06
2025-06-03 14:48:10.430 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 65: 0.00033150595845654607, 9.783951099961996e-06
2025-06-03 14:48:10.583 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 144: 0.0014082463458180428, 9.909155778586864e-06
2025-06-03 14:48:10.706 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 119: 0.001446563983336091, 9.836861863732338e-06
2025-06-03 14:48:10.793 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 85: 0.0008855313644744456, 9.855255484580994e-06
2025-06-03 14:48:10.980 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 182: 0.001519459648989141, 9.884359315037727e-06
2025-06-03 14:48:11.166 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 187: 0.0016631347825750709, 9.971670806407928e-06
2025-06-03 14:48:11.324 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 151: 0.001645904267206788, 9.759794920682907e-06
2025-06-03 14:48:11.423 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 102: 0.0006868625059723854, 9.900599252432585e-06
2025-06-03 14:48:11.519 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 93: 0.0011085433652624488, 9.763753041625023e-06
2025-06-03 14:48:11.688 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 159: 0.0016807268839329481, 9.835814125835896e-06
2025-06-03 14:48:11.829 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 133: 0.0018847923493012786, 9.87667590379715e-06
2025-06-03 14:48:11.933 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 100: 0.000951554044149816, 9.78721072897315e-06
2025-06-03 14:48:12.089 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 150: 0.0013353406684473157, 9.924289770424366e-06
2025-06-03 14:48:12.200 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 107: 0.0007168743759393692, 9.949086233973503e-06
2025-06-03 14:48:12.315 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 112: 0.000863533525262028, 9.930925443768501e-06
2025-06-03 14:48:12.417 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 103: 0.0006369405309669673, 9.867828339338303e-06
2025-06-03 14:48:12.557 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 132: 0.001542279147543013, 9.898794814944267e-06
[20]:
Text(0.4, 0.92, 'PCC=0.9018')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_23_2.png

missing rate is equal to 0.7

[21]:
config.missing_ratio=0.7
from torch.utils.data import DataLoader
test_dloader = DataLoader(PRE8dDataset(config, mode='test'), 1, shuffle=False)
datas, data_ob_masks, data_gt_masks, labels, label_masks = next(iter(test_dloader))
device = "cuda"
datas, data_ob_masks, data_gt_masks, labels, label_masks = datas.float().to(device), data_ob_masks.to(device), data_gt_masks.to(device), labels.to(device), label_masks.to(device)

model = torch.load("./log_bak/imputation/PRE/STIMP/best_0.7.pt")
model = model.to(device)
cond_mask = data_gt_masks
adj = np.load("./data/{}/adj.npy".format(config.area))
adj = torch.from_numpy(adj).float().to(device)

imputed_data_our = model.impute(datas, cond_mask, adj, 10)
imputed_data_our = imputed_data_our.median(1).values
mask = data_ob_masks - cond_mask
imputed_our = imputed_data_our[0][mask.bool().cpu()[0]]
truth = datas[0][mask.bool()[0]].cpu()
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['STIMP' for i in range(imputed_our.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_our,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_our).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
[21]:
Text(0.4, 0.92, 'PCC=0.9738')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_25_1.png
[22]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()
model.fit(tmp_data)

imputed_data = model.predict()
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=46)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:50:21.186 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 134: 0.001426593167707324, 9.857118129730225e-06
[22]:
Text(0.4, 0.92, 'PCC=0.6945')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_26_2.png
[23]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(4, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=5
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = model.predict()
imputed_data_list.append(imputed_data[:,46//t*t-46:])
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=5)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:51:23.631 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 33: 0.0009264986729249358, 8.64820322021842e-06
2025-06-03 14:51:23.698 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 14: 0.003838603151962161, 1.5683472156524658e-06
2025-06-03 14:51:23.855 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 66: 0.0001603393320692703, 9.902170859277248e-06
2025-06-03 14:51:23.923 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 22: 0.0005455789505504072, 7.616588845849037e-06
2025-06-03 14:51:23.974 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 13: 0.000480195099953562, 1.926033291965723e-06
2025-06-03 14:51:23.994 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 0: 6.187770651422397e-08, 6.187770651422397e-08
2025-06-03 14:51:24.130 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 51: 0.0002570715150795877, 9.377283276990056e-06
2025-06-03 14:51:24.235 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 35: 0.000477699504699558, 8.603790774941444e-06
2025-06-03 14:51:24.331 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 38: 0.00030104166944511235, 9.773211786523461e-06
[23]:
Text(0.4, 0.92, 'PCC=0.5129')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_27_2.png
[24]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=11
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = model.predict()
imputed_data_list.append(imputed_data[:,46//t*t-46:])
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=11)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:52:30.310 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 34: 8.404224354308099e-05, 9.235431207343936e-06
2025-06-03 14:52:30.568 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 32: 9.475122351432219e-05, 9.851901268120855e-06
2025-06-03 14:52:30.592 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 0: 1.7444766342578077e-07, 1.7444766342578077e-07
2025-06-03 14:52:31.441 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 33: 0.00024325033882632852, 9.955605491995811e-06
[24]:
Text(0.4, 0.92, 'PCC=0.6387')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_28_2.png
[25]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=23
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])

imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=23)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:53:39.471 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 109: 0.0006985452491790056, 9.77847957983613e-06
2025-06-03 14:53:41.061 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 81: 0.00047964908299036324, 9.909301297739148e-06
[25]:
Text(0.4, 0.92, 'PCC=0.5700')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_29_2.png
[26]:
from einops import rearrange
from model.dineof_per_step import DINEOF
model = DINEOF(10, [60, 96], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

impute_data_list = []
for t in range(datas_image.shape[1]):
    data = datas_image[:, t, :, :, :].squeeze()

    tmp_data = torch.where(cond_mask_image[:,t].cpu().squeeze()==0, float("nan"), data.cpu())
    model.fit(tmp_data.numpy())

    imputed_data = model.predict()
    imputed_data = rearrange(imputed_data, "(b t c h) w->b t c h w", b=1, t=1, c=1, h=data.shape[-2], w=data.shape[-1])
    impute_data_list.append(torch.from_numpy(imputed_data))

imputed_data = torch.cat(impute_data_list, dim=1).numpy()
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=1)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:54:46.188 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 24: 0.0032232939265668392, 1.2076925486326218e-06
2025-06-03 14:54:46.312 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 115: 0.0009245204273611307, 9.892915841192007e-06
2025-06-03 14:54:46.442 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 122: 0.0013841711916029453, 9.749084711074829e-06
2025-06-03 14:54:46.586 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 137: 0.0011840699007734656, 9.8721357062459e-06
2025-06-03 14:54:46.722 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 129: 0.0011788735864683986, 9.977724403142929e-06
2025-06-03 14:54:46.860 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 132: 0.0011889883317053318, 9.866198524832726e-06
2025-06-03 14:54:46.999 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 131: 0.0013804398477077484, 9.792391210794449e-06
2025-06-03 14:54:47.113 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 111: 0.0006923600449226797, 9.919691365212202e-06
2025-06-03 14:54:47.220 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 103: 0.0007930145366117358, 9.871728252619505e-06
2025-06-03 14:54:47.301 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 79: 0.00036203660420142114, 9.700102964416146e-06
2025-06-03 14:54:47.408 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 103: 0.000813257705885917, 9.981042239814997e-06
2025-06-03 14:54:47.486 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 75: 0.00034457066794857383, 9.93459252640605e-06
2025-06-03 14:54:47.552 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 64: 0.00033426942536607385, 9.856390533968806e-06
2025-06-03 14:54:47.640 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 86: 0.0005215616547502577, 9.883835446089506e-06
2025-06-03 14:54:47.771 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 125: 0.0010331954108551145, 9.938026778399944e-06
2025-06-03 14:54:47.853 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 83: 0.00019353469542693347, 9.689989383332431e-06
2025-06-03 14:54:47.942 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 86: 0.0005104717565700412, 9.731680620461702e-06
2025-06-03 14:54:48.031 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 87: 0.00040019932202994823, 9.671726729720831e-06
2025-06-03 14:54:48.134 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 98: 0.0007176211220212281, 9.98948235064745e-06
2025-06-03 14:54:48.244 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 106: 0.0007542030652984977, 9.798561222851276e-06
2025-06-03 14:54:48.358 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 110: 0.000702509714756161, 9.932438842952251e-06
2025-06-03 14:54:48.464 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 102: 0.0007986527634784579, 9.87231032922864e-06
2025-06-03 14:54:48.544 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 77: 0.0005171404336579144, 9.950483217835426e-06
2025-06-03 14:54:48.588 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 41: 0.0005318302428349853, 2.7730129659175873e-06
2025-06-03 14:54:48.615 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 26: 3.86626306863036e-05, 9.004939784063026e-06
2025-06-03 14:54:48.764 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 146: 0.0008877735235728323, 9.617942851036787e-06
2025-06-03 14:54:48.792 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 27: 3.3736901968950406e-05, 8.972965588327497e-06
2025-06-03 14:54:49.062 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 299: nan, nan
2025-06-03 14:54:49.148 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 85: 0.00019701171549968421, 9.592302376404405e-06
2025-06-03 14:54:49.183 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 34: 0.00010439110337756574, 9.474184480495751e-06
2025-06-03 14:54:49.318 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 128: 0.0012598165776580572, 9.925337508320808e-06
2025-06-03 14:54:49.425 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 103: 0.0006936208228580654, 9.904848411679268e-06
2025-06-03 14:54:49.496 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 70: 0.00021650551934726536, 9.753028280101717e-06
2025-06-03 14:54:49.631 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 132: 0.0009047363419085741, 9.861716534942389e-06
2025-06-03 14:54:49.804 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 174: 0.0010751709342002869, 9.94617585092783e-06
2025-06-03 14:54:49.963 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 153: 0.0011183877941220999, 9.929412044584751e-06
2025-06-03 14:54:50.010 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 47: 9.55200448515825e-05, 9.824914741329849e-06
2025-06-03 14:54:50.095 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 82: 0.00046053717960603535, 9.773240890353918e-06
2025-06-03 14:54:50.239 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 137: 0.0010929569834843278, 9.858980774879456e-06
2025-06-03 14:54:50.370 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 124: 0.0012969275703653693, 9.63243655860424e-06
2025-06-03 14:54:50.455 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 82: 0.00037931231781840324, 9.877898264676332e-06
2025-06-03 14:54:50.570 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 111: 0.0007979008951224387, 9.895593393594027e-06
2025-06-03 14:54:50.650 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 77: 0.00029640551656484604, 9.761744877323508e-06
2025-06-03 14:54:50.736 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 83: 0.0004224222502671182, 9.708455763757229e-06
2025-06-03 14:54:50.790 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 53: 0.00012011536455247551, 9.426570613868535e-06
2025-06-03 14:54:50.938 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 141: 0.0012423800071701407, 9.868643246591091e-06
[26]:
Text(0.4, 0.92, 'PCC=0.7228')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_30_2.png

missing rate is equal to 0.9

[27]:
config.missing_ratio=0.9
from torch.utils.data import DataLoader
test_dloader = DataLoader(PRE8dDataset(config, mode='test'), 1, shuffle=False)
datas, data_ob_masks, data_gt_masks, labels, label_masks = next(iter(test_dloader))
device = "cuda"
datas, data_ob_masks, data_gt_masks, labels, label_masks = datas.float().to(device), data_ob_masks.to(device), data_gt_masks.to(device), labels.to(device), label_masks.to(device)

model = torch.load("./log_bak/imputation/PRE/STIMP/best_0.9.pt")
model = model.to(device)
cond_mask = data_gt_masks
adj = np.load("./data/{}/adj.npy".format(config.area))
adj = torch.from_numpy(adj).float().to(device)

imputed_data_our = model.impute(datas, cond_mask, adj, 10)
imputed_data_our = imputed_data_our.median(1).values
mask = data_ob_masks - cond_mask
imputed_our = imputed_data_our[0][mask.bool().cpu()[0]]
truth = datas[0][mask.bool()[0]].cpu()
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['STIMP' for i in range(imputed_our.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_our,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_our).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
[27]:
Text(0.4, 0.92, 'PCC=0.9454')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_32_1.png
[28]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()
model.fit(tmp_data)

imputed_data = model.predict()
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=46)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:57:25.741 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 79: 0.00048423392581753433, 9.798997780308127e-06
[28]:
Text(0.4, 0.92, 'PCC=0.4360')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_33_2.png
[29]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(4, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=5
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = model.predict()
imputed_data_list.append(imputed_data[:,46//t*t-46:])
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=5)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 14:58:46.056 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 33: 0.00011798148625530303, 9.892784873954952e-06
2025-06-03 14:58:46.226 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 44: 0.00010314747487427667, 9.284027328249067e-06
2025-06-03 14:58:46.374 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 40: 9.238488564733416e-05, 9.67855885392055e-06
2025-06-03 14:58:46.480 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 31: 0.00014137984544504434, 9.723618859425187e-06
2025-06-03 14:58:46.524 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 13: 0.00020046990539412946, 3.7585559766739607e-06
2025-06-03 14:58:46.544 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 0: 3.993541142222057e-08, 3.993541142222057e-08
2025-06-03 14:58:46.586 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 12: 0.00037573862937279046, 6.721587851643562e-06
2025-06-03 14:58:46.674 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 23: 2.4783144908724353e-05, 8.065791917033494e-06
2025-06-03 14:58:46.739 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 16: 2.9570139304269105e-05, 8.661849278723821e-06
[29]:
Text(0.4, 0.92, 'PCC=0.3142')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_34_2.png
[30]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=11
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = model.predict()
imputed_data_list.append(imputed_data[:,46//t*t-46:])
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=11)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 15:00:12.555 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 23: 3.539193494361825e-05, 9.682029485702515e-06
2025-06-03 15:00:12.699 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 21: 7.153443584684283e-05, 9.207979019265622e-06
2025-06-03 15:00:12.722 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 0: 9.557592761666456e-08, 9.557592761666456e-08
2025-06-03 15:00:12.844 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 19: 4.859522596234456e-05, 8.525963494321331e-06
[30]:
Text(0.4, 0.92, 'PCC=0.4232')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_35_2.png
[31]:
# DINEOF
from einops import rearrange
from model.dineof import DINEOF
model = DINEOF(10, [60, 96, config.in_len], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

t=23
tmp_data = torch.where(cond_mask_image.cpu()==0, float("nan"), datas_image.cpu())
tmp_data = rearrange(tmp_data, "b t c h w -> (b h) (w c) t")
tmp_data = tmp_data.cpu().numpy()

imputed_data_list = []

for i in range(46//t):
    model.fit(tmp_data[:,:,i*t:(i+1)*t])
    imputed_data = model.predict()
    imputed_data_list.append(imputed_data)
imputed_data = np.concatenate(imputed_data_list, axis=-1)
imputed_data = rearrange(imputed_data, "(b h w c) t->b t c h w", b=1, t=datas_image.shape[1], c=1, h=datas_image.shape[-2], w=datas_image.shape[-1])

imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=23)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 15:01:38.668 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 50: 0.00020601184223778546, 9.623632649891078e-06
2025-06-03 15:01:39.226 | INFO     | model.dineof:_fit:103 - Error/Relative Error at iteraion 45: 0.00015742632967885584, 9.525800123810768e-06
[31]:
Text(0.4, 0.92, 'PCC=0.3728')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_36_2.png
[32]:
from einops import rearrange
from model.dineof_per_step import DINEOF
model = DINEOF(10, [60, 96], keep_non_negative_only=False)
is_sea = np.load("./data/PRE/is_sea.npy")
datas_image = torch.zeros(1,46,1,60,96)
datas_image = datas_image.to(device)
datas_image[:,:,:,is_sea.astype(bool)]=datas
cond_mask_image = torch.zeros(1,46,1,60,96)
cond_mask_image = cond_mask_image.to(device)
cond_mask_image[:,:,:,is_sea.astype(bool)]=cond_mask
ob_mask_image = torch.zeros(1,46,1,60,96)
ob_mask_image = ob_mask_image.to(device)
ob_mask_image[:,:,:,is_sea.astype(bool)]=data_ob_masks

impute_data_list = []
for t in range(datas_image.shape[1]):
    data = datas_image[:, t, :, :, :].squeeze()

    tmp_data = torch.where(cond_mask_image[:,t].cpu().squeeze()==0, float("nan"), data.cpu())
    model.fit(tmp_data.numpy())

    imputed_data = model.predict()
    imputed_data = rearrange(imputed_data, "(b t c h) w->b t c h w", b=1, t=1, c=1, h=data.shape[-2], w=data.shape[-1])
    impute_data_list.append(torch.from_numpy(imputed_data))

imputed_data = torch.cat(impute_data_list, dim=1).numpy()
imputed_data_dineof = imputed_data[:,:,:,is_sea.astype(bool)]
imputed_dineof = imputed_data_dineof[0][mask.bool().cpu()[0]]
imputed_dineof = imputed_dineof[:]
truth = truth[:]

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stat
import pandas as pd
method = []
method.extend(['DINEOF (t=1)' for i in range(imputed_dineof.shape[0])])
data = {'truth': truth.numpy(),
        'imputed':imputed_dineof,
        'method':method}
data = pd.DataFrame.from_dict(data)
color =["#9F0000","#576fa0"]
g=sns.jointplot(data=data, x="truth", y="imputed", hue="method", kind='kde', palette=color, marginal_kws={'common_norm':False,'shade':True})
xpoints = ypoints = plt.xlim()
plt.plot(xpoints, ypoints, 'k-', alpha=0.75, zorder=0)
plt.ylim(ypoints)
# plt.legend()
plt.ylabel("imputed", size=24)
plt.xlabel("truth", size=24)
pcc1=stat.pearsonr(truth[:], imputed_dineof).statistic.item()
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.text(0.40, 0.92, 'PCC={:.4f}'.format(pcc1), transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', color="#9F0000")
2025-06-03 15:03:01.977 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 68: 0.0002469733590260148, 7.116614142432809e-06
2025-06-03 15:03:02.033 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 53: 0.00028734200168401003, 9.937997674569488e-06
2025-06-03 15:03:02.098 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 62: 0.000163619639351964, 9.96104790829122e-06
2025-06-03 15:03:02.169 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 68: 0.00018983388144988567, 9.617462637834251e-06
2025-06-03 15:03:02.228 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 57: 0.00013680299161933362, 9.737603249959648e-06
2025-06-03 15:03:02.305 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 74: 0.00017466569261159748, 9.627037798054516e-06
2025-06-03 15:03:02.369 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 61: 0.0001767356734490022, 9.827723260968924e-06
2025-06-03 15:03:02.455 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 85: 0.0004490058054216206, 9.708455763757229e-06
2025-06-03 15:03:02.508 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 50: 0.00013546067930292338, 9.780691470950842e-06
2025-06-03 15:03:02.578 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 68: 0.00017693002882879227, 9.715004125609994e-06
2025-06-03 15:03:02.634 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 54: 0.0001526108244433999, 9.481576853431761e-06
2025-06-03 15:03:02.687 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 51: 0.00015381944831460714, 9.89530235528946e-06
2025-06-03 15:03:02.741 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 52: 0.00028402096359059215, 5.21860783919692e-07
2025-06-03 15:03:02.759 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 13: 0.008920789696276188, 6.628222763538361e-06
2025-06-03 15:03:02.824 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 64: 0.00016729900380596519, 9.70781547948718e-06
2025-06-03 15:03:02.837 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 8: 0.0022271466441452503, 3.72203066945076e-06
2025-06-03 15:03:02.892 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 52: 0.00022145792900118977, 8.946226444095373e-06
2025-06-03 15:03:02.943 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 49: 0.00021000468404963613, 8.017756044864655e-06
2025-06-03 15:03:03.002 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 57: 0.00014132916112430394, 9.623196092434227e-06
2025-06-03 15:03:03.064 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 59: 0.0001869166299002245, 9.769588359631598e-06
2025-06-03 15:03:03.128 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 63: 0.0002919154358096421, 9.982468327507377e-06
2025-06-03 15:03:03.171 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 39: 0.0008617845014669001, 9.122071787714958e-06
2025-06-03 15:03:03.220 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 46: 0.000902706990018487, 2.9537477530539036e-06
2025-06-03 15:03:03.267 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 46: 0.00012932637764606625, 9.652139851823449e-06
2025-06-03 15:03:03.284 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 15: 2.2864969650981948e-05, 9.293002221966162e-06
2025-06-03 15:03:03.338 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 53: 9.275275806430727e-05, 9.374693036079407e-06
2025-06-03 15:03:03.372 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 35: 7.161369467212353e-06, 6.891042175993789e-06
2025-06-03 15:03:03.641 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 299: nan, nan
2025-06-03 15:03:03.694 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 52: 8.17550826468505e-05, 9.268253052141517e-06
2025-06-03 15:03:03.743 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 49: 2.448569102853071e-05, 6.821412171120755e-06
2025-06-03 15:03:03.805 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 60: 0.00017569950432516634, 9.905576007440686e-06
2025-06-03 15:03:03.871 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 65: 0.00017581140855327249, 9.99892654363066e-06
2025-06-03 15:03:03.915 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 42: 0.0001483052910771221, 9.915122063830495e-06
2025-06-03 15:03:03.978 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 62: 0.00014931590703781694, 9.512848919257522e-06
2025-06-03 15:03:04.042 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 63: 0.00014058890519663692, 9.58556483965367e-06
2025-06-03 15:03:04.110 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 66: 0.00016133516328409314, 9.624112863093615e-06
2025-06-03 15:03:04.134 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 22: 3.2853440643521026e-05, 8.170954970410094e-06
2025-06-03 15:03:04.183 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 46: 0.00019005437206942588, 8.942050044424832e-06
2025-06-03 15:03:04.250 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 65: 0.0001946628326550126, 9.887953638099134e-06
2025-06-03 15:03:04.319 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 67: 0.00022103033552411944, 9.598443284630775e-06
2025-06-03 15:03:04.377 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 55: 0.00017595656390767545, 9.96702874545008e-06
2025-06-03 15:03:04.433 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 54: 0.0001307205529883504, 9.78204479906708e-06
2025-06-03 15:03:04.470 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 33: 0.002378583187237382, 5.877809599041939e-06
2025-06-03 15:03:04.527 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 54: 0.0004299607826396823, 9.49940294958651e-06
2025-06-03 15:03:04.564 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 36: 0.00016774632968008518, 9.971132385544479e-06
2025-06-03 15:03:04.638 | INFO     | model.dineof_per_step:_fit:102 - Error/Relative Error at iteraion 71: 0.0002451682521495968, 9.91017441265285e-06
[32]:
Text(0.4, 0.92, 'PCC=0.3533')
../../_images/supplementary_dineof_14-imputation-pearl-river-estuary-dineof_37_2.png