facebookresearch / theseus

A library for differentiable nonlinear optimization

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

"RuntimeError: Cannot access data pointer of Tensor that doesn't have storage" when using auto grad and knn from pytorch3d in cost function

TimoRST opened this issue · comments

Hi!
I asked this question in the pytorch3d repo aswell, but it seems it might be helpful to ask here as well.

Basically, I'm trying to combine your libraries pytorch3d and theseus.
I'm using an auto grad function as objective in Theseus, in which I need to get the K nearest neighbors of a point cloud to another to perform some modified GICP.
When theseus calls auto grad on the objective I get an RuntimeError in knn.py of pytorch3d at line 69 "idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version)": "RuntimeError: Cannot access data pointer of Tensor that doesn't have storage".
Previous calls (outside of the theseus layer) work without problems. The error only occurs when auto grad is called on the function. At this time p1 und p2 are "BatchedTensors".
E.g.:
p1 = {Tensor: (1, 1024, 3)} BatchedTensor(lvl=1, bdim=0, value=\n tensor([[[[ 4.0704, -1.7715, -1.6834],\n [ 44.4077, 61.2256, 0.9029]...7, -1.8939],\n [-13.8112, -16.6522, 0.1097],\n [-12.5197, -11.3855, 0.1289]]]], device='cuda:0')\n)
p2 = {Tensor: (1, 1024, 3)} BatchedTensor(lvl=1, bdim=0, value=\n tensor([[[[-1.4440e+01, -8.8393e+00, -3.6783e-02],\n [ 4.3913e+01, 6.1490... [ 3.0349e+01, -6.0585e+01, -6.5257e-01],\n [ 1.6265e+01, -5.2071e+00, 5.0570e-01]]]], device='cuda:0')\n)

I modified the knn-class a little bit to work with pytorch2 (basically moved the ctx calls to the "setup_context" function and set "generate_vmap_rule = True":

class _knn_points(Function):
    """
    Torch autograd Function wrapper for KNN C++/CUDA implementations.
    """
    generate_vmap_rule = True
    @staticmethod
    # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
    def forward(
        p1,
        p2,
        lengths1,
        lengths2,
        K,
        version,
        norm: int = 2,
        return_sorted: bool = True,
    ):
        """
        K-Nearest neighbors on point clouds.

        Args:
            p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
                containing up to P1 points of dimension D.
            p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
                containing up to P2 points of dimension D.
            lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
                length of each pointcloud in p1. Or None to indicate that every cloud has
                length P1.
            lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
                length of each pointcloud in p2. Or None to indicate that every cloud has
                length P2.
            K: Integer giving the number of nearest neighbors to return.
            version: Which KNN implementation to use in the backend. If version=-1,
                the correct implementation is selected based on the shapes of the inputs.
            norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2).
            return_sorted: (bool) whether to return the nearest neighbors sorted in
                ascending order of distance.

        Returns:
            p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
                the nearest neighbors. This is padded with zeros both where a cloud in p2
                has fewer than K points and where a cloud in p1 has fewer than P1 points.

            p1_idx: LongTensor of shape (N, P1, K) giving the indices of the
                K nearest neighbors from points in p1 to points in p2.
                Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
                neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
                in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
        """
        if not ((norm == 1) or (norm == 2)):
            raise ValueError("Support for 1 or 2 norm.")

        idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version)

        # sort KNN in ascending order if K > 1
        if K > 1 and return_sorted:
            if lengths2.min() < K:
                P1 = p1.shape[1]
                mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None]
                # mask has shape [N, K], true where dists irrelevant
                mask = mask[:, None].expand(-1, P1, -1)
                # mask has shape [N, P1, K], true where dists irrelevant
                dists[mask] = float("inf")
                dists, sort_idx = dists.sort(dim=2)
                dists[mask] = 0
            else:
                dists, sort_idx = dists.sort(dim=2)
            idx = idx.gather(2, sort_idx)

        # ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
        # ctx.mark_non_differentiable(idx)
        # ctx.norm = norm
        return dists, idx

    @staticmethod
    def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
        p1, p2, lengths1, lengths2, K, version, norm, return_sorted = inputs
        dists, idx = output
        ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
        ctx.mark_non_differentiable(idx)
        ctx.norm = norm

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_dists, grad_idx):
        p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
        norm = ctx.norm
        # TODO(gkioxari) Change cast to floats once we add support for doubles.
        if not (grad_dists.dtype == torch.float32):
            grad_dists = grad_dists.float()
        if not (p1.dtype == torch.float32):
            p1 = p1.float()
        if not (p2.dtype == torch.float32):
            p2 = p2.float()
        grad_p1, grad_p2 = _C.knn_points_backward(
            p1, p2, lengths1, lengths2, idx, norm, grad_dists
        )
        return grad_p1, grad_p2, None, None, None, None, None, None

Torchversion: 2.0.0
with Cuda 11.8

Any idea how to deal with that? I've got it running with an knn implementation from pointnet2, but the one from pytorch3d is much more efficient.

I have absolutely no experience with jacrev, vmap or BatchedTensors. What I found out is that you can call .storage() on a Tensor and BatchedTensor doesnt seem to support that, which might be related to the error:
NotImplementedError('Cannot access storage of BatchedTensorImpl')

Does anyone had similar issues with functions implemented in cuda before? Is there something special needed to deal with this batched tensors?

Can be closed. Auto generation of the vmap function did not work in that case. Writing a short vmap function solved the issue.

@TimoRST glad to know you were able to figure it out, and thanks for sharing the fix!