diff --git a/model_custom.py b/model_custom.py index f829002b87c54452330ac46d4ce6767c4a2e0e64..82c1516a845223c507d42463ae4b32b21520460e 100644 --- a/model_custom.py +++ b/model_custom.py @@ -3,6 +3,7 @@ import torch.nn as nn import torch from tape.models.modeling_bert import ProteinBertModel + class PermuteLayer(nn.Module): def __init__(self, dims): super().__init__() @@ -102,6 +103,7 @@ class Model_Common_Transformer(nn.Module): def forward(self, seq, charge): print(seq.size(),charge.size()) print(seq, charge) + print(torch.cuda.mem_get_info()) meta_ohe = torch.nn.functional.one_hot(charge - 1, self.charge_max).float() seq_emb = torch.nn.functional.one_hot(seq, self.nb_aa).float() print(seq_emb.shape)