Skip to content

Commit 21ed8e5

Browse files
authored
Remove deprecated img_size args in SwinUNETR (#416)
Remove deprecated `img_size` args in `SwinUNETR` Add `allow_smaller=True` in CropForeground / CropForegroundd since default value has been changed Project-MONAI/MONAI#8430 --------- Signed-off-by: YunLiu <[email protected]>
1 parent 4c18daf commit 21ed8e5

File tree

16 files changed

+15
-22
lines changed

16 files changed

+15
-22
lines changed

DAE/BTCV_Finetune/utils/data_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def get_loader(args):
8080
transforms.ScaleIntensityRanged(
8181
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
8282
),
83-
transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
83+
transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
8484
transforms.RandCropByPosNegLabeld(
8585
keys=["image", "label"],
8686
label_key="label",
@@ -111,7 +111,7 @@ def get_loader(args):
111111
transforms.ScaleIntensityRanged(
112112
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
113113
),
114-
transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
114+
transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
115115
transforms.ToTensord(keys=["image", "label"]),
116116
]
117117
)

DAE/Pretrain_full_contrast/data/data_pretrain.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
AsChannelFirstd,
2222
AsDiscrete,
2323
Compose,
24-
CropForegroundd,
2524
LoadImaged,
2625
NormalizeIntensityd,
2726
Orientationd,

SwinMM/WORD/models/swin_unetr.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def __init__(
3838
3939
"""
4040
super().__init__(
41-
img_size,
4241
*args,
4342
num_heads=num_heads,
4443
feature_size=feature_size,

SwinUNETR/BRATS21/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ Mean Dice refers to average Dice of WT, ET and TC tumor semantic classes.
106106
A Swin UNETR network with standard hyper-parameters for brain tumor semantic segmentation (BraTS dataset) is be defined as:
107107

108108
``` bash
109-
model = SwinUNETR(img_size=(128,128,128),
110-
in_channels=4,
109+
model = SwinUNETR(in_channels=4,
111110
out_channels=3,
112111
feature_size=48,
113112
use_checkpoint=True,

SwinUNETR/BRATS21/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def main_worker(gpu, args):
127127
pretrained_pth = os.path.join(pretrained_dir, model_name)
128128

129129
model = SwinUNETR(
130-
img_size=(args.roi_x, args.roi_y, args.roi_z),
131130
in_channels=args.in_channels,
132131
out_channels=args.out_channels,
133132
feature_size=args.feature_size,

SwinUNETR/BRATS21/test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def main():
7070
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7171
pretrained_pth = os.path.join(pretrained_dir, model_name)
7272
model = SwinUNETR(
73-
img_size=128,
7473
in_channels=args.in_channels,
7574
out_channels=args.out_channels,
7675
feature_size=args.feature_size,

SwinUNETR/BRATS21/utils/data_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def get_loader(args):
9999
transforms.LoadImaged(keys=["image", "label"]),
100100
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
101101
transforms.CropForegroundd(
102-
keys=["image", "label"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z]
102+
keys=["image", "label"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z], allow_smaller=True
103103
),
104104
transforms.RandSpatialCropd(
105105
keys=["image", "label"], roi_size=[args.roi_x, args.roi_y, args.roi_z], random_size=False

SwinUNETR/BTCV/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,7 @@ Once the json file is downloaded, please place it in the same folder as the data
8484
A Swin UNETR network with standard hyper-parameters for multi-organ semantic segmentation (BTCV dataset) is be defined as:
8585

8686
``` bash
87-
model = SwinUNETR(img_size=(96,96,96),
88-
in_channels=1,
87+
model = SwinUNETR(in_channels=1,
8988
out_channels=14,
9089
feature_size=48,
9190
use_checkpoint=True,

SwinUNETR/BTCV/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def main_worker(gpu, args):
127127

128128
pretrained_dir = args.pretrained_dir
129129
model = SwinUNETR(
130-
img_size=(args.roi_x, args.roi_y, args.roi_z),
131130
in_channels=args.in_channels,
132131
out_channels=args.out_channels,
133132
feature_size=args.feature_size,

SwinUNETR/BTCV/test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def main():
7171
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7272
pretrained_pth = os.path.join(pretrained_dir, model_name)
7373
model = SwinUNETR(
74-
img_size=96,
7574
in_channels=args.in_channels,
7675
out_channels=args.out_channels,
7776
feature_size=args.feature_size,

SwinUNETR/BTCV/utils/data_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def get_loader(args):
8080
transforms.ScaleIntensityRanged(
8181
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
8282
),
83-
transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
83+
transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
8484
transforms.RandCropByPosNegLabeld(
8585
keys=["image", "label"],
8686
label_key="label",
@@ -111,7 +111,7 @@ def get_loader(args):
111111
transforms.ScaleIntensityRanged(
112112
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
113113
),
114-
transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
114+
transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
115115
transforms.ToTensord(keys=["image", "label"]),
116116
]
117117
)

SwinUNETR/Pretrain/utils/data_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def get_loader(args):
7878
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
7979
),
8080
SpatialPadd(keys="image", spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
81-
CropForegroundd(keys=["image"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z]),
81+
CropForegroundd(keys=["image"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z], allow_smaller=True),
8282
RandSpatialCropSamplesd(
8383
keys=["image"],
8484
roi_size=[args.roi_x, args.roi_y, args.roi_z],
@@ -98,7 +98,7 @@ def get_loader(args):
9898
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
9999
),
100100
SpatialPadd(keys="image", spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
101-
CropForegroundd(keys=["image"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z]),
101+
CropForegroundd(keys=["image"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z], allow_smaller=True),
102102
RandSpatialCropSamplesd(
103103
keys=["image"],
104104
roi_size=[args.roi_x, args.roi_y, args.roi_z],

UNETR/BTCV/utils/data_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def get_loader(args):
8080
transforms.ScaleIntensityRanged(
8181
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
8282
),
83-
transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
83+
transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
8484
transforms.RandCropByPosNegLabeld(
8585
keys=["image", "label"],
8686
label_key="label",
@@ -111,7 +111,7 @@ def get_loader(args):
111111
transforms.ScaleIntensityRanged(
112112
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
113113
),
114-
transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
114+
transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
115115
transforms.ToTensord(keys=["image", "label"]),
116116
]
117117
)

auto3dseg/algorithm_templates/dints/scripts/algo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def fill_template_config(self, data_stats_file, output_path, **kwargs):
158158
"source_key": "@image_key",
159159
"start_coord_key": None,
160160
"end_coord_key": None,
161+
"allow_smaller": True,
161162
},
162163
],
163164
}
@@ -174,7 +175,7 @@ def fill_template_config(self, data_stats_file, output_path, **kwargs):
174175
"b_max": 1.0,
175176
"clip": True,
176177
},
177-
{"_target_": "CropForegroundd", "keys": "@image_key", "source_key": "@image_key"},
178+
{"_target_": "CropForegroundd", "keys": "@image_key", "source_key": "@image_key", "allow_smaller": True},
178179
],
179180
}
180181

auto3dseg/algorithm_templates/swinunetr/configs/network.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
network:
22
_target_: SwinUNETR
33
feature_size: 48
4-
img_size: 96
54
in_channels: "@input_channels"
65
out_channels: "@output_classes"
76
spatial_dims: 3

auto3dseg/algorithm_templates/swinunetr/scripts/algo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def fill_template_config(self, data_stats_file, output_path, **kwargs):
168168
"source_key": "@image_key",
169169
"start_coord_key": None,
170170
"end_coord_key": None,
171+
"allow_smaller": True,
171172
},
172173
],
173174
}
@@ -183,7 +184,7 @@ def fill_template_config(self, data_stats_file, output_path, **kwargs):
183184
"b_max": 1.0,
184185
"clip": True,
185186
},
186-
{"_target_": "CropForegroundd", "keys": "@image_key", "source_key": "@image_key"},
187+
{"_target_": "CropForegroundd", "keys": "@image_key", "source_key": "@image_key", "allow_smaller": True},
187188
],
188189
}
189190
mr_intensity_transform = {

0 commit comments

Comments
 (0)