chaitjo / graph-convnet-tsp

Code for the paper 'An Efficient Graph Convolutional Network Technique for the Travelling Salesman Problem' (INFORMS Annual Meeting Session 2019)

Home Page:https://arxiv.org/abs/1906.01227

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

A code issue

Wastedzz opened this issue · comments

hi, thanks for your open-source code. I met a bug when I run the beam search code in utilis/beamsearch.py line96:
self.mask = self.mask.gather(1, perm_mask)
where 'perm_mask' should be a LongTensor type, but here its type is FloatTensor. It makes severl beam search-based functions fail to use.

Hi @Wastedzz @maoxiaowei97, thank you for your interest. I believe you may be using an incorrect version of PyTorch -- the code was tested with a now ancient version 0.4, but PyTorch has undergone several changes since then. For reproducing exactly, you may have to downgrade your PyTorch version.

Here are some related issues and discussions for reference:

Maybe one simple thing to try first to get the code to run would be to update backpointers via integer division:

# Update backpointers
prev_k = bestScoresId // self.num_nodes

Here are some related issues and discussions for reference:

* [Error in beamsearch.py #11](https://github.com/chaitjo/graph-convnet-tsp/issues/11)

* [GNN Encoder learning-tsp#1 (comment)](https://github.com/chaitjo/learning-tsp/issues/1#issuecomment-814688655)

Maybe one simple thing to try first to get the code to run would be to update backpointers via integer division:

# Update backpointers
prev_k = bestScoresId // self.num_nodes

you are right, thanks for your reply!

Hi @Wastedzz @maoxiaowei97, thank you for your interest. I believe you may be using an incorrect version of PyTorch -- the code was tested with a now ancient version 0.4, but PyTorch has undergone several changes since then. For reproducing exactly, you may have to downgrade your PyTorch version.

You are right. I've figured it out, and it worked. Thank you so much!

Great, happy to help, no worries.

perm_mask = perm_mask.type(torch.int64)

just add in this before gather() line to change the datatype to int64