otavioon / ssl_tools

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

generate_embeddings error when EvaluatorBase is used with non-SimpleClassificationNet models

viniandrs opened this issue Β· comments

πŸ› Describe the bug

When testing using an EvaluatorBase child class, if you don't use a SimpleClassificationNet (e.g. SSLDiscriminator), generate_embeddings will throw an error because it will try to access the fully connected layers of the model.

Error logs

File "/workspaces/workdir/ssl_tools/ssl_tools/experiments/har_classification/_classification_base.py", line 51, in generate_embeddings
    old_fc = model.fc
File "/home/vini_meta4/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'SSLDiscriminator' object has no attribute 'fc'

Solved by merging the SSLDiscriminator class into SimpleClassificationNet. Now SimpleClassificationNet also receives the update_backbone parameter thus allowing it to freeze while finetuning.