diff --git a/lora_finetine_dino_large.py b/lora_finetine_dino_large.py new file mode 100644 index 0000000000000000000000000000000000000000..23802c03dc93dd0ace905696eee0bd0fad2c25cf --- /dev/null +++ b/lora_finetine_dino_large.py @@ -0,0 +1,229 @@ +import transformers +import accelerate +import peft +from peft import PeftConfig, PeftModel +from peft import LoraConfig, get_peft_model +from datasets import Dataset +from sklearn.model_selection import train_test_split +from datasets import load_from_disk +from transformers import AutoImageProcessor +from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay +from sklearn.metrics import classification_report +import numpy as np +import evaluate +from transformers import AutoModelForImageClassification, TrainingArguments, Trainer +import torch + +model_checkpoint = "facebook/dinov2-large" + +dataset = load_from_disk('wikiart_train') +print(dataset) +labels = dataset.features["genre"].names[:-1] +print(labels) + +label2id, id2label = dict(), dict() +for i, label in enumerate(labels): + label2id[label] = i + id2label[i] = label + +image_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large') + +from torchvision.transforms import ( + CenterCrop, + Compose, + Normalize, + RandomHorizontalFlip, + RandomResizedCrop, + RandomAdjustSharpness, + Resize, + ToTensor, +) + +img_size = 224 +normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) +train_transforms = Compose( + [ + RandomResizedCrop(size=(img_size, img_size), scale=(0.8, 1)), + RandomAdjustSharpness(sharpness_factor=3), + RandomHorizontalFlip(), + ToTensor(), + normalize, + ] +) + +val_transforms = Compose( + [ + Resize(size=(img_size, img_size)), + RandomAdjustSharpness(sharpness_factor=3), + RandomHorizontalFlip(), + ToTensor(), + normalize, + ] +) + +def preprocess_train(example_batch): + """Apply train_transforms across a batch.""" + example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]] + return example_batch + +def preprocess_val(example_batch): + """Apply val_transforms across a batch.""" + example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]] + return example_batch + +splits = dataset.train_test_split(test_size=0.1) +train_ds = splits["train"] +val_ds = splits["test"] + +train_ds.set_transform(preprocess_train) +val_ds.set_transform(preprocess_val) + +def print_trainable_parameters(model): + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}" + ) + + +model = AutoModelForImageClassification.from_pretrained( + model_checkpoint, + label2id=label2id, + id2label=id2label, + ignore_mismatched_sizes=True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint +) + +print_trainable_parameters(model) + +config = LoraConfig( + r=32, + lora_alpha=32, + target_modules=["query", "value"], +# target_modules=['value', 'fc2', 'fc1', 'key', 'query', 'dense'], + lora_dropout=0.1, + bias="none", + modules_to_save=["classifier"], +) + +lora_model = get_peft_model(model, config) +print_trainable_parameters(lora_model) + +model_name = model_checkpoint.split("/")[-1] +batch_size = 16 + +args = TrainingArguments( + f"{model_name}-finetuned-lora", + remove_unused_columns=False, + evaluation_strategy="epoch", + save_strategy="epoch", + learning_rate=1e-5, + per_device_train_batch_size=batch_size, + gradient_accumulation_steps=4, + per_device_eval_batch_size=batch_size, + fp16=True, + num_train_epochs=40, + logging_steps=10, + load_best_model_at_end=True, + metric_for_best_model="accuracy", + push_to_hub=True, + label_names=["labels"], +) + +metric = evaluate.load("accuracy") + +def compute_metrics(eval_pred): + """Computes accuracy on a batch of predictions""" + predictions = np.argmax(eval_pred.predictions, axis=1) + return metric.compute(predictions=predictions, references=eval_pred.label_ids) + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + labels = torch.tensor([example["genre"] for example in examples]) + return {"pixel_values": pixel_values, "labels": labels} + +trainer = Trainer( + lora_model, + args, + train_dataset=train_ds, + eval_dataset=val_ds, + tokenizer=image_processor, + compute_metrics=compute_metrics, + data_collator=collate_fn, +) + + +# Train the model +train_results = trainer.train() +# compute train results +metrics = train_results.metrics +metrics = trainer.evaluate(val_ds) + +repo_name = f".../{model_name}-finetuned-lora-dino_large" +lora_model.push_to_hub(repo_name) +image_processor.push_to_hub(repo_name) + +config = PeftConfig.from_pretrained(repo_name) + +model = AutoModelForImageClassification.from_pretrained( + config.base_model_name_or_path, + label2id=label2id, + id2label=id2label, + ignore_mismatched_sizes=True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint +) + +# Load the LoRA model +inference_model = PeftModel.from_pretrained(model, repo_name) + +image_processor = AutoImageProcessor.from_pretrained(repo_name) + +test_ds = load_from_disk('wikiart_test') +print(test_ds) + +test_transforms = Compose( + [ + Resize((img_size, img_size)), + ToTensor(), + normalize, + ] +) + +def preprocess_test(example_batch): + """Apply test_transforms across a batch.""" + example_batch["pixel_values"] = [test_transforms(image.convert("RGB")) for image in example_batch["image"]] + return example_batch + + +test_ds.set_transform(preprocess_test) + +predictions = [] +true_labels = [] + +for example in test_ds: + image = example["image"] + encoding = image_processor(image, return_tensors="pt") + + with torch.no_grad(): + outputs = inference_model(**encoding) + logits = outputs.logits + + predicted_class_idx = logits.argmax(-1).item() + # Append predictions and true labels + predictions.append(predicted_class_idx) + true_labels.append(example['genre']) + +# Calculate confusion matrix with probabilities +conf_matrix = confusion_matrix(true_labels, predictions, normalize='true') + +# Display confusion matrix +classes = sorted(set(true_labels)) +disp = ConfusionMatrixDisplay(conf_matrix, display_labels=classes) +plt = disp.plot(values_format=".2f", cmap="Blues") + +metrics = trainer.evaluate(test_ds) +print(metrics) + +print(classification_report(true_labels, predictions, target_names=labels, digits=4))