microsoft / SpareNet

Style-based Point Generator with Adversarial Rendering for Point Cloud Completion (CVPR 2021)

Home Page:https://arxiv.org/abs/2103.02535

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cannot Train with Depth Maps without Chamfer Distance

Wi-sc opened this issue · comments

Hi, thanks for your released code. I'm interested in the your methods that training with depth map. So I design a toy problem where I only train with depth map but I find points become NaN. Do you have any suggestions? Does this mean the depth maps cannot help training?

from p2i_utils import ComputeDepthMaps
compute_depth_maps = ComputeDepthMaps(projection="perspective", eyepos_scale=1.0, image_size=224).float()

def get_depth_render(points, requires_grad=True):
    depth_list = []
    for view_id in range(8):
        _depth = compute_depth_maps(points, view_id=view_id)
        depth_list.append(_depth)
    if requires_grad:
        return torch.cat(depth_list, dim=1)
    else:
        return torch.cat(depth_list, dim=1).detach()

trg_points = load_from_file()
input_points = torch.full([1, 2500, 3], 0.0, device=device, requires_grad=True)
optimizer = torch.optim.SGD([input_points ], lr=1.0, momentum=0.9)
Niter = 2000
loss_list = []

for i in range(Niter):
    optimizer.zero_grad()
    gt_depth = get_depth_render(trg_points, requires_grad=False)
    pred_depth = get_depth_render(input_points , requires_grad=True)
    loss_depth = torch.nn.L1Loss()(pred_depth, gt_depth)
    loss = loss_depth
    loss_list.append(loss.detach().cpu().item())
    loss.backward()
    optimizer.step()

Hi, thanks for your interest! @Wi-sc I think Chamfer Distance is not the reason for NaN in your experiments.

In our method, the supervision of depth maps can help optimize the point encoder and generator networks. However, in your case, you want depth maps to directly optimize the 3D Points themselves. It could be an interesting application of our point renderer.
Could you decrease the learning rate? lr =1.0 might be too large.

@AlphaPav Thanks for your reply. I have changed the learning rate into 0.01. But it also outputs NaN. And actually I also tried to train an auto-encoder network. But l1 distance between depth maps cannot regularize the training. I think this is not reasonable. Have you tested the depth renderer under this setting?

This is the auto-encoder code.

import os
import sys
import torch
import pytorch3d
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.io import loadmat
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.use('agg')
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80

from p2i_utils import ComputeDepthMaps
compute_depth_maps = ComputeDepthMaps(projection="perspective", eyepos_scale=1.0, image_size=224).float()

def get_depth_render(points, requires_grad=True):
    depth_list = []
    for view_id in range(8):
        _depth = compute_depth_maps(points, view_id=view_id)
        depth_list.append(_depth)
    if requires_grad:
        return torch.cat(depth_list, dim=1)
    else:
        return torch.cat(depth_list, dim=1).detach()

class Encoder(nn.Module):
    def __init__(self, output_dim=1024):
        super(Encoder, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, output_dim, 1)

        self.bn1 = torch.nn.BatchNorm1d(64)
        self.bn2 = torch.nn.BatchNorm1d(128)
        self.bn3 = torch.nn.BatchNorm1d(output_dim)

    def forward(self, shapes):
        x = F.relu(self.bn1(self.conv1(shapes.transpose(1,2))))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x, _ = torch.max(x, 2)
        x = x.view(shapes.size(0), -1)
        return x

class Decoder(nn.Module):
    def __init__(self, bottleneck_size=1024):
        super(Decoder, self).__init__()
        self.conv1 = torch.nn.Conv1d(bottleneck_size, bottleneck_size, 1)
        self.conv2 = torch.nn.Conv1d(bottleneck_size, bottleneck_size//2, 1)
        self.conv3 = torch.nn.Conv1d(bottleneck_size//2, bottleneck_size//4, 1)
        self.conv4 = torch.nn.Conv1d(bottleneck_size//4, 3, 1)
        self.bn1 = torch.nn.BatchNorm1d(bottleneck_size)
        self.bn2 = torch.nn.BatchNorm1d(bottleneck_size//2)
        self.bn3 = torch.nn.BatchNorm1d(bottleneck_size//4)
        self.th = nn.Tanh()

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)), inplace=True)
        x = F.relu(self.bn2(self.conv2(x)), inplace=True)
        x = F.relu(self.bn3(self.conv3(x)), inplace=True)
        x = self.th(self.conv4(x))
        return x

class Net(nn.Module):
    def __init__(self,  feature_dim=1024):
        super(Net, self).__init__()
        self.feature_dim = feature_dim
        self.encoder = Encoder(output_dim=self.feature_dim)
        self.decoder = Decoder(bottleneck_size=self.feature_dim+3)
    
    def forward(self, x, sphere_points):
        feature = self.encoder(x)
        sphere_points = sphere_points.transpose(1,2)
        feature = feature.unsqueeze(2).expand(feature.size(0), feature.size(1), sphere_points.size(2)).contiguous()
        feature = torch.cat([sphere_points, feature], dim=1)
        out = self.decoder(feature) + sphere_points
        return out.transpose(1,2)

device = torch.device("cuda:0")

mat_file = loadmat('path')
verts = torch.from_numpy(mat_file['verts'])*torch.FloatTensor([[1,1,-1]])
faces = torch.from_numpy(mat_file['faces'])
trg_mesh = Meshes(verts=[verts], faces=[faces]).to(device)
src_mesh = ico_sphere(4, device)

network = Net().cuda()
optimizer = torch.optim.SGD(network.parameters(), lr=0.01, momentum=0.9)
# Number of optimization steps
Niter = 20000
plot_period = 2000

loss_list = []

for i in range(Niter):
    # Initialize optimizer
    optimizer.zero_grad()
    sample_src = sample_points_from_meshes(src_mesh, 2048).detach()
    sample_trg = sample_points_from_meshes(trg_mesh, 2048).detach()
    pred_points = network(sample_trg, sample_src)
    gt_depth = get_depth_render(sample_trg, requires_grad=False)
    pred_depth = get_depth_render(pred_points, requires_grad=True)
    loss_depth = torch.nn.L1Loss()(pred_depth, gt_depth)
    
    # loss_chamfer, _ = chamfer_distance(sample_trg, pred_points)
    
    loss = loss_depth
    loss_list.append(loss.detach().cpu().item())
    
    if i % plot_period == 0:
        print(i, 'total_loss = %.6f' % loss)
        print(pred_points[0, :10, :])
        fig, (ax1, ax2) = plt.subplots(1, 2)
        ax1.imshow(pred_depth[0, 0, :, :].detach().cpu().numpy())
        ax2.imshow(gt_depth[0, 0, :, :].cpu().numpy())
        plt.draw()
        plt.savefig('./test_loss/%s_chamfer.png'%i)

    # Optimization step
    loss.backward()
    optimizer.step()

fig = plt.figure(figsize=(13, 5))
ax = fig.gca()
ax.plot(loss_list, label="loss")
ax.legend(fontsize="16")
ax.set_xlabel("Iteration", fontsize="16")
ax.set_ylabel("Loss", fontsize="16")
ax.set_title("Loss vs iterations", fontsize="16")
plt.draw()
plt.savefig('./test_loss/loss.png')

@AlphaPav Hi, I still cannot train a simple encoder-decoder network under supervision of depth maps.

Hi @Wi-sc , in my opinion, using depth maps as the only supervision without CD/EMD should not work.

Just imagine when only one pred point appears in the pred_depth and one gt point appears in the gt_depth, and there is a distance between the gt point and the pred point in the 2D rendering plane. When calculating the L1 loss between such two depth maps pixel-wisely, since all pixels nearby the pred point mostly see nothing but close-to-zero values in gt_depth (because the two points have a distance), then the gradients back-propagated to the pred_depth pixels can hardly reflect directional information concerning the location of the gt point. Therefore, the pixel-wise L1 loss will hardly make the two points closer.

Furthermore, in our p2i code, we implement the rendering with only a limited range around points for rendering efficiency: pixels far away from the rendered point are directly set to an arbitrary value. Under our implementation, it becomes more impossible that a single L1 loss between depth-maps could ever work.

In my opinion, the L1 loss between depth maps could only work when pred points and gt points are already close enough on the 2D rendering plane. Therefore, combining it with CD/EMD losses is necessary, as far as I concern.