rosikand / torchplate

🍽 A minimal and simple experiment module for machine learning research workflows in PyTorch.

Home Page:https://rosikand.github.io/torchplate/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Open thread on proper style

rosikand opened this issue · comments

Some thoughts:

  • Users should limit the amount of code they have outside of the experiment subclass. It may make sense however, to create one parent experiment subclass which defines the constants and then for each subsequent experiment where you want to have a specific variable as a control, subclass from this parent and change only what is needed. But the main point is, reproducibility is easier to achieve if the majority of the code is contained within the same class (or class family). That is, it may make sense for users to define their models and dataloaders within the experiment subclass rather than in separate modules.
  • Additionally, the runner module should remain relatively clean and not handle experiment-related configuration. To achieve this, it is recommended to pass in the experiment configuration (e.g., a config.yaml file or YACS object) to the experiment subclass directly and handle the configuration there. For example, with YACS:
class SampleExp(torchplate.experiment.Experiment):
    def __init__(self, config_object): 
        # setup config here based on config_object
        ... 
        super().__init__(
            ...
        )
    
    def evaluate(self, batch):
        ...

exp = SampleExp(config_object)
exp.train(num_epochs=100)

To add on to the first comment about configs, I think it would make sense to do away with yaml files and just use regular old python classes. This allows you to specify the actual attributes in code rather than using getattr everywhere. You could have one base class that provides sensible defaults and then each experiment subclasses this and changes the relevant parameters.

class BaseConfig:
    seed = 1
    model = models.ViT()
    ...

class TestExperimentConfig(BaseConfig):
    model = models.CNN()