diff --git a/recognition/3D-UNT 48790835/README.md b/recognition/3D-UNT 48790835/README.md new file mode 100644 index 000000000..4b3bd372a --- /dev/null +++ b/recognition/3D-UNT 48790835/README.md @@ -0,0 +1,68 @@ +# 3D UNet for Prostate Segmentation + +## Introduction + +This project utilizes the 3D UNet architecture to train on the Prostate 3D dataset, aiming to achieve precise medical volumetric image segmentation. We evaluate the performance of the segmentation using the Dice similarity coefficient, targeting a minimum score of 0.7 for all labels on the test set. Image segmentation transforms a volumetric image into segmented areas represented by masks, which facilitates medical condition analysis, symptom prediction, and treatment planning. + +## Background + +### UNet-3D + +The 3D UNet is an extension of the original UNet architecture, which is widely used for segmenting 2D medical images. While the standard UNet processes 2D images, UNet-3D extends this functionality to volumetric (3D) images, allowing for more accurate segmentation of complex medical structures found in modalities like MRI or CT scans. + +UNet architecture leverages a combination of convolutional neural networks (CNNs) and skip connections, improving performance by combining high-resolution features from the contracting path with low-resolution context from the expansive path. This design maintains spatial information throughout the segmentation process, which is critical in the medical imaging field. + + +![3D U-Net Architecture](https://raw.githubusercontent.com/Han1zen/PatternAnalysis-2024/refs/heads/topic-recognition/recognition/3D-UNT%2048790835/picture/3D%20U-Net.webp) + +### Dataset + +For this project, we will segment the downsampled Prostate 3D dataset. A sample code for loading and processing Nifti file formats is provided in Appendix B. Furthermore, we encourage the use of data augmentation libraries for TensorFlow (TF) or the appropriate transformations in PyTorch to enhance the robustness of the model. + +### Evaluation Metric + +We will employ the Dice similarity coefficient as our primary evaluation metric. The Dice coefficient measures the overlap between the predicted segmentation and the ground truth, mathematically expressed as: + +\[ \text{Dice} = \frac{2 |A \cap B|}{|A| + |B|} \] + +where \( A \) and \( B \) are the sets of predicted and ground truth regions respectively. A Dice coefficient of 0.7 or greater indicates a significant degree of accuracy in segmentation. + +## Objectives + +- Implement the 3D Improved UNet architecture for the Prostate dataset. +- Achieve a minimum Dice similarity coefficient of 0.7 for all labels on the test set. +- Utilize data augmentation techniques to improve model generalization. +- Load and preprocess Nifti file formats for volumetric data analysis. + +## Quick Start + +To get started with the 3D UNet model for prostate segmentation, follow these steps: + +1. **Clone the Repository**: Clone the repository to your local machine. +2. **Install Dependencies**: Ensure you have the required libraries installed. +3. **Prepare the Dataset**: Download the Prostate 3D dataset and place it in the `data/` directory. +4. **Run Training**: Execute the training script to begin training the model on the Prostate 3D dataset. + +## Results + +### Training and Validation Loss + +![Training and Validation Loss](https://github.com/Han1zen/PatternAnalysis-2024/blob/topic-recognition/recognition/3D-UNT%2048790835/picture/train_loss_and_valid_loss.png#:~:text=loss.jpg-,train_loss_and_valid_loss,-.png) + +- The **training loss** curve demonstrates a rapid decline in the early stages of training, indicating that the model is effectively learning and adapting to the training data. +- As training progresses, the loss stabilizes, ultimately reaching around **0.6**. This suggests that the model performs well on the training set and is capable of effective feature learning. + +- The **validation loss** curve also exhibits a downward trend, remaining relatively close to the training loss in the later stages of training. +- This indicates that the model has good generalization capabilities on the validation set, with no significant signs of overfitting. The validation loss stabilizes at approximately **0.62**, further supporting the model's effectiveness. + +### Dice Similarity Coefficient + +![Dice](https://github.com/Han1zen/PatternAnalysis-2024/blob/topic-recognition/recognition/3D-UNT%2048790835/picture/dice.png#:~:text=dice.-,png,-loss.jpg) +- The model achieves a **Dice similarity coefficient** of over **0.7** for all labels, meeting our established target. +- This indicates that the model performs excellently in the segmentation task, accurately identifying and segmenting different regions of the prostate. + + +## References + +1. Sik-Ho Tsang. "Review: 3D U-Net — Volumetric Segmentation (Medical Image Segmentation)." [Towards Data Science](https://towardsdatascience.com/review-3d-u-net-volumetric-segmentation-medical-image-segmentation-8b592560fac1). + diff --git a/recognition/3D-UNT 48790835/dataset.py b/recognition/3D-UNT 48790835/dataset.py new file mode 100644 index 000000000..2af554eb8 --- /dev/null +++ b/recognition/3D-UNT 48790835/dataset.py @@ -0,0 +1,122 @@ +import os +import torch +from torch.utils.data import Dataset, DataLoader +from monai.transforms import ( + Compose, + LoadImaged, + EnsureTyped, + RandFlipd, + Lambdad, + Resized, + EnsureChannelFirstd, + ScaleIntensityd, + RandRotate90d, +) + +# Transforms for training data: load, resize, and apply random flips and rotations +train_transforms = Compose([ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + ScaleIntensityd(keys="image"), # Normalize intensity + Lambdad(keys="image", func=lambda x: (x - x.min()) / (x.max() - x.min())), # Further normalization + RandRotate90d(keys=("image", "label"), prob=0.5), # Random 90-degree rotations + RandFlipd(keys=("image", "label"), prob=0.5, spatial_axis=[0]), + RandFlipd(keys=("image", "label"), prob=0.5, spatial_axis=[1]), + RandFlipd(keys=("image", "label"), prob=0.5, spatial_axis=[2]), + Resized(keys=["image", "label"], spatial_size=(256, 256, 128)), + EnsureTyped(keys=("image", "label"), dtype=torch.float32), +]) + +# Transforms for testing data: only load and normalize +val_transforms = Compose([ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + ScaleIntensityd(keys="image"), + Lambdad(keys="image", func=lambda x: (x - x.min()) / (x.max() - x.min())), + EnsureTyped(keys=("image", "label"), dtype=torch.float32), +]) + +class CustomDataset(Dataset): + """ + Dataset class for reading pelvic MRI data. + """ + + def __init__(self, mode, dataset_path): + """ + Args: + mode (str): One of 'train', 'val', 'test'. + dataset_path (str): Root directory of the dataset. + """ + self.mode = mode + self.train_transform = train_transforms + self.test_transform = val_transforms + + # Load image and label file paths based on mode + if self.mode == 'train': + with open('train_list.txt', 'r') as f: + select_list = [_.strip() for _ in f.readlines()] + self.img_list = [os.path.join(dataset_path, 'semantic_MRs_anon', _) for _ in select_list] + self.label_list = [os.path.join(dataset_path, 'semantic_labels_anon', _.replace('_LFOV', '_SEMANTIC_LFOV')) + for _ in select_list] + + elif self.mode == 'test': + with open('test_list.txt', 'r') as f: + select_list = [_.strip() for _ in f.readlines()] + self.img_list = [os.path.join(dataset_path, 'semantic_MRs_anon', _) for _ in select_list] + self.label_list = [os.path.join(dataset_path, 'semantic_labels_anon', _.replace('_LFOV', '_SEMANTIC_LFOV')) + for _ in select_list] + + def __len__(self): + return len(self.label_list) + + def __getitem__(self, index): + img_path = self.img_list[index] + label_path = self.label_list[index] + + if self.mode == 'train': + augmented = self.train_transform({'image': img_path, 'label': label_path}) + image = augmented['image'] + label = augmented['label'] + + # 确保图像和标签是4D张量 + if image.dim() == 5: # 如果是5D张量 + image = image.squeeze(1) # 去掉通道维度,变为4D张量 (x, y, z) + + if label.dim() == 5: # 如果是5D张量 + label = label.squeeze(1) # 去掉通道维度,变为4D张量 (x, y, z) + + return image, label + + if self.mode == 'test': + augmented = self.test_transform({'image': img_path, 'label': label_path}) + image = augmented['image'] + label = augmented['label'] + + # 确保图像和标签是4D张量 + if image.dim() == 5: # 如果是5D张量 + image = image.squeeze(1) # 去掉通道维度,变为4D张量 (x, y, z) + + if label.dim() == 5: # 如果是5D张量 + label = label.squeeze(1) # 去掉通道维度,变为4D张量 (x, y, z) + + return image, label + + +if __name__ == '__main__': + # Test the dataset + test_dataset = CustomDataset(mode='test', dataset_path=r"path_to_your_dataset") + test_dataloader = DataLoader(dataset=test_dataset, batch_size=2, shuffle=False) + print(len(test_dataset)) + for batch_ndx, sample in enumerate(test_dataloader): + print('test') + print(sample[0].shape) # 应该打印 (batch_size, channels, x, y, z) + print(sample[1].shape) # 应该打印 (batch_size, channels, x, y, z) + break + + train_dataset = CustomDataset(mode='train', dataset_path=r"path_to_your_dataset") + train_dataloader = DataLoader(dataset=train_dataset, batch_size=2, shuffle=False) + for batch_ndx, sample in enumerate(train_dataloader): + print('train') + print(sample[0].shape) # 应该打印 (batch_size, channels, x, y, z) + print(sample[1].shape) # 应该打印 (batch_size, channels, x, y, z) + break \ No newline at end of file diff --git a/recognition/3D-UNT 48790835/modules.py b/recognition/3D-UNT 48790835/modules.py new file mode 100644 index 000000000..9e2d0b3f9 --- /dev/null +++ b/recognition/3D-UNT 48790835/modules.py @@ -0,0 +1,92 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.double_conv = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm3d(out_channels), + nn.ReLU(inplace=True), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm3d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.double_conv(x) + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool3d(2), + DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) + self.conv = DoubleConv(in_channels, out_channels) + else: + self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # input is CHW + diffZ = x2.size()[2] - x1.size()[2] + diffY = x2.size()[3] - x1.size()[3] + diffX = x2.size()[4] - x1.size()[4] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2, + diffZ // 2, diffZ - diffZ // 2]) + + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + +class UNet3D(nn.Module): + def __init__(self, in_channels=1, out_channels=6): + super(UNet3D, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.inc = DoubleConv(in_channels, 64) + self.down1 = Down(64, 128) + self.down2 = Down(128, 256) + self.down3 = Down(256, 512) + self.down4 = Down(512, 512) + self.up1 = Up(512, 256) + self.up2 = Up(256, 128) + self.up3 = Up(128, 64) + self.up4 = Up(64, 64) + self.outc = nn.Conv3d(64, out_channels, kernel_size=1) + + def forward(self, x): + + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + logits = self.outc(x) + return logits \ No newline at end of file diff --git a/recognition/3D-UNT 48790835/picture/3D U-Net.webp b/recognition/3D-UNT 48790835/picture/3D U-Net.webp new file mode 100644 index 000000000..bfa47324f Binary files /dev/null and b/recognition/3D-UNT 48790835/picture/3D U-Net.webp differ diff --git a/recognition/3D-UNT 48790835/picture/dice.png b/recognition/3D-UNT 48790835/picture/dice.png new file mode 100644 index 000000000..2de00cc45 Binary files /dev/null and b/recognition/3D-UNT 48790835/picture/dice.png differ diff --git a/recognition/3D-UNT 48790835/picture/loss.jpg b/recognition/3D-UNT 48790835/picture/loss.jpg new file mode 100644 index 000000000..cb6dd49ee Binary files /dev/null and b/recognition/3D-UNT 48790835/picture/loss.jpg differ diff --git a/recognition/3D-UNT 48790835/picture/train_loss_and_valid_loss.png b/recognition/3D-UNT 48790835/picture/train_loss_and_valid_loss.png new file mode 100644 index 000000000..3411ac1b1 Binary files /dev/null and b/recognition/3D-UNT 48790835/picture/train_loss_and_valid_loss.png differ diff --git a/recognition/3D-UNT 48790835/predict.py b/recognition/3D-UNT 48790835/predict.py new file mode 100644 index 000000000..ac204e03a --- /dev/null +++ b/recognition/3D-UNT 48790835/predict.py @@ -0,0 +1,69 @@ +import torch +import numpy as np +import random +import argparse +from modules import UNet3D +from dataset import Dataset +from torch.utils.data import DataLoader +import torch.nn as nn + +# Set random seed for reproducibility +seed = 42 +torch.manual_seed(seed) +np.random.seed(seed) +random.seed(seed) + +# Load the model +model = UNet3D(in_channel=1, out_channel=6).cuda() +model.load_state_dict(torch.load(r'epoch_2_lossdice1.pth')) +model.eval() + +# Define the test dataloader +test_dataset = Dataset(mode='test', dataset_path=r'C:\Users\111\Desktop\3710\新建文件夹\数据集\Labelled_weekly_MR_images_of_the_male_pelvis-Xken7gkM-\data\HipMRI_study_complete_release_v1') +test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) + +# Define weighted Dice loss function +class WeightedDiceLoss(nn.Module): + def __init__(self, weights=None, smooth=1): + super(WeightedDiceLoss, self).__init__() + self.weights = weights + self.smooth = smooth + + def forward(self, inputs, targets): + # Flatten the input and target tensors + inputs = inputs.view(-1) + targets = targets.view(-1) + + intersection = (inputs * targets).sum() + total = inputs.sum() + targets.sum() + + # Calculate Dice coefficient + dice = (2. * intersection + self.smooth) / (total + self.smooth) + + if self.weights is not None: + return (1 - dice) * self.weights + return 1 - dice + +valid_loss = [] +for idx, (data_x, data_y) in enumerate(test_dataloader): + data_x = data_x.to(torch.float32).cuda() + data_y = data_y.to(torch.float32).cuda().squeeze() + + # Get model outputs + outputs = model(data_x) + + # Get the predicted class with the maximum value + outputs_class = torch.argmax(outputs, dim=1).squeeze() + + # Calculate the intersection with the ground truth + intersection = torch.sum(outputs_class == data_y) + assert outputs_class.size() == data_y.size() + + # Calculate the Dice coefficient + dice_coeff = intersection.item() / outputs_class.numel() + print('Dice Coefficient:', dice_coeff) + valid_loss.append(dice_coeff) + +# Print the average Dice coefficient for the test set +average_loss = np.average(valid_loss) +print('Average Dice Coefficient:', average_loss) diff --git a/recognition/3D-UNT 48790835/train.py b/recognition/3D-UNT 48790835/train.py new file mode 100644 index 000000000..a0d9e9a09 --- /dev/null +++ b/recognition/3D-UNT 48790835/train.py @@ -0,0 +1,165 @@ +import os +import torch +import numpy as np +import random +import argparse +from modules import UNet3D +from dataset import CustomDataset # 确保这里是正确的类名 +from torch.utils.data import DataLoader +import time +import matplotlib.pyplot as plt +import torchio as tio + +# Set random seed +seed = 42 +torch.manual_seed(seed) +np.random.seed(seed) +random.seed(seed) + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.001) +parser.add_argument('--epoch', default=20) +parser.add_argument('--device', default='cuda') +parser.add_argument('--loss', default='dice') +parser.add_argument('--dataset_root', type=str, + default=r'C:\Users\111\Desktop\3710\新建文件夹\数据集\Labelled_weekly_MR_images_of_the_male_pelvis-Xken7gkM-\data\HipMRI_study_complete_release_v1', + help='Root directory of the dataset') +args = parser.parse_args() + +# Define the model +model = UNet3D(in_channels=1, out_channels=6).to(args.device) + +class DiceLoss(torch.nn.Module): + def __init__(self, smooth=1): + super(DiceLoss, self).__init__() + self.smooth = smooth + + def forward(self, inputs, targets): + assert inputs.shape == targets.shape, f"Shapes don't match {inputs.shape} != {targets.shape}" + + # Skip background class + inputs = inputs[:, 1:] + targets = targets[:, 1:] + + # Sum over elements per sample and per class + axes = tuple(range(2, len(inputs.shape))) # 这里的范围从2开始,适应5D张量 + intersection = torch.sum(inputs * targets, axes) + addition = torch.sum(torch.square(inputs) + torch.square(targets), axes) + + # 计算Dice损失 + dice_score = (2 * intersection + self.smooth) / (addition + self.smooth) + return 1 - torch.mean(dice_score) + +criterion = DiceLoss().to(args.device) + +# Define the data augmentation class +class Augment: + def __init__(self): + self.shrink = tio.CropOrPad((16, 32, 32)) + self.flip = tio.transforms.RandomFlip(0, flip_probability=0.5) + + def __call__(self, image, mask): + image = self.shrink(image) + mask = self.shrink(mask) + image = self.flip(image) + mask = self.flip(mask) + return image, mask + +# Define the train and test dataloaders +train_dataset = CustomDataset(mode='train', dataset_path=args.dataset_root) +train_dataloader = DataLoader(dataset=train_dataset, batch_size=1, shuffle=True) +test_dataset = CustomDataset(mode='test', dataset_path=args.dataset_root) +test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) + +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + +train_loss = [] +valid_loss = [] +train_epochs_loss = [] +valid_epochs_loss = [] + +# Initialize data augmentation +augment = Augment() + +# The training and validation process +start_time = time.time() +for epoch in range(args.epoch): + model.train() + train_epoch_loss = [] + + for idx, (data_x, data_y) in enumerate(train_dataloader): + data_x = data_x.to(torch.float32).to(args.device) + data_y = data_y.to(torch.float32).to(args.device) + + # Ensure data_x is 5D + if data_x.dim() == 4: # If it's a 4D tensor + data_x = data_x.unsqueeze(1) # Add a channel dimension + + data_x, data_y = augment(data_x, data_y) # Apply augmentation + + # Ensure data_y is 5D + if data_y.dim() == 4: # If it's a 4D tensor + data_y = data_y.unsqueeze(1) # Add a channel dimension + + labely = torch.nn.functional.one_hot(data_y.squeeze(1).long(), num_classes=6).permute(0, 4, 1, 2, 3).float().to(args.device) + outputs = model(data_x) + optimizer.zero_grad() + loss = criterion(outputs, labely) + loss.backward() + optimizer.step() + train_epoch_loss.append(loss.item()) + train_loss.append(loss.item()) + + train_epochs_loss.append(np.average(train_epoch_loss)) + epoch_time = time.time() - start_time + print(f'Epoch {epoch}: Train Loss: {train_epochs_loss[-1]:.4f}') + + if epoch % 1 == 0: + model.eval() + valid_epoch_loss = [] + with torch.no_grad(): + for idx, (data_x, data_y) in enumerate(test_dataloader): + data_x = data_x.to(torch.float32).to(args.device) + data_y = data_y.to(torch.float32).to(args.device) + + # Ensure data_x is 5D + if data_x.dim() == 4: # If it's a 4D tensor + data_x = data_x.unsqueeze(1) # Add a channel dimension + + # Ensure data_y is 5D + if data_y.dim() == 4: # If it's a 4D tensor + data_y = data_y.unsqueeze(1) # Add a channel dimension + + labely = torch.nn.functional.one_hot(data_y.squeeze(1).long(), num_classes=6).permute(0, 4, 1, 2, 3).float().to(args.device) + outputs = model(data_x) + loss = criterion(outputs, labely) + valid_epoch_loss.append(loss.item()) + valid_loss.append(loss.item()) + + valid_epochs_loss.append(np.average(valid_epoch_loss)) + # Save the trained model + torch.save(model.state_dict(), f'epoch_{epoch}_loss{args.loss}.pth') + +# Plotting the training and validation loss +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) + +# Left plot: Training loss +ax1.plot(train_loss, label='Train Loss (Dice)', color='blue', linewidth=2) +ax1.set_title("Train Loss (Dice)", fontsize=16) +ax1.set_xlabel("Iterations", fontsize=14) +ax1.set_ylabel("Loss", fontsize=14) +ax1.legend() +ax1.grid(True) + +# Right plot: Training and validation loss comparison +ax2.plot(np.arange(0, len(train_epochs_loss)), train_epochs_loss, '-o', label='Epoch Train Loss', color='orange', markersize=4) +ax2.plot(np.arange(0, len(valid_epochs_loss)), valid_epochs_loss, '-o', label='Epoch Valid Loss', color='green', markersize=4) +ax2.set_title("Train and Validation Loss", fontsize=16) +ax2.set_xlabel("Epochs", fontsize=14) +ax2.set_ylabel("Loss", fontsize=14) +ax2.legend() +ax2.grid(True) + +plt.tight_layout() +plt.savefig(f"train_loss_and_valid_loss.png") +plt.show() \ No newline at end of file