Lackel / DPN

AAAI 2023 paper "Generalized Category Discovery with Decoupled Prototypical Network"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Unable to replicate results for any datasets. Getting too much deviation from the provided results in the paper.

Chandan-IITI opened this issue · comments

I tried to run the default code given for the "StackOverflow" dataset and the Clinc dataset. I am getting the following results:

got for Stackoverflow dataset: {'ACC_all': 75.6, 'ACC_known': 81.2, 'ACC_novel': 58.8}

Provided in paper for Stackoverflow dataset: {'ACC_all': 84.23, 'ACC_known': 85.29, 'ACC_novel': 81.07}

Please let me know how can I reproduce the results. I tried varying parameters also like 'predict_k'=2 or 3 but was unable to reproduce the results.

Hi, we just tested our code in our environment with default settings (e.g., cluster_num_factor=1, rather than 2 or 3). It worked well in our experiments (e.g., for Stackoverflow with seed 0, we get {'ACC_all': 85.8, 'ACC_known': 85.47, 'ACC_novel': 86.8}), so can you check the hyper-parameters (you can simply run sh run.sh without any changes) and the version of your packages (we provided the version we used in readme)? If it still doesn't work, please provide more information about the environment you used.

Thanks a lot @Lackel for the quick response. I tried to run the code without changing any hyperparameter. Let me check the version of the packages and get back to you.

  • @Lackel Thank you for your suggestion. I am able to replicate the results after changing the package version. I didn't expect a 10 % difference due to the versions.

Hi @Lackel ,

I ran the code for all datasets. I am able to replicate the results for the StackOverflow dataset. However, I am getting a 2-4% deviation in the case of CLINC and Banking datasets as follows:

Method BANKING     StackOverflow     CLINC    
  All Known Novel All Known Novel All Known Novel
DPN (AAAI-2023) 72.96 80.93 48.6 84.23 85.29 81.07 89.06 92.97 77.54
DPN (Reproduced) 70.36 76.77 50.79 85.6 85.2 86.8 85.73 92.98 64.39

Am I missing something? Please suggest what I can do for reproducing the results as I took care of the package version in the current run. I also rand the code multiple times due to randomness.

Hi, I just tested our code with default settings on another two datasets, It worked well too (for CLINC, we got {'ACC_all': 89.82, 'ACC_known': 94.23, 'ACC_novel': 76.84}, and for BANKING, we got {'ACC_all': 74.90, 'ACC_known': 81.64, 'ACC_novel': 54.34} with random seed 0). So I have no idea why you can't reproduce our results in your environment. Maybe you can create a new conda environment and install our recommended package version, and then download our code without any changes to perform the experiments. I hope this can work for you.

HI @Lackel,

Thanks a lot for your help. I created the new environment and installed the same version of 7 packages as you mentioned. I also didn't change the code. I am sure there must be problem at my end, just want to figure out the issue. Please see the version of package in my environment as follows:

image

If it is possible then please provide details of your all packages' version apart from the 7 mentioned in your repositoiry. I am not sure why I am not able to replicate the results for CLINC and Banking datasets. However, it is perfectly working for Stackoverflow.

@Lackel Thanks for your help. I am not sure what was the issue but I am able to reproduce the results now. It might be due to randomness. Just restarted the system and rand the code again. It yielded similar results as you mentioned in the paper.

I found the issue, it might be with my environment setup. In case if any other person are getting similar issue then please replace ' def _read_tsv(self, input_file, quotechar=None):' function in data.py file by the following:

def _read_tsv(self, input_file, quotechar=None):
    """Reads a tab separated value file."""
    if input_file[5:18] == 'stackoverflow':
        print("Processing for stackoverflow Dataset")
        with open(input_file, "r", encoding="utf8") as f:  ### Modieied this line by me
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                lines.append(line)
            return lines
    else:
        print("Processing for other than stackoverflow Dataset (CLINC or Banking)")
        with open(input_file, "r") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                lines.append(line)
            return lines

It is really a weird bug😂