chhsiao1981 / SAM-Med3D

SAM-Med3D: An Efficient General-purpose Promptable Segmentation Model for 3D Volumetric Medical Image

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

SAM-Med3D [Paper]

πŸ”₯πŸŒ»πŸ“° News πŸ“°πŸŒ»πŸ”₯

  • [New Checkpoints Release] A newer version of finetuned SAM-Med3D named SAM-Med3D-turbo is released now. We fine-tuned it on 44 datasets (list) to improve the performance. Hope this update can help you πŸ™‚.
  • [New Checkpoints Release] Finetuned SAM-Med3D for organ/brain segmentation is released now! Hope you enjoy the enhanced performance for specific tasks πŸ˜‰. Details are in results and ckpt.
  • [Recommendation] If you are interested in computer vision, we recommend checking out OpenGVLab for more exciting projects like SAM-Med2D!

🌟 Highlights

  • πŸ“š Curated the most extensive volumetric medical dataset to date for training, boasting 131K 3D masks and 247 categories.
  • 🚀 Achieved efficient promptable segmentation, requiring 10 to 100 times fewer prompt points for satisfactory 3D outcomes.
  • πŸ† Conducted a thorough assessment of SAM-Med3D across 15 frequently used volumetric medical image segmentation datasets.

πŸ”¨ Usage

Training / Fine-tuning

(we recommend fine-tuning with SAM-Med3D pre-trained weights from link)

To train the SAM-Med3D model on your own data, follow these steps:

0. (Recommend) Prepare the Pre-trained Weights

Download the checkpoint from ckpt section and move the pth file into SAM_Med3D/ckpt/sam_med3d.pth.

1. Prepare Your Training Data (from nnU-Net-style dataset):

Ensure that your training data is organized according to the structure shown in the data/validation directories. The target file structures should be like the following:

data/train
      β”œβ”€β”€ adrenal
      β”‚ β”œβ”€β”€ ct_WORD
      β”‚ β”‚ β”œβ”€β”€ imagesTr
      β”‚ β”‚ β”‚ β”œβ”€β”€ word_0025.nii.gz
      β”‚ β”‚ β”‚ β”œβ”€β”€ ...
      β”‚ β”‚ β”œβ”€β”€ labelsTr
      β”‚ β”‚ β”‚ β”œβ”€β”€ word_0025.nii.gz
      β”‚ β”‚ β”‚ β”œβ”€β”€ ...
      β”œβ”€β”€ ...

If the original data are in the nnU-Net style, follow these steps:

For a nnU-Net style dataset, the original file structure should be:

Dataset10_WORD
     β”œβ”€β”€ imagesTr
     β”‚ β”œβ”€β”€ word_0025.nii.gz
     β”‚ β”œβ”€β”€ ...
     β”œβ”€β”€ labelsTr
     β”‚ β”œβ”€β”€ word_0025.nii.gz
     β”‚ β”œβ”€β”€ ...

If the labels have multiple classes, you should first split them into multiple binary labels (reference script can be found in script), then re-organize them into multiple sub-folders.

data/train
      β”œβ”€β”€ adrenal
      β”‚ β”œβ”€β”€ ct_WORD
      β”‚ β”‚ β”œβ”€β”€ imagesTr
      β”‚ β”‚ β”‚ β”œβ”€β”€ word_0025.nii.gz
      β”‚ β”‚ β”‚ β”œβ”€β”€ ...
      β”‚ β”‚ β”œβ”€β”€ labelsTr
      β”‚ β”‚ β”‚ β”œβ”€β”€ word_0025.nii.gz (binary label)
      β”‚ β”‚ β”‚ β”œβ”€β”€ ...
      β”œβ”€β”€ liver
      β”‚ β”œβ”€β”€ ct_WORD
      β”‚ β”‚ β”œβ”€β”€ imagesTr
      β”‚ β”‚ β”‚ β”œβ”€β”€ word_0025.nii.gz
      β”‚ β”‚ β”‚ β”œβ”€β”€ ...
      β”‚ β”‚ β”œβ”€β”€ labelsTr
      β”‚ β”‚ β”‚ β”œβ”€β”€ word_0025.nii.gz (binary label)
      β”‚ β”‚ β”‚ β”œβ”€β”€ ...
      β”œβ”€β”€ ...

Then, modify the utils/data_paths.py according to your own data.

img_datas = [
"data/train/adrenal/ct_WORD",
"data/train/liver/ct_WORD",
...
]

2. Run the Training Script:

Run bash train.sh to execute the following command in your terminal:

python train.py --multi_gpu --task_name ${tag}

This will start the training process of the SAM-Med3D model on your prepared data. If you use only one GPU, remove the --multi_gpu flag.

The key options are listed below:

  • task_name: task name
  • checkpoint: pre-trained checkpoint
  • work_dir: results folder for log and ckpt
  • multi_gpu: use multiple GPU with DDP
  • gpu_ids: set gpu ids used for training
  • num_epochs: number of epoches
  • batch_size: batch size for training
  • lr: learning rate for training

Hint: Use the --checkpoint to set the pre-trained weight path, the model will be trained from scratch if no ckpt in the path is found!

Evaluation

Prepare your own dataset and refer to the samples in data/validation to replace them according to your specific scenario. Then you can simply run bash infer.sh to test SAM-Med3D on your data. Make sure the masks are processed into the one-hot format (have only two values: the main image (foreground) and the background).

python validation.py --seed 2023\
 -vp ./results/vis_sam_med3d \
 -cp ./ckpt/sam_med3d.pth \
 -tdp ./data/validation -nc 1 \
 --save_name ./results/sam_med3d.py
  • vp: visualization path, dir to save the final visualization files
  • cp: checkpoint path
  • tdp: test data path, where your data is placed
  • nc: number of clicks of prompt points
  • save_name: filename to save evaluation results

For validation of SAM and SAM-Med2D on 3D volumetric data, you can refer to infer_sam.sh and infer_med2d.sh for details.

Hint: We also provide a simple script sum_result.py to help summarize the results from file like ./results/sam_med3d.py.

πŸ”— Checkpoint

Our most recommended version is SAM-Med3D-turbo

Model Google Drive Baidu NetDisk
SAM-Med3D Download Download (pwd:r5o3)
SAM-Med3D-organ Download Download (pwd:5t7v)
SAM-Med3D-brain Download Download (pwd:yp42)
SAM-Med3D-turbo Download Download (pwd:l6ol)

Other checkpoints are available with their official link: SAM and SAM-Med2D.

πŸ—Ό Method

πŸ† Results

πŸ’‘ Overall Performance

Model Prompt Resolution Inference Time (s) Overall Dice
SAM N points 1024Γ—1024Γ—N 13 17.01
SAM-Med2D N points 256Γ—256Γ—N 4 42.75
SAM-Med3D 1 point 128Γ—128Γ—128 2 49.91
SAM-Med3D 10 points 128Γ—128Γ—128 6 60.94
SAM-Med3D-turbo 10 points 128Γ—128Γ—128 6 77.60

Note: Quantitative comparison of different methods on our evaluation dataset. Here, N denotes the count of slices containing the target object (10 ≀ N ≀ 200). Inference time is calculated with N=100, excluding the time for image processing and simulated prompt generation.

πŸ’‘ Dice on Different Anatomical Architecture and Lesions

Model Prompt A&T Bone Brain Cardiac Gland Muscle Seen Lesion Unseen Lesion
SAM N points 17.19 22.32 17.68 2.82 11.62 3.50 12.03 8.88
SAM-Med2D N points 46.79 47.52 19.24 32.23 43.55 35.57 26.08 44.87
SAM-Med3D 1 point 46.80 54.77 34.48 46.51 57.28 53.28 42.02 40.53
SAM-Med3D 10 points 55.81 69.13 40.71 52.86 65.01 67.28 50.52 48.44
SAM-Med3D-brain 10 points 51.71 - 62.77 37.93 62.95 43.70 45.89 20.51
SAM-Med3D-organ 10 points 70.63 - 46.49 63.14 73.01 75.29 53.02 36.44
SAM-Med3D-turbo 10 points 83.96 85.34 46.08 69.90 90.97 91.62 64.80 61.00

Note: Comparison from the perspective of anatomical structure and lesion. A&T represents Abdominal and Thorax targets. N denotes the count of slices containing the target object (10 ≀ N ≀ 200).

πŸ’‘ Visualization

πŸ“¬ Citation

@misc{wang2023sammed3d,
      title={SAM-Med3D}, 
      author={Haoyu Wang and Sizheng Guo and Jin Ye and Zhongying Deng and Junlong Cheng and Tianbin Li and Jianpin Chen and Yanzhou Su and Ziyan Huang and Yiqing Shen and Bin Fu and Shaoting Zhang and Junjun He and Yu Qiao},
      year={2023},
      eprint={2310.15161},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

🎫 License

This project is released under the Apache 2.0 license.

πŸ’¬ Discussion Group

image

(If the QRCode is expired, please contact the WeChat account: EugeneYonng or Small_dark8023.)

BTW, welcome to follow our Zhihu official account, we will share more information on medical imaging there.

πŸ™ Acknowledgement

  • We thank all medical workers and dataset owners for making public datasets available to the community.
  • Thanks to the open-source of the following projects:

πŸ‘‹ Hiring & Global Collaboration

  • Hiring: We are hiring researchers, engineers, and interns in General Vision Group, Shanghai AI Lab. If you are interested in Medical Foundation Models and General Medical AI, including designing benchmark datasets, general models, evaluation systems, and efficient tools, please contact us.
  • Global Collaboration: We're on a mission to redefine medical research, aiming for a more universally adaptable model. Our passionate team is delving into foundational healthcare models, promoting the development of the medical community. Collaborate with us to increase competitiveness, reduce risk, and expand markets.
  • Contact: Junjun He(hejunjun@pjlab.org.cn), Jin Ye(yejin@pjlab.org.cn), and Tianbin Li (litianbin@pjlab.org.cn).

About

SAM-Med3D: An Efficient General-purpose Promptable Segmentation Model for 3D Volumetric Medical Image

License:Apache License 2.0


Languages

Language:Python 99.8%Language:Shell 0.2%