Meta’s Segment Anything Model (SAM) has demonstrated its ability to detect objects in different areas of an image. This model’s architecture is flexible, and users can guide it with various prompts. During training, it could segment objects that were not in its dataset.
These features make this model a highly effective tool for detecting and segmenting objects for any purpose. This tool can also be used for specific segmentation tasks, as we have seen with industry-based applications like self-driving vehicles and robotics. Another crucial detail of this model is how it can segment images using masks and bounding boxes, which is vital in how it works for medical purposes.
However, Meta’s Segment Anything Model for medical imaging plays a huge role in diagnosing and detecting abnormalities in scanned images. MEDSAM trains a model on image-mask pairs collected from different sources. This dataset also covers over 15 image modalities and over 30 cancer types.
We’ll discuss how this model can detect objects from medical images using bounding boxes.
This article was published as a part of the Data Science Blogathon.
SAM is an image segmentation model developed by Meta to identify objects in almost any region of an image. This model’s best attribute is its versatility, which allows it to generalize when detecting images.
This model was trained on a fascinating 11 million real-world images, but more intriguingly, it can segment objects that are not even present in its dataset.
There are many image segmentation and object detection models with different structures. Models like this could be task-specific or base models, but SAM, being a ‘segment-it-all’ model, can be both as it has a good foundational background to detect millions of images while also leaving room for fine-tuning. That is where researchers come in with various ideas, just like with MEDSAM.
A highlight of SAM’s capabilities is its ability to adapt. It is also a prompt-based segmentation model, which means it can receive information about how to perform segmentation tasks. These include foreground, background, a rough box, bounding boxes, masks, texts, and other information that could help the model segment the image.
The basic principle of this model’s architecture is the image encoder, prompt encoder, and mask encoder. All three components play a huge role in performing the segmentation tasks. The image and prompt encoder help generate the image and prompt embeddings. The mask encoder detects the mask generated for the image you want to segment using the prompt.
Using the Segment Anything Model for medical purposes was worth trying. Also, the model has a large dataset and varying capabilities, so why not medical imaging? However application in medical segmentation came with some limitations due to the nature of medical images and problems with how the model can deal with uncertain bounding boxes in the image. With challenges from the nature of image masks in medical images, the need for specialization becomes essential. So, that brought about the innovation of MEDSAM, a segmentation model built on SAM’s architecture but tailored to medical images.
This model can handle various tasks in anatomic structures and different image instances. Medical imaging gets effective results with this model; 15 imaging modalities and over 30 cancer types show the large scale of medical image segmentation training involved in MEDSAM.
The MEDSAM was built on the pre-trained SAM model. The framework involves the image and prompt encoders generating embeddings for the encoding mask on target images.
The image encoder in the Segment Anything Model processes positional information that requires a lot of computing power. To make the process more efficient, the researchers of this model decided to “freeze” both the image encoder and the prompt encoder. That means they stopped updating or changing these parts during training.
The prompt encoder, which helps understand the positions of objects using data from the bounding-box encoder in SAM, also stayed unchanged. By freezing these components, they reduced the computing power needed and made the system more efficient.
The researchers improved the architecture of this model to make it more efficient. Before prompting the model, they computed the training images’ image embeddings to avoid repeated computations. The mask encoder—the only one fine-tuned —now creates one mask encoder instead of three, as the bounding box helps clearly define the area to segment. This approach made the training more efficient.
Here is a graphical illustration of how this model works:
This model would need some libraries to function, and we’ll dive into how you can run medical imaging segmentation tasks on an image.
We’ll need a few more libraries to run this model, as we also have to draw lines on the bounding boxes as part of the prompt. We’ll start by starting with requests, numpy, and metaplot.
import requests
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from transformers import SamModel, SamProcessor
import torch
The ‘request’ library helps fetch images from their source. The ‘numpy’ library becomes useful because we perform numerical operations involving the coordinates of the bounding boxes. PIL and metaplot assist in image processing and display, respectively. In addition to the SAM model, the processor and torch (handling computation defined in the code below)are important packages for running this model.
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("flaviagiammarino/medsam-vit-base").to(device)
processor = SamProcessor.from_pretrained("flaviagiammarino/medsam-vit-base")
Therefore, the pre-trained model usually uses the most suitable computing device, such as a GPU or CPU. This operation happens before loading the model’s processor and preparing it for image input data.
img_url = "https://huggingface.co/flaviagiammarino/medsam-vit-base/resolve/main/scripts/input.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_boxes = [95., 255., 190., 350.]
Loading the image with a URL is easy, especially with our library in the environment. We can also open the image and convert it to a compatible format for processing. The ‘input_boxes’ list defines the bounding box with coordinates [95, 255, 190, 350]. This number represents the image’s top-left and bottom-right corners of the region of interest. Using the bounding box, we can perform the segmentation task focusing on a specific region.
Next, we process the image input, run the segmentation model, and prepare the output mask. The model processor prepares the raw image and input boxes and converts them into a suitable format for processing. Afterward, the processed input is run to predict mask probabilities. This code results in a refined, probability-based mask for the segmented region.
inputs = processor(raw_image, input_boxes=[[input_boxes]], return_tensors="pt").to(device)
outputs = model(**inputs, multimask_output=False)
probs = processor.image_processor.post_process_masks(outputs.pred_masks.sigmoid().cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu(), binarize=False)
def show_mask(mask, ax, random_color):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([251/255, 252/255, 30/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
Here, we try to show the colored mask on the image using ‘ax. show.’ The show_mask function displays a segmentation mask on a plot. It can use a random color or the default yellow. The mask is resized to fit the image, overlayed with the selected color, and visualized using ‘ax.show’.
Afterward, the function draws a rectangle using the coordinates and its position. This process runs as shown below;
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2))
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(np.array(raw_image))
show_box(input_boxes, ax[0])
ax[0].set_title("Input Image and Bounding Box")
ax[0].axis("off")
ax[1].imshow(np.array(raw_image))
show_mask(mask=probs[0] > 0.5, ax=ax[1], random_color=False)
show_box(input_boxes, ax[1])
ax[1].set_title("MedSAM Segmentation")
ax[1].axis("off")
plt.show()
This code creates a figure with two side-by-side subplots to display the input image with a bounding box and the result. The first subplot shows the original image with the bounding box, and the second shows the image with the mask overlaid and the bounding box.
SAM, as a foundational model is a multipurpose tool; with its high generalization capabilities and the millions of dataset training from real-world images, there is a lot this model can do. Here are some common applications of this model:
MEDSAM is a huge milestone in the Segment Anything Model’s use case. Medical imaging is more complex than regular images; this model helps us understand this context. Using different diagnostic approaches to detect cancer types and other cells in medical imaging can make this model more efficient for task-specific detection.
Meta’s Segment Anything Model’s versatility has shown great potential. Its medical imaging capability is a significant milestone in revolutionizing diagnoses and related tasks in the healthcare industry. Integrating bounding boxes makes it even more effective. Medical imaging can only improve as the SAM base model evolves.
A. SAM is an image processing technique developed by Meta to detect objects and segment them across any region in an image. It can also segment objects not trained in the model’s dataset. This model is trained to operate with prompts and masks and is adaptable across various domains.
A. MEDSAM is a fine-tuned version of SAM specifically designed for medical imaging. While SAM is general-purpose, MEDSAM is optimized to handle the complex nature of medical imaging, which translates to various imaging modalities and cancer detection.
A. This model’s versatility and real-time processing capabilities allow it to be used in real-time applications, including self-driving vehicles and robotics. It can quickly and efficiently detect and understand objects within images.
The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.