iMoonLab / DeepHypergraph

A pytorch library for graph and hypergraph computation.

Home Page:https://deephypergraph.com/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

torch.topk 函数设置 largest=False 是为了找到最小的距离,即最近的邻居。这种方法可能比转换张量为 numpy 数组并使用 scipy.spatial.cKDTree 更有效,因为它避免了 CPU 和 GPU 之间的数据传输。

opened this issue · comments

@staticmethod
def _e_list_from_feature_kNN(features: torch.Tensor, k: int):
r"""Construct hyperedges from the feature matrix. Each hyperedge in the hypergraph is constructed by the central vertex and its :math:k-1 neighbor vertices.

Args:
    ``features`` (``torch.Tensor``): The feature matrix.
    ``k`` (``int``): The number of nearest neighbors.
"""
assert features.ndim == 2, "The feature matrix should be 2-D."
assert (
    k <= features.shape[0]
), "The number of nearest neighbors should be less than or equal to the number of vertices."

dist_matrix = torch.cdist(features, features, p=2)
_, nbr_indices = torch.topk(dist_matrix, k, largest=False)

return nbr_indices.tolist()

感谢,我这最近更新一下这一块!