Adding custom dataset for multiclassification
Sandipriz opened this issue · comments
Hi @Sandipriz It seems that filepath
and filepath1
are not RasterDatasets
but are just strings. Could you post some additional code?
Now I see the problem.
Here is my file path:
filepath = '/content/drive/My Drive/BCMCA/bcmca_3b.tif'
filepath1= '/content/drive/My Drive/BCMCA/bcmca_mask.tif'
The resolution for the image is 0.3 and CRS is epsg (32617). I am not able to define it as a raster dataset and mask.
Should it be like
raster=function(filepath, crs=naip.crs, res=0.3), because I can see Sentinel2/ChaespeakeDE in place of function for the custom raster dataset tutorial but not sure for WorldView II.
@Sandipriz if I understand correctly, your current code looks like:
filepath = '/content/drive/My Drive/BCMCA/bcmca_3b.tif'
filepath1 = '/content/drive/My Drive/BCMCA/bcmca_mask.tif'
train_dataset = filepath & filepath1
The correct code would be:
from torchgeo.datasets import RasterDataset
filepath = '/content/drive/My Drive/BCMCA/bcmca_3b.tif'
filepath1 = '/content/drive/My Drive/BCMCA/bcmca_mask.tif'
image_dataset = RasterDataset(filepath)
mask_dataset = RasterDataset(filepath1)
train_dataset = image_dataset & mask_dataset
There are other steps that may be required, but this should at least get you farther.
I was able to move ahead. I have the intention to patchify the image and make a train and validation dataset for the model.
Another way I tried to patchify the image manually to the size of 256*256 and feed in the model. For this, I moved the images and mask to the train and validation sub-folder.
Number of images in validation images directory: 539
Number of masks in validation mask directory: 539
Number of images in training images directory: 1255
Number of masks in training mask directory: 1255
I got stuck again.
First, to clarify some confusion, len(dataset)
is not the number of possible samples, it's the number of overlapping images. len(sampler)
or len(dataloader)
is the number of possible samples.
If you tell me where you got stuck or share the code (not a picture of the code) I can try to help.
Thank you very much for your support.
filepath = '/content/drive/My Drive/BCMCA/bcmca_3b.tif'
filepath1= '/content/drive/My Drive/BCMCA/bcmca_mask.tif'
from torchgeo.datasets import RasterDataset
image = RasterDataset(filepath)
mask = RasterDataset(filepath1)
dataset = image & mask
#Check the number of possible samples
num_samples = len(dataset)
print(f"Number of possible samples: {num_samples}")
from torchgeo.datasets import RasterDataset
from torchgeo.samplers import GridGeoSampler, RandomGeoSampler
#Define the size of the patches
patch_size = 256 # Example patch size, you can adjust this
#Create a grid sampler to generate patches
sampler = RandomGeoSampler(dataset, size=patch_size, length=20)
print(sampler)
train_loader = DataLoader(dataset, sampler=sampler, batch_size=16, collate_fn=stack_samples)
#let's print a random element of the training dataset
random_element = np.random.randint(0, len(train_loader))
for idx, sample in enumerate(train_loader):
if idx != random_element:
continue
# let's select the first sample from the batch
image = sample["image"][0]
target = sample["mask"][0]
# Create a figure and a 1x2 grid of axes
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
# Plot the first image in the left axis
# select only first 3 bands and cast to uint8
rgb_image = np.transpose(image.numpy().squeeze()[0:3], (1, 2, 0)).astype('uint8')
axes[0].imshow(rgb_image)
axes[0].set_title('Rgb image')
# Plot the labels image in the right axis
target_image = target.numpy().squeeze()
axes[1].imshow(target_image)
axes[1].set_title('Mask')
# Adjust layout to prevent clipping of titles
plt.tight_layout()
# Show the plots
plt.show()
I wanted to check to see if the image and mask are properly arranged in the dataloader.
My image size is 11719x9922 which makes 37 non-overlapping patches of size 256x256. I want to make 20 of those samples as training and the rest as validation datasets.
I want to make 20 of those samples as training and the rest as validation datasets.
My suggestion would be to use one of TorchGeo's splitting functions to split the dataset into non-overlapping train and validation datasets. Otherwise, there's no guarantee that the 20 random tiles you sample during training and the 17 random tiles you sample during validation have no overlap.
Also note that RandomGeoSampler
makes no guarantees of non-overlapping samples. For that, you probably want GridGeoSampler
.
I patchify the images manually.
Number of images- /content/patches256/train/images/: 1254
Number of mask- /content/patches256/train/mask/: 1254
The images have the same name as the mask, consisting size of 256x256 and non-overlapping.
Rest I keep aside for validation.
Number of images- /content/patches256/val/images/: 539
Number of mask- /content/patches256/val/mask/: 539
Seems like now I can skip the sampling part and go directly to dataloader part.
I don't find in the documentation about how to load the patchified data directly as train data from the directory.
You can either write a NonGeoDataset
for loading patchified data, or use a GeoDataset
with a sampler. For writing a NonGeoDataset
, see a tutorial like: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html.
@Sandipriz do you have any other questions, or can we close this issue?
Three issues:
- I am able to feed my 3 bands image but not the 8 bands.
- I am lost when I try to augment the data for training.
- Basically it's the part of augmentation, as I want to use vegetation index like NDVI, even though I am going through documentation, I am unable to adopt it.
# Define paths to local .tif files
image_path = "C:/Users/Chris/Documents/BCMCA/bcmca_small_true.tif"
mask_path = "C:/Users/Chris/Documents/BCMCA/bcmca_mask.tif"
class CustomGeoDataset(Dataset):
def __init__(self, image_path, mask_path, patch_size=256):
self.image_path = image_path
self.mask_path = mask_path
self.patch_size = patch_size
# Open the images
self.image_src = rasterio.open(self.image_path)
self.mask_src = rasterio.open(self.mask_path)
self.height, self.width = self.image_src.height, self.image_src.width
def __len__(self):
# We will define an arbitrary number of samples
return 1000 # Modify as per your requirement
def __getitem__(self, idx):
x = np.random.randint(0, self.width - self.patch_size)
y = np.random.randint(0, self.height - self.patch_size)
image = self.image_src.read(window=rasterio.windows.Window(x, y, self.patch_size, self.patch_size))
mask = self.mask_src.read(1, window=rasterio.windows.Window(x, y, self.patch_size, self.patch_size))
# Convert to torch tensors
image = torch.from_numpy(image).float()
mask = torch.from_numpy(mask).long()
# Normalize the image to range [0, 1] if necessary
image = image / 255.0
sample = {"image": image, "mask": mask}
return sample
# Create dataset instances
train_dataset = CustomGeoDataset(image_path, mask_path)
val_dataset = CustomGeoDataset(image_path, mask_path) # Use the same for validation for now
# Set the ratio for splitting into training and validation sets
train_ratio = 0.8 # 80% for training, 20% for validation
total_samples = 100 # Samples number after augmentation
train_size = int(train_ratio * total_samples)
val_size = total_samples - train_size
# Use RandomSampler to sample regions from the dataset
train_sampler = RandomSampler(train_dataset, num_samples=train_size, replacement=True)
val_sampler = RandomSampler(val_dataset, num_samples=val_size, replacement=True)
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=16)
val_loader = DataLoader(val_dataset, sampler=val_sampler, batch_size=16)
# Let's print a random element of the training dataset
random_element = np.random.randint(0, len(train_loader))
for idx, sample in enumerate(train_loader):
if idx != random_element:
continue
# Let's select the first sample from the batch
image = sample["image"][0]
target = sample["mask"][0]
# Ensure the image has 3 bands (RGB)
if image.shape[0] == 3:
# Create a figure and a 1x2 grid of axes
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
# Plot the first image in the left axis
# Select only first 3 bands and cast to uint8
rgb_image = np.transpose(image.numpy().squeeze(), (1, 2, 0))
axes[0].imshow((rgb_image * 255).astype('uint8'))
axes[0].set_title('RGB image')
# Plot the labels image in the right axis
target_image = target.numpy().squeeze()
axes[1].imshow(target_image, cmap='gray')
axes[1].set_title('Mask')
# Adjust layout to prevent clipping of titles
plt.tight_layout()
# Show the plots
plt.show()
else:
print("Image does not have 3 bands.")
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
from tqdm import tqdm
# Define necessary variables
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
nb_channels = 3 # Number of input channels (RGB)
nb_classes = 1 # Number of output classes (binary mask)
# Initialize model, loss function, and optimizer
model = UNet(n_channels=nb_channels, n_classes=nb_classes).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training loop
num_epochs = 10
log_dict = {'loss_per_batch': [], 'loss_per_epoch': []}
best_loss = float('inf')
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for batch in tqdm(train_loader):
inputs = batch["image"][:, 0:3, :, :] / 255.0 # Normalize inputs to [0, 1]
targets = batch["mask"].unsqueeze(1) # Add a channel dimension
# Move data to device
inputs = inputs.to(device)
targets = targets.to(device)
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, targets.float()) # BCEWithLogitsLoss does not require sigmoid
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
log_dict['loss_per_batch'].append(loss.item())
# Evaluation phase
model.eval()
total_val_loss = 0.0
with torch.no_grad():
for val_batch in val_loader:
val_inputs = val_batch["image"][:, 0:3, :, :] / 255.0 # Normalize inputs to [0, 1]
val_targets = val_batch["mask"].unsqueeze(1) # Add a channel dimension
# Move data to device
val_inputs = val_inputs.to(device)
val_targets = val_targets.to(device)
val_outputs = model(val_inputs)
val_loss = criterion(val_outputs, val_targets.float())
total_val_loss += val_loss.item()
# Calculate the average loss
average_val_loss = total_val_loss / len(val_loader)
log_dict['loss_per_epoch'].append(average_val_loss)
# Check if current performance is better than the best so far
if average_val_loss < best_loss:
best_loss = average_val_loss
# Save the model checkpoint
torch.save(model.state_dict(), 'best_model.pt')
print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {running_loss / len(train_loader.dataset)}, Validation Loss: {average_val_loss}")
'''
@Sandipriz you are no longer using TorchGeo in your latest code snippet. If you have any questions about TorchGeo, let me know.
I was unable to feed my custom data to the TorchGeo. A suggestion on that would be helpful because the documentation didn't help to ingest the 8-bands image with the mask in the model.
If you upload the failing TorchGeo code and data needed to reproduce the issue, I can take a look.
https://torchgeo.readthedocs.io/en/stable/tutorials/custom_raster_dataset.html gives an example with 4-band Sentinel-2 imagery, it shouldn't be too hard to modify that example to support 8-band imagery.