-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathMoU.py
121 lines (102 loc) · 6.76 KB
/
MoU.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
__all__ = ['MoU']
# Cell
from typing import Callable, Optional
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
import numpy as np
from layers.MoU_backbone import MoU_backbone
from layers.PatchTST_layers import series_decomp
class Model(nn.Module):
def __init__(self, configs, max_seq_len:Optional[int]=1024, d_k:Optional[int]=None, d_v:Optional[int]=None, norm:str='BatchNorm', attn_dropout:float=0.,
act:str="gelu", key_padding_mask:bool='auto',padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True,
pre_norm:bool=False, store_attn:bool=False, pe:str='zeros', learn_pe:bool=True, pretrain_head:bool=False, head_type = 'flatten', verbose:bool=False, **kwargs):
super().__init__()
# load parameters
c_in = configs.enc_in
context_window = configs.seq_len
target_window = configs.pred_len
n_layers = configs.e_layers
n_heads = configs.n_heads
d_model = configs.d_model
d_ff = configs.d_ff
dropout = configs.dropout
fc_dropout = configs.fc_dropout
head_dropout = configs.head_dropout
K = configs.K
conv_stride = configs.conv_stride
conv_kernel_size = configs.conv_kernel_size
entype = configs.entype
postype = configs.postype
ltencoder = configs.ltencoder
individual = configs.individual
head_type = head_type
patch_len = configs.patch_len
stride = configs.stride
padding_patch = configs.padding_patch
revin = configs.revin
affine = configs.affine
subtract_last = configs.subtract_last
decomposition = configs.decomposition
kernel_size = configs.kernel_size
d_state = configs.d_state
dps = configs.dps
num_x = configs.num_x
topk = configs.topk
expand=configs.expand
device = ":".join(["cuda", str(configs.gpu)]) if torch.cuda.is_available() else "cpu"
# model
self.decomposition = decomposition
if self.decomposition:
self.decomp_module = series_decomp(kernel_size)
self.model_trend = MoU_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride,
max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model,
entype=entype, postype=postype, ltencoder=ltencoder,
K=K, conv_stride=conv_stride, conv_kernel_size=conv_kernel_size,
dps=dps, d_state=d_state, num_x=num_x, topk=topk, expand=expand,
n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout,
dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch,
pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine,
subtract_last=subtract_last, verbose=verbose,
device=device, **kwargs)
self.model_res = MoU_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride,
max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model,
entype=entype, postype=postype, ltencoder=ltencoder,
K=K, conv_stride=conv_stride, conv_kernel_size=conv_kernel_size,
dps=dps, d_state=d_state, num_x=num_x, topk=topk, expand=expand,
n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout,
dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch,
pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine,
subtract_last=subtract_last, verbose=verbose,
device=device, **kwargs)
else:
self.model = MoU_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride,
max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model,
entype=entype, postype=postype, ltencoder=ltencoder,
K=K, conv_stride=conv_stride, conv_kernel_size=conv_kernel_size,
dps=dps, d_state=d_state, num_x=num_x, topk=topk, expand=expand,
n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout,
dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch,
pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine,
subtract_last=subtract_last, verbose=verbose,
device=device, **kwargs)
def forward(self, x): # x: [Batch, Input length, Channel]
if self.decomposition:
res_init, trend_init = self.decomp_module(x)
res_init, trend_init = res_init.permute(0,2,1), trend_init.permute(0,2,1) # x: [Batch, Channel, Input length]
res = self.model_res(res_init)
trend = self.model_trend(trend_init)
x = res + trend
x = x.permute(0,2,1) # x: [Batch, Input length, Channel]
else:
x = x.permute(0,2,1) # x: [Batch, Channel, Input length]
x = self.model(x)
x = x.permute(0,2,1) # x: [Batch, Input length, Channel]
return x