From d8c19c08a7a6c6bc4f373e88ca4a72bafa1fc23f Mon Sep 17 00:00:00 2001
From: Tetiana Yemelianenko <tyemel.mzeom@gmail.com>
Date: Tue, 12 Nov 2024 14:39:18 +0000
Subject: [PATCH] Upload New File

---
 lora_finetine_dino_large.py | 229 ++++++++++++++++++++++++++++++++++++
 1 file changed, 229 insertions(+)
 create mode 100644 lora_finetine_dino_large.py

diff --git a/lora_finetine_dino_large.py b/lora_finetine_dino_large.py
new file mode 100644
index 0000000..23802c0
--- /dev/null
+++ b/lora_finetine_dino_large.py
@@ -0,0 +1,229 @@
+import transformers
+import accelerate
+import peft
+from peft import PeftConfig, PeftModel
+from peft import LoraConfig, get_peft_model
+from datasets import Dataset
+from sklearn.model_selection import train_test_split
+from datasets import load_from_disk
+from transformers import AutoImageProcessor
+from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
+from sklearn.metrics import classification_report
+import numpy as np
+import evaluate
+from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
+import torch
+
+model_checkpoint = "facebook/dinov2-large"
+
+dataset = load_from_disk('wikiart_train')
+print(dataset)
+labels = dataset.features["genre"].names[:-1]
+print(labels)
+
+label2id, id2label = dict(), dict()
+for i, label in enumerate(labels):
+    label2id[label] = i
+    id2label[i] = label
+
+image_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large')
+
+from torchvision.transforms import (
+    CenterCrop,
+    Compose,
+    Normalize,
+    RandomHorizontalFlip,
+    RandomResizedCrop,
+    RandomAdjustSharpness,
+    Resize,
+    ToTensor,
+)
+
+img_size = 224
+normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
+train_transforms = Compose(
+    [
+        RandomResizedCrop(size=(img_size, img_size), scale=(0.8, 1)),
+        RandomAdjustSharpness(sharpness_factor=3),
+        RandomHorizontalFlip(),
+        ToTensor(),
+        normalize,
+    ]
+)
+
+val_transforms = Compose(
+    [
+        Resize(size=(img_size, img_size)),
+        RandomAdjustSharpness(sharpness_factor=3),
+        RandomHorizontalFlip(),
+        ToTensor(),
+        normalize,
+    ]
+)
+
+def preprocess_train(example_batch):
+    """Apply train_transforms across a batch."""
+    example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
+    return example_batch
+
+def preprocess_val(example_batch):
+    """Apply val_transforms across a batch."""
+    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
+    return example_batch
+
+splits = dataset.train_test_split(test_size=0.1)
+train_ds = splits["train"]
+val_ds = splits["test"]
+
+train_ds.set_transform(preprocess_train)
+val_ds.set_transform(preprocess_val)
+
+def print_trainable_parameters(model):
+    trainable_params = 0
+    all_param = 0
+    for _, param in model.named_parameters():
+        all_param += param.numel()
+        if param.requires_grad:
+            trainable_params += param.numel()
+    print(
+        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
+    )
+
+
+model = AutoModelForImageClassification.from_pretrained(
+    model_checkpoint,
+    label2id=label2id,
+    id2label=id2label,
+    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
+)
+
+print_trainable_parameters(model)
+
+config = LoraConfig(
+    r=32,
+    lora_alpha=32,
+    target_modules=["query", "value"],
+#    target_modules=['value', 'fc2', 'fc1', 'key', 'query', 'dense'], 
+    lora_dropout=0.1,
+    bias="none",
+    modules_to_save=["classifier"],
+)
+
+lora_model = get_peft_model(model, config)
+print_trainable_parameters(lora_model)
+
+model_name = model_checkpoint.split("/")[-1]
+batch_size = 16
+
+args = TrainingArguments(
+    f"{model_name}-finetuned-lora",
+    remove_unused_columns=False,
+    evaluation_strategy="epoch",
+    save_strategy="epoch",
+    learning_rate=1e-5,
+    per_device_train_batch_size=batch_size,
+    gradient_accumulation_steps=4,
+    per_device_eval_batch_size=batch_size,
+    fp16=True,
+    num_train_epochs=40,
+    logging_steps=10,
+    load_best_model_at_end=True,
+    metric_for_best_model="accuracy",
+    push_to_hub=True,
+    label_names=["labels"],
+)
+
+metric = evaluate.load("accuracy")
+
+def compute_metrics(eval_pred):
+    """Computes accuracy on a batch of predictions"""
+    predictions = np.argmax(eval_pred.predictions, axis=1)
+    return metric.compute(predictions=predictions, references=eval_pred.label_ids)
+
+def collate_fn(examples):
+    pixel_values = torch.stack([example["pixel_values"] for example in examples])
+    labels = torch.tensor([example["genre"] for example in examples])
+    return {"pixel_values": pixel_values, "labels": labels}
+
+trainer = Trainer(
+    lora_model,
+    args,
+    train_dataset=train_ds,
+    eval_dataset=val_ds,
+    tokenizer=image_processor,
+    compute_metrics=compute_metrics,
+    data_collator=collate_fn,
+)
+
+
+# Train the model
+train_results = trainer.train()
+# compute train results
+metrics = train_results.metrics
+metrics = trainer.evaluate(val_ds)
+
+repo_name = f".../{model_name}-finetuned-lora-dino_large"
+lora_model.push_to_hub(repo_name)
+image_processor.push_to_hub(repo_name)
+
+config = PeftConfig.from_pretrained(repo_name)
+
+model = AutoModelForImageClassification.from_pretrained(
+    config.base_model_name_or_path,
+    label2id=label2id,
+    id2label=id2label,
+    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
+)
+
+# Load the LoRA model
+inference_model = PeftModel.from_pretrained(model, repo_name)
+
+image_processor = AutoImageProcessor.from_pretrained(repo_name)
+
+test_ds = load_from_disk('wikiart_test')
+print(test_ds)
+
+test_transforms = Compose(
+    [
+        Resize((img_size, img_size)),
+        ToTensor(),
+        normalize,
+    ]
+)
+
+def preprocess_test(example_batch):
+    """Apply test_transforms across a batch."""
+    example_batch["pixel_values"] = [test_transforms(image.convert("RGB")) for image in example_batch["image"]]
+    return example_batch
+
+
+test_ds.set_transform(preprocess_test)
+
+predictions = []
+true_labels = []
+
+for example in test_ds:
+        image = example["image"]
+        encoding = image_processor(image, return_tensors="pt")
+
+        with torch.no_grad():
+                outputs = inference_model(**encoding)
+                logits = outputs.logits
+
+        predicted_class_idx = logits.argmax(-1).item()
+        # Append predictions and true labels
+        predictions.append(predicted_class_idx)
+        true_labels.append(example['genre'])
+
+# Calculate confusion matrix with probabilities
+conf_matrix = confusion_matrix(true_labels, predictions, normalize='true')
+
+# Display confusion matrix
+classes = sorted(set(true_labels))
+disp = ConfusionMatrixDisplay(conf_matrix, display_labels=classes)
+plt = disp.plot(values_format=".2f", cmap="Blues")
+
+metrics = trainer.evaluate(test_ds)
+print(metrics)
+
+print(classification_report(true_labels, predictions, target_names=labels, digits=4))
-- 
GitLab