目录

SegRNN-源码理解PMF的多步并行预测

【SegRNN 源码理解】PMF的多步并行预测

位置编码

        elif self.dec_way == "pmf":
            if self.channel_id:
                # m,d//2 -> 1,m,d//2 -> c,m,d//2
                # c,d//2 -> c,1,d//2 -> c,m,d//2
                # c,m,d -> cm,1,d -> bcm, 1, d
                pos_emb = torch.cat([
                    self.pos_emb.unsqueeze(0).repeat(self.enc_in, 1, 1),
                    self.channel_emb.unsqueeze(1).repeat(1, self.seg_num_y, 1)
                ], dim=-1).view(-1, 1, self.d_model).repeat(batch_size,1,1)

https://i-blog.csdnimg.cn/direct/b86a235db12449188d223f6e66079865.png

https://i-blog.csdnimg.cn/direct/925d09677c664b8e8639bb09287ea3ae.png https://i-blog.csdnimg.cn/direct/053f9cb7506d40a08ae65e8a8020f17c.png

https://i-blog.csdnimg.cn/direct/e53fcf3093974597ac9d8a1b8a8e2c68.png

┌─────────────┐              ┌─────────────┐
  pos_emb                   channel_emb 
  [2, 256]                   [7, 256]   
└──────┬──────┘              └──────┬──────┘
                                   
                                   
┌─────────────┐              ┌─────────────┐
 unsqueeze(0)              unsqueeze(1) 
  [1,2,256]                 [7,1,256]   
└──────┬──────┘              └──────┬──────┘
                                   
                                   
┌─────────────┐              ┌─────────────┐
 repeat                     repeat      
 [7,2,256]                  [7,2,256]   
└──────┬──────┘              └──────┬──────┘
                                   
       └────────────┬───────────────┘
                    
                    
            ┌───────────────┐
             concat(dim=-1)
              [7,2,512]    
            └───────┬───────┘
                    
                    
            ┌───────────────┐
             view(-1,1,512)
              [14,1,512]   
            └───────┬───────┘
                    
                    
            ┌───────────────┐
             repeat        
             [224,1,512]   
            └───────┬───────┘
                    
                    
              最终组合嵌入

https://i-blog.csdnimg.cn/direct/9ba3ceedd1eb424a943658c2e37ef06c.png

理解时间序列数据的训练集、序列长度和批次大小

我又有了一个新的问题,训练集大小是 593 个样本,怎么 batchsize=16,seqlen=60

用自己的话,593 个样本,随机选择了 16 个,这16 个样本的时间步并不一定是连续的,但是一个 batch 内部封装的 60 个时间步一定是连续的

https://i-blog.csdnimg.cn/direct/7aa8dbd526d7477d8a37687b8a3040fe.png

https://i-blog.csdnimg.cn/direct/d11760c233ac4e6b8c6f9b0469b4315e.png

https://i-blog.csdnimg.cn/direct/8150ea6399134198b182344fbd6c7697.png

https://i-blog.csdnimg.cn/direct/6b7898c09fda4eacaf64f52e59d5d9e4.png

shuffle=True

https://i-blog.csdnimg.cn/direct/76727f33d11b417bb38284add61c5eca.png

https://i-blog.csdnimg.cn/direct/45f01b5ad6614c6ea9720b6fdcf0647c.png https://i-blog.csdnimg.cn/direct/ff9827e333204c079e096fbe7665623b.png

https://i-blog.csdnimg.cn/direct/5a2f465e092b4dc3952c2d23202dfae5.png https://i-blog.csdnimg.cn/direct/347436d34fdc4150896e9cdc2719264a.png