From b54c200dde4ee3bc24a67952cebcf87a6ca31809 Mon Sep 17 00:00:00 2001 From: schne <leo.schneider@ecl19.ec-lyon.fr> Date: Thu, 4 Apr 2024 18:02:28 +0200 Subject: [PATCH] distance Levenshtein --- requirements.txt | 3 +++ res_viz.py | 38 +++++++++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 46644aa..ad56964 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,6 @@ seaborn~=0.13.0 pyopenms~=3.1.0 dlomix~=0.0.6 scikit-learn~=1.4.1.post1 + +Levenshtein~=0.25.0 +keras~=2.15.0 \ No newline at end of file diff --git a/res_viz.py b/res_viz.py index 4117e65..9793ecd 100644 --- a/res_viz.py +++ b/res_viz.py @@ -3,6 +3,8 @@ import matplotlib import numpy as np import matplotlib.pyplot as plt from dlomix.data import RetentionTimeDataset +import pandas as pd +import Levenshtein as lv epoch = 100 number = 8 BATCH_SIZE=64 @@ -10,6 +12,14 @@ BATCH_SIZE=64 test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv', seq_length=30, batch_size=BATCH_SIZE, test=True) +rtdata = RetentionTimeDataset(data_source='database/data_train.csv', + seq_length=30, batch_size=BATCH_SIZE, val_ratio=0.2, test=False) + +data_train = pd.read_csv('database/data_train.csv') +data_test = pd.read_csv('database/data_holdout.csv') +train_seq = data_train['sequence'] +test_seq = data_test['sequence'] + target = test_rtdata.get_split_targets(split="test") index = random.sample(range(target.size), number) target_plot = target[index] @@ -73,4 +83,30 @@ def compare_error(): for i in range(number): plt.plot(error_plot_1[:,i], c = colors[i], ls = 'dashed') plt.plot(error_plot_2[:, i], c = colors[i]) - plt.show() \ No newline at end of file + plt.show() + +def get_worse_seq(n,epoch): + error = error_1[epoch,:] + ind = np.argpartition(error, n)[:n] + return test_seq[ind].to_list(),ind + +def get_nb_iteration(seq_list): + nb_list=[] + counts = train_seq.value_counts() + for seq in seq_list: + nb_list.append(counts[seq]) + + +def compute_min_distance(seq_list): + res=[] + for ref in seq_list : + min_dist = 10000 + for seq in train_seq: + d = 1 - lv.ratio(seq,ref) + if d < min_dist: + min_dist = d + res.append(min_dist) + return res + +radom_1000 = index = random.sample(range(test_seq.size), 1000) + -- GitLab