3D image segmentation involves partitioning volumetric data into distinct regions to extract meaningful information such as identifying organs, tumors, etc. With applications ranging from medical diagnosis to industrial inspection and robotics, 3D segmentation plays a pivotal role in understanding complex three-dimensional structures and objects. In this guide, we’ll explore the fundamentals of 3D image segmentation in medical imaging and learn how to leverage the MONAI framework with the UNet architecture for segmentation tasks.
This article was published as a part of the Data Science Blogathon.
Image segmentation is a fundamental task in computer vision and medical imaging that involves partitioning an image or a volumetric dataset into multiple regions or segments. Let’s break it down.
Input is an image.
It segments the Region of Interest:
Output is a segmentation mask.
It classifies every pixel of the image to one of the classes i.e. whether it belongs to foreground or background and therefore estimates the Pixel Probabilities.
In this next portion, we will cover an in-depth understanding of how the UNet architecture works. We will explore each element that comprises both encoder and decoder segments along with their respective tasks.
UNet employs both a ‘contracting’ and an ‘expansive’ pathway to achieve accurate segmentation. The contracting pathway follows a conventional convolutional network design, where it repetitively applies two 3×3 convolutions followed by ReLU activation and down sampling through 2×2 max pooling with a stride of 2.
This process doubles the number of feature channels with each iteration, effectively capturing the context of the image.
On the other hand, the expansive pathway focuses on precise localization by upscaling existing features and halving the number of channels using a 2×2 convolution (also known as ‘up-convolution’). This is followed by crop-based concatenation and another round of two consecutive 3×3 convolutions, both finalized with ReLU activation.
The 3D U-Net architecture is quite similar to the UNET. It has an analysis path to the left and a synthesis path to the right.
Each layer in the analysis path contains two 3×3×3 convolutions followed by a ReLU, and then a 2×2×2 max pooling with strides of two in each dimension.
Each layer in the synthesis path consists of an up-convolution of 2×2×2 by strides of two in each dimension, followed by two 3×3×3 convolutions each followed by a ReLU.
Shortcut connections from layers of equal resolution in the analysis path provide the essential high-resolution features of the synthesis path. Additionally, a 1x1x1 convolutional layer in the last layer reduces the number of output channels to match the desired number of labels, typically three in medical imaging tasks. There is a batch normalization layer before each ReLU that contributes to the stability and efficiency of the network’s training process.
MONAI (Medical Open Network for AI) is an open-source, community-driven framework designed to facilitate medical image analysis with deep learning. At its core, MONAI provides a rich set of functionalities to facilitate every stage of the medical image analysis pipeline. From data preprocessing and augmentation to model training, evaluation, and deployment, MONAI offers an intuitive workflow designed to streamline the research process.
One of the key strengths of MONAI lies in its extensive library of pre-built components and algorithms, spanning a wide range of tasks such as image transformation, segmentation, registration, and classification.
Here we will be discussing in detail image segmentation particularly speen segmentation using MONAI.
Here we will be discussing in detail image segmentation particularly speen segmentation using MONAI.
The first step is to install MONAI and load all the necessary libraries. You can install MONAI with ‘pip install monai’ and import the necessary libraries.
from monai.utils import first, set_determinism
from monai.transforms import (
AsDiscrete,
AsDiscreted,
EnsureChannelFirstd,
Compose,
CropForegroundd,
LoadImaged,
Orientationd,
RandCropByPosNegLabeld,
SaveImaged,
ScaleIntensityRanged,
Spacingd,
Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
from tqdm import tqdm
When we are talking about 3d image segmentation we deal with the Nifti files. We have a 3d chunk data which is a CT Scan present in Nifti file format. Each slice comprising the 3d chunk of data is called a dicom file. To better understand we can understand CT Scans as videos and each frame of the video are dicom file.
We will be using the spleen dataset that can be found here. http://medicaldecathlon.com/
DICOM (Digital Imaging and Communications in Medicine) files are the standard format for storing medical imaging data, encompassing various modalities such as X-rays, MRI scans, CT scans, and ultrasounds.
This files contain both image data and metadata, including patient information, and acquisition parameters. DICOM groups are collections of DICOM files that are related to each other, such as images from the same study, series, or patient. On the other hand, NIfTI (Neuroimaging Informatics Technology Initiative) files are commonly used in neuroimaging for storing volumetric brain imaging data, such as MRI and MRI scans.
Creating DICOM groups involves organizing DICOM files based on their attributes. This function creates the dicom folders containing the group of a fixed number of dicom files(slices).
def create_groups(in_dir, out_dir, number_slices):
for patient in glob(in_dir + '/*'):
patient_name = os.path.basename(os.path.normpath(patient))
# calculate the number of folders each with
# number_slices of dicom files belonging to the same patient
number_folders = int(len(glob(patient + '/*')) / number_slices)
# print(number_folders)
for i in range(number_folders):
output_path = os.path.join(out_dir, patient_name + '_' + str(i))
os.mkdir(output_path)
# Move the slices into a specific folder
dicom_files = glob(patient + '/*')
for j, file in enumerate(dicom_files[i*number_slices:]):
if j == number_slices:
break
shutil.copy(file, output_path)
# create groups of image dicom files
create_groups(dicom_files_image_path, dicom_groups_image_path, number_slices=40)
print("Creating Dicom Groups from Image dicoms completed!!\n")
# create groups of label dicom files
create_groups(dicom_files_label_path, dicom_groups_label_path, number_slices=40)
print("Creating Dicom Groups from Label dicoms completed!!\n")
While DICOM is widely used in medical imaging, it may not always be the most convenient format for analysis and processing. So the conversion of dicom groups to nifti is required. The conversion process typically involves extracting relevant metadata and pixel data from DICOM files and reformatting them into NIfTI-compatible structures. This can be done using dedicated DICOM to NIfTI conversion.
This function will be used to convert the DICOM folder into NIFTI files after creating the groups with the number of slices that you want.
def dcm2nifti(in_dir, out_dir):
print(glob(in_dir + '/*'))
for folder in tqdm(glob(in_dir + '/*')):
print(folder)
patient_name = os.path.basename(os.path.normpath(folder))
print(patient_name)
dicom2nifti.dicom_series_to_nifti(folder, os.path.join(out_dir, patient_name + '.nii.gz'))
dcm2nifti(dicom_groups_image_path, nifti_files_image_path)
print("Conversion from Image Dicom Groups to Nifti files completed!!\n")
Now that we have the NIFTI files required for the segmentation task, let’s dive into the segmentation task using MONAI.
The first step is to prepare the data after we get the NIFTI files. This prepares file paths for training and validation data by locating NIFTI images and label files for spleen segmentation.
It creates a list of dictionaries, each containing an image file path and its corresponding label file path, and then splits them into training and validation sets.
data_dir = os.path.join("/content/drive/Spleen-Segmentation/Data/Task09_Spleen")
train_images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(
glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [
{"image": image_name, "label": label_name}
for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
MONAI provides a set of powerful tools designed to preprocess and augment medical imaging data called MONAI transforms. These transforms encompass a wide range of operations, including data normalization, resampling, cropping, and intensity adjustments, tailored specifically for medical imaging applications. By applying MONAI transforms to input data, the quality, consistency, and relevance of their datasets can be enhanced, which helps improve the performance of deep learning models.
Let’s have a look at a few of the transforms.
Now that we understand MONAI transforms, let us utilize different Monai transforms for both training and validation data.
The transforms include loading images and labels, ensuring channel-first format, adjusting intensity range, cropping out space, orienting the images, adjusting spacing, and performing random cropping based on positive and negative labels for training. Validation transforms exclude the random cropping for a consistent evaluation.
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(keys=["image", "label"], pixdim=(
1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=4,
image_key="image",
image_threshold=0,
),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(keys=["image", "label"], pixdim=(
1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
]
)
Now let us check the transforms in the Data Loader. We first load a validation dataset, process the first batch, extract an image and its corresponding label, and display a specific slice (at index 80) from both the image and label for visual inspection in a side-by-side plot.
check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
# plot the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 80], cmap="gray")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 80])
plt.axis("off")
plt.show()
Let us visualize a few of the intermediate preprocessing outputs.
Contrast Adjustment and Intensity Scaling
Crop Foreground
Let us now define a 3D U-Net model for semantic segmentation, utilizing GPU if available. The model architecture consists of contracting and expanding paths with specified channels and strides, enhanced by residual units.
The training setup includes the Dice loss, Adam optimizer, and a Dice metric for evaluation, targeting multi-class segmentation with background excluded.
device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")
The next step is to train a U-Net model for semantic segmentation over multiple epochs, here we are training for 500 epochs, evaluating a validation dataset at intervals. It tracks loss, and dice metrics, and saves checkpoints of the model’s state, optimizer state, and training progress to monitor and resume training later.
max_epochs = 500
val_interval = 2
checkpoint = torch.load("/content/drive/Spleen-Segmentation/ImprovedResults/my_checkpoint.pth.tar")
best_metric = checkpoint["best_metric"]
best_metric_epoch = checkpoint["best_metric_epoch"]
epoch_loss_values = checkpoint["train_loss"]
metric_values = checkpoint["val_dice"]
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])
save_dir = "/content/drive/Spleen-Segmentation/ImprovedResults"
checkpoint = {}
for epoch in range(240, max_epochs):
print("-" * 10)
print(f"epoch {epoch + 1}/{max_epochs}")
model.train()
epoch_loss = 0
step = 0
for batch_data in tqdm(train_loader):
step += 1
inputs, labels = (
batch_data["image"].to(device),
batch_data["label"].to(device),
)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# print(
# f"{step}/{len(train_ds) // train_loader.batch_size}, "
# f"train_loss: {loss.item():.4f}")
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
for val_data in val_loader:
val_inputs, val_labels = (
val_data["image"].to(device),
val_data["label"].to(device),
)
roi_size = (160, 160, 160)
sw_batch_size = 4
val_outputs = sliding_window_inference(
val_inputs, roi_size, sw_batch_size, model)
val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
val_labels = [post_label(i) for i in decollate_batch(val_labels)]
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_labels)
# aggregate the final mean dice result
metric = dice_metric.aggregate().item()
# reset the status for next validation round
dice_metric.reset()
metric_values.append(metric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), os.path.join(
save_dir, "best_metric_model.pth"))
print("saved new best metric model")
print(
f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
f"\nbest mean dice: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}"
)
checkpoint["train_loss"] = epoch_loss_values
checkpoint["val_dice"] = metric_values
checkpoint["best_metric_epoch"] = best_metric_epoch
checkpoint["best_metric"] = best_metric
checkpoint["model_state_dict"] = model.state_dict()
checkpoint["optimizer_state_dict"] = optimizer.state_dict()
torch.save(checkpoint, os.path.join(save_dir, "my_checkpoint2.pth.tar"))
Evaluation of the model is a crucial step that measures the agreement between predicted and ground truth segmentations. Here we will be discussing a few evaluation metrics commonly used in image segmentation tasks and carry out visualizations by plotting the loss curves and displaying the outputs.
Evaluation metrics play a crucial role in assessing the performance and accuracy of 3D image segmentation algorithms. Several metrics are commonly used in evaluating 3D image segmentation. Let’s understand a few of them.
It computes the ratio of the intersection to the union of the segmented and ground truth regions, offering a normalized measure of overlap.
It measures the overlap between the segmented region and ground truth, providing a comprehensive measure of segmentation accuracy.
Dice Loss = 1 – Dice Score
Now we visualize the training and validation performance of a model during training epochs.
val_interval = 2
plt.figure("train", (15, 5))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Dice Loss")
x = [i + 1 for i in range(len(checkpoint["train_loss"]))]
y = checkpoint["train_loss"]
plt.xlabel("#Epochs")
plt.ylabel("Dice Loss")
plt.plot(x, y)
plt.plot(checkpoint["best_metric_epoch"],
checkpoint["train_loss"][checkpoint["best_metric_epoch"]], 'r*', markersize=8)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice Score")
x = [val_interval * (i + 1) for i in range(len(checkpoint["val_dice"]))]
y = checkpoint["val_dice"]
plt.xlabel("#Epochs")
plt.plot(x, y)
plt.plot(checkpoint["best_metric_epoch"],
checkpoint["val_dice"][checkpoint["best_metric_epoch"]//2], 'r*', markersize=10)
plt.annotate("Best Score[470, 0.9516]", xy=(checkpoint["best_metric_epoch"],
checkpoint["val_dice"][checkpoint["best_metric_epoch"]//2]))
plt.savefig("LearningCurves.png")
plt.show()
The left subplot displays the average dice loss per epoch, with a red star indicating the epoch with the best validation metric. The right subplot illustrates the mean dice score at validation intervals, with an annotation marking the epoch with the best validation score.
After this, we load a trained UNet model from a specified directory, perform inference on validation data using the sliding window inference technique, and visualize the input image, ground truth label, and model output for a slice along the z-axis.
save_dir = "/content/drive/Spleen-Segmentation/ImprovedResults/"
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
).to(device)
model.load_state_dict(torch.load(
os.path.join(save_dir, "best_metric_model.pth"), map_location=device))
model.eval()
# elapsed_time = 0
with torch.no_grad():
for i, val_data in enumerate(val_loader):
roi_size = (160, 160, 160)
sw_batch_size = 4
# t = time.time()
val_outputs = sliding_window_inference(
val_data["image"].to(device), roi_size, sw_batch_size, model
)
# elapsed_time += time.time() - t
# print("Elapse Time : ", time.time()-t)
# plot the slice [:, :, 80]
plt.figure("check", (18, 6))
plt.subplot(1, 3, 1)
plt.title(f"image {i}")
plt.imshow(val_data["image"][0, 0, :, :, 80], cmap="gray")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.title(f"label {i}")
plt.axis("off")
plt.imshow(val_data["label"][0, 0, :, :, 80])
plt.subplot(1, 3, 3)
plt.title(f"output {i}")
plt.axis("off")
plt.imshow(torch.argmax(
val_outputs, dim=1).detach().cpu()[0, :, :, 80])
plt.show()
if i == 3:
Break
Let us visualize the results closely.
Close to ground truth
Better than ground truth
3D image segmentation can extract meaningful insights from complex data. By partitioning images into distinct regions or structures, clinicians can accurately identify and analyze anatomical features, abnormalities, and pathologies. The applications of 3D image segmentation span across various medical specialties and clinical scenarios discussed below.
Accurate tumor detection through 3D image segmentation aids clinicians in diagnosing malignancies and monitoring disease progression for appropriate treatment planning.
Organ segmentation enables clinicians to assess organ function, identify abnormalities, and plan interventions with higher precision and accuracy.
Precise segmentation of anatomical structures supports optimal treatment planning, guiding surgical trajectories, and delivering targeted therapies with minimal damage to healthy tissues.
Let us dive deeper into the case studies:
In a study published in the Journal of Neurosurgery, researchers utilized 3D image segmentation to delineate tumor boundaries in MRI scans of patients with glioblastoma, a type of malignant brain tumor. Accurate segmentation enabled clinicians to assess tumor size, location, and response to treatment, guiding surgical resection and radiation therapy planning.
In a clinical case presented at a cardiology conference, 3D segmentation of the heart from cardiac MRI scans facilitated treatment planning for patients with congenital heart defects. Precise segmentation of cardiac structures allowed cardiologists to assess ventricular function, identify abnormalities, and plan surgical interventions or cardiac catheterization procedures with improved accuracy and outcomes.
A retrospective analysis of liver transplantation cases demonstrated the utility of 3D liver segmentation from CT scans in surgical planning and donor-recipient matching. Accurate segmentation of liver anatomy enabled surgeons to assess liver volume, vascular structures, and disease extent, facilitating donor selection, graft optimization, and post-transplant monitoring for improved patient outcomes.
These case studies illustrate how 3D image segmentation contributes to improved clinical workflows, personalized treatment planning, and better patient outcomes across a wide range of medical specialties and conditions.
For more details, visit this GitHub repo. https://github.com/bbabina/Spleen-Segmentation-using-Monai-and-Pytorch
In conclusion, 3D image segmentation, particularly in medical imaging, has revolutionized healthcare by providing clinicians with powerful tools to extract valuable insights from complex data. Through techniques like UNet architecture implemented with MONAI framework, there is the possibility of accurate segmentation of anatomical structures, tumors, and abnormalities aiding in diagnosis, treatment planning, and monitoring. Furthermore, the diverse applications of 3D segmentation highlighted in the case studies underscore its profound impact on clinical workflows and patient outcomes, promising a future where medical imaging continues to drive advancements in personalized healthcare.
A. 3D image segmentation involves dividing volumetric data into distinct regions, that are crucial for tasks like identifying organs and tumors. It plays a pivotal role in medical diagnosis, treatment planning, and monitoring.
A. The UNet architecture utilizes both contracting and expansive pathways to achieve accurate segmentation. It captures context through convolutional layers and focuses on precise localization by upscaling features. UNet’s skip connections preserve fine details, that aid in the reconstruction of segmentation maps with high accuracy.
A. MONAI offers a set of functionalities that are tailored for medical image analysis, from data preprocessing to model deployment. The library of pre-built components and algorithms simplifies tasks like image transformation, segmentation, registration, and classification.
The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.
Your project helps me a lot. However I'm still confused on how save the prediction res to nii.gz file?