fromnd_unetimportUNet2d# or UNet1d or UNet3dunet=UNet2d(
in_channels=3, # Mandatory. Number of input channelsout_channels=5, # Mandatory. Number of output channels (or classes)num_stages=4, # Optional, default is 4. Number of stagesinitial_num_channels=32, # Optional, default is 32. Number of channels if the first stage, doubled in each subsequent stagenorm=None, # Optional, default is None. Type of normalization. Can be None (no normalization), 'bn' (batch norm), 'gnx' (group norm where x is optionally the number of groups), 'in' (instance norm), 'ln' (layer norm)non_lin='relu', # Optional, default is 'relu'. Type of activation function. Can be None, 'relu', 'leaky_relu', 'gelu', 'elu'kernel_size=3, # Optional, default is 3. Kernel size for the convolutionspooling='max'# Optional, default is 'max'. Can be 'max' or 'avg'.bias=True, # Optional, whether to add bias to the convolutionspadding='same', # Optional, can be 'same' (i.e. padding=kernel_size//2 when kernel_size is odd) or an int specifying the padding. Beware, a value different from 'same' can produce an output that has a different size from the inputpadding_mode='zeros'# Optional, can be any of the padding modes supported by PyTorch convolutions ('zeros', 'reflect', 'replicate', or 'circular')stride_sequence=None, # Optional, a sequence of strides of length (num_stages - 1) can be provided in order to control pooling. For example, stride_sequence = [(1, 2), (2, 2), (2, 2)] will not reduce the dimension of the first axis in the first layer. Default is stride_sequence = [2] * (num_stages - 1)skip_connections=True# Optional, whether to use skip connections or not. Default is True.
)