Skip to content

Commit

Permalink
fix model
Browse files Browse the repository at this point in the history
  • Loading branch information
WenWeiTHU committed Nov 4, 2024
1 parent 6b77591 commit 4ff2bc2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 38 deletions.
39 changes: 3 additions & 36 deletions models/Timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,39 +41,6 @@ def __init__(self, configs):
else:
raise NotImplementedError

def encoder_top(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x_enc /= stdev

# do patching and embedding
x_enc = x_enc.permute(0, 2, 1)
# u: [bs * nvars x patch_num x d_model]
dec_in, n_vars = self.enc_embedding(x_enc)

# Encoder
# z: [bs * nvars x patch_num x d_model]

return dec_in

def encoder_bottom(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x_enc /= stdev

# do patching and embedding
x_enc = x_enc.permute(0, 2, 1)
# u: [bs * nvars x patch_num x d_model]
dec_in, n_vars = self.enc_embedding(x_enc) # [B * M, N, D]

# Encoder
dec_out, attns = self.decoder(dec_in) # [B * M, N, D]
return dec_out

def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
B, L, M = x_enc.shape

Expand All @@ -87,7 +54,7 @@ def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
x_enc = x_enc.permute(0, 2, 1) # [B, M, T]
dec_in, n_vars = self.enc_embedding(x_enc) # [B * M, N, D]

# Encoder
# Transformer Blocks
dec_out, attns = self.decoder(dec_in) # [B * M, N, D]
dec_out = self.proj(dec_out) # [B * M, N, L]
dec_out = dec_out.reshape(B, M, -1).transpose(1, 2) # [B, T, M]
Expand All @@ -114,7 +81,7 @@ def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
x_enc = x_enc.permute(0, 2, 1) # [B, M, T]
dec_in, n_vars = self.enc_embedding(x_enc) # [B * M, N, D]

# Encoder
# Transformer Blocks
dec_out, attns = self.decoder(dec_in) # [B * M, N, D]
dec_out = self.proj(dec_out) # [B * M, N, L]
dec_out = dec_out.reshape(B, M, -1).transpose(1, 2) # [B, T, M]
Expand All @@ -136,7 +103,7 @@ def anomaly_detection(self, x_enc):
x_enc = x_enc.permute(0, 2, 1) # [B, M, T]
dec_in, n_vars = self.enc_embedding(x_enc) # [B * M, N, D]

# Encoder
# Transformer Blocks
dec_out, attns = self.decoder(dec_in) # [B * M, N, D]
dec_out = self.proj(dec_out) # [B * M, N, L]
dec_out = dec_out.reshape(B, M, -1).transpose(1, 2) # [B, T, M]
Expand Down
4 changes: 2 additions & 2 deletions models/TimerBackbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def __init__(self, configs):
self.patch_embedding = PatchEmbedding(
self.d_model, self.patch_len, self.stride, padding, self.dropout)

# Decoder
# Decoder-only Transformer: Refer to issue: https://github.com/thuml/Large-Time-Series-Model/issues/23
self.decoder = Encoder(
[
EncoderLayer(
EncoderLayer(
AttentionLayer(
FullAttention(True, configs.factor, attention_dropout=configs.dropout,
output_attention=True), configs.d_model, configs.n_heads),
Expand Down

0 comments on commit 4ff2bc2

Please sign in to comment.