Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
huggingface
GitHub Repository: huggingface/notebooks
Path: blob/main/examples/image_captioning_blip.ipynb
5906 views
Kernel: Python 3 (ipykernel)

Fine-tune BLIP using Hugging Face transformers and datasets 🤗

This tutorial is largely based from the GiT tutorial on how to fine-tune GiT on a custom image captioning dataset. Here we will use a dummy dataset of football players ⚽ that is uploaded on the Hub. The images have been manually selected together with the captions. Check the 🤗 documentation on how to create and upload your own image-text dataset.

Set-up environment

!pip install git+https://github.com/huggingface/transformers.git@main
!pip install -q datasets

We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

from transformers.utils import send_example_telemetry send_example_telemetry("image_captioning_blip_notebook", framework="pytorch")

Load the image captioning dataset

Let's load the image captioning dataset, you just need few lines of code for that.

from datasets import load_dataset dataset = load_dataset("ybelkada/football-dataset", split="train")

Let's retrieve the caption of the first example:

dataset[0]["text"]

And the corresponding image

dataset[0]["image"]

Create PyTorch Dataset

The lines below are entirely copied from the original notebook!

from torch.utils.data import Dataset, DataLoader class ImageCaptioningDataset(Dataset): def __init__(self, dataset, processor): self.dataset = dataset self.processor = processor def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] encoding = self.processor(images=item["image"], text=item["text"], padding="max_length", return_tensors="pt") # remove batch dimension encoding = {k:v.squeeze() for k,v in encoding.items()} return encoding

Load model and processor

from transformers import AutoProcessor, BlipForConditionalGeneration processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

Now that we have loaded the processor, let's load the dataset and the dataloader:

train_dataset = ImageCaptioningDataset(dataset, processor) train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2)

Train the model

Let's train the model! Run the simply the cell below for training the model

import torch optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.train() for epoch in range(50): print("Epoch:", epoch) for idx, batch in enumerate(train_dataloader): input_ids = batch.pop("input_ids").to(device) pixel_values = batch.pop("pixel_values").to(device) outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids) loss = outputs.loss print("Loss:", loss.item()) loss.backward() optimizer.step() optimizer.zero_grad()

Inference

Let's check the results on our train dataset

# load image example = dataset[0] image = example["image"] image
# prepare image for the model inputs = processor(images=image, return_tensors="pt").to(device) pixel_values = inputs.pixel_values generated_ids = model.generate(pixel_values=pixel_values, max_length=50) generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] print(generated_caption)

Load from the Hub

Once trained you can push the model and processor on the Hub to use them later. Meanwhile you can play with the model that we have fine-tuned!

from transformers import BlipForConditionalGeneration, AutoProcessor model = BlipForConditionalGeneration.from_pretrained("ybelkada/blip-image-captioning-base-football-finetuned").to(device) processor = AutoProcessor.from_pretrained("ybelkada/blip-image-captioning-base-football-finetuned")

Let's check the results on our train dataset!

from matplotlib import pyplot as plt fig = plt.figure(figsize=(18, 14)) # prepare image for the model for i, example in enumerate(dataset): image = example["image"] inputs = processor(images=image, return_tensors="pt").to(device) pixel_values = inputs.pixel_values generated_ids = model.generate(pixel_values=pixel_values, max_length=50) generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] fig.add_subplot(2, 3, i+1) plt.imshow(image) plt.axis("off") plt.title(f"Generated caption: {generated_caption}")