DAI-Lab / SteganoGAN

SteganoGAN is a tool for creating steganographic images using adversarial training.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Move `_get_steganogan` logic to `SteganoGAN.load`

csala opened this issue · comments

The code that loads the pretrained models is currently inside a "private" function in the cli.py module.

This should be moved inside the SteganoGAN.load method to allow loading pretrained models directly from the class.

For this, the following changes are needed:

  • Add an architecture argument to the load method.
  • Make the path argument from the load method optional (=None).

Then, reimplement the load method so that:

  • If a path is given, it is directly used to load the model from it.
  • If an architecture name is given, the path to the model is built dynamically in a similar way to what we currently have in _get_stegangan.
  • If either both arguments or none are given, an exception is raised.

Then, in the cli.py module, the _get_steganogan implementation can be changed pass the architecture value directly to the load method.