Skip to content
Snippets Groups Projects
Commit b9beaff1 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix

parent 10195944
No related branches found
No related tags found
No related merge requests found
...@@ -5,7 +5,7 @@ import pyarrow.parquet as pq ...@@ -5,7 +5,7 @@ import pyarrow.parquet as pq
import pyarrow as pa import pyarrow as pa
import torch import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from loess.loess_1d import loess_1d # from loess.loess_1d import loess_1d
from model.model import ModelTransformer from model.model import ModelTransformer
from config import load_args from config import load_args
...@@ -95,9 +95,9 @@ def predict(data_pred, model, output_path): ...@@ -95,9 +95,9 @@ def predict(data_pred, model, output_path):
if __name__ =='__main__': if __name__ =='__main__':
df = load_lib('spectral_lib/1-240711_ident_resistance_idbioriv_fluoroquinolones_conta_human_sang.parquet') # df = load_lib('spectral_lib/1-240711_ident_resistance_idbioriv_fluoroquinolones_conta_human_sang.parquet')
df = extract_sequence(df) # df = extract_sequence(df)
df.to_csv('spectral_lib/1-240711_ident_resistance_idbioriv_fluoroquinolones_conta_human_sang.csv') # df.to_csv('spectral_lib/1-240711_ident_resistance_idbioriv_fluoroquinolones_conta_human_sang.csv')
# plt.hist(df['RT']) # plt.hist(df['RT'])
# plt.savefig('test.png') # plt.savefig('test.png')
# #
...@@ -174,19 +174,19 @@ if __name__ =='__main__': ...@@ -174,19 +174,19 @@ if __name__ =='__main__':
# #
# args = load_args() args = load_args()
#
# model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
# n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
# decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
# embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first, seq_length=30) embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first, seq_length=30)
#
# if torch.cuda.is_available(): if torch.cuda.is_available():
# model = model.cuda() model = model.cuda()
#
# model.load_state_dict(torch.load(args.model_weigh, weights_only=True)) model.load_state_dict(torch.load(args.model_weigh, weights_only=True))
#
# data_test = load_data(data_source=args.dataset_test, batch_size=args.batch_size, length=30, mode=args.split_test, data_test = load_data(data_source=args.dataset_test, batch_size=args.batch_size, length=30, mode=args.split_test,
# seq_col=args.seq_test) seq_col=args.seq_test)
#
# predict(data_test, model, args.output) predict(data_test, model, args.output)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment