fabiotosi92 / NeRF-Supervised-Deep-Stereo

A novel paradigm for collecting and generating stereo training data using neural rendering

Home Page:https://nerfstereo.github.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

About custom dataset preparation

CaptainEven opened this issue · comments

Thanks for the excellent work and contribution!
I have a little question about preparing my own dataset for stereo training.
As you mentioned in the Supplementary Material, "As a pre-processing step, we adjust the rendered disparity maps generated by Instant-NGP by fitting a scale-shift pair of values for each triplet", could you please provide the code/script for the disparity compensation optimization operation? Looking forward for the reply! Thanks again!

Hi, below the pseudocode that represents the approach we used.

def model(x_input):
    out = x_input * scale + shift
    return out

# Convert and prepare tensors im1, im2, im3, disp_3nerf, confidence
# th = 0.5

valid = (confidence > th).float()

# Initialize scale and shift
scale =  torch.nn.Parameter(torch.ones(1).cuda(), requires_grad=True)
shift =  torch.nn.Parameter(torch.zeros(1).cuda(), requires_grad=True) 

# Create an optimizer
optimizer = optim.SGD([scale, shift], lr=0.001)

# Define the number of epochs
num_epochs = 30

# Initialize minimum loss
min_loss = infinity

# Training loop
for epoch in range(num_epochs):
    disp = model(disp_3nerf)

    # Warp images
    im2_warped_0, mask_23_0 = warp_images(im3, disp, r2l=False)
    im2_warped_1, mask_23_1 = warp_images(im1, disp, r2l=True)

    # Compute photometric loss
    loss_0 = compute_photometric_loss(im2_warped_0, im2, mask_23_0, valid)
    loss_1 = compute_photometric_loss(im2_warped_1, im2, mask_23_1, valid)

    # Choose the minimum loss between the two
    loss = min(loss_0, loss_1)

    if loss < min_loss:
        min_loss = loss
        fscale = scale # final scale
        fshift = shift # final shift

    # Update parameters
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Thanks for the reply and help!