diff --git a/common_dataset.py b/common_dataset.py index 6e468458a2e17c03a65e851ece6db955df3f9bb6..5ab8a5f21583168a1371968847032ef911f2f156 100644 --- a/common_dataset.py +++ b/common_dataset.py @@ -95,7 +95,7 @@ def alphabetical_to_numerical(seq, vocab): else : for i in range(len(seq) - 2 * seq.count('-')): if seq[i + dec] != '-': - num.append(IUPAC_VOCAB[seq[i + dec]]) + num.append(ALPHABET_UNMOD[seq[i + dec]]) else: if seq[i + dec + 1:i + dec + 4] == 'CaC': num.append(21) diff --git a/model_custom.py b/model_custom.py index d2888b40567abc2e1ff1f1384cf3b6c26f75dd88..96cfc64035ba8c8c40d79b94b81246c05cc54005 100644 --- a/model_custom.py +++ b/model_custom.py @@ -25,17 +25,14 @@ class PositionalEncoding(nn.Module): pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) - print(d_model, max_len) self.register_buffer('pe', pe) def forward(self, x): - x = torch.permute(x, (1, 0, 2)) """ Arguments: x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` """ - x = x + self.pe[:x.size(0)] return self.dropout(x) @@ -106,8 +103,6 @@ class Model_Common_Transformer(nn.Module): def forward(self, seq, charge): meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() - seq_emb = torch.permute(seq_emb, (0, 2, 1)) - print(seq_emb.size()) emb = self.pos_embedding(self.emb(seq_emb)) meta_enc = self.meta_enc(meta_ohe) @@ -121,7 +116,6 @@ class Model_Common_Transformer(nn.Module): def forward_rt(self, seq): seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() - seq_emb = torch.permute(seq_emb, (0, 2, 1)) emb = self.pos_embedding(self.emb(seq_emb)) enc = self.encoder(emb) out_rt = self.decoder_RT(enc) @@ -131,7 +125,6 @@ class Model_Common_Transformer(nn.Module): def forward_int(self, seq, charge): meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() - seq_emb = torch.permute(seq_emb, (0, 2, 1)) emb = self.pos_embedding(self.emb(seq_emb)) meta_enc = self.meta_enc(meta_ohe) enc = self.encoder(emb)