mlfoundations / open_lm

A repository for research on medium sized language models.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Deduplicate argparse namespace creation for tests

achalddave opened this issue · comments

Right now, we set up arguments in tests where we don't need to. As a result, we end up needing to change tests every time we add a parameter:

Some places:

  1. args = {
    "device": "cpu",
    "precision": "fp16",
    "accum_freq": 1,
    "seq_len": 9,
    "vocab_size": 10,
    "batch_size": 16,
    "log_logit_mean": False,
    "grad_clip_norm": 1.0,
    "skip_scheduler": True,
    "rank": 0,
    "local_rank": 0,
    "world_size": 1,
    "wandb": False,
    "log_every_n_steps": 1,
    "target_mask_left": None,
    "target_mask_individual": None,
    }
  2. args = argparse.Namespace(
    **{
    # Generation params:
    "model": "open_lm_160m",
    "input_text": "random",
    "max_gen_len": max_gen_len,
    "context_len": context_len,
    "temperature": 0.0,
    "top_p": 1.0,
    "use_cache": False,
    # Model params that might not be in config:
    "model_norm": "gain_only_layer_norm",
    "qk_norm": False,
    "positional_embedding_type": "rotary",
    "ffn_type": "swiglu",
    }
    )
  3. open_lm/tests/shared.py

    Lines 13 to 79 in b5f9beb

    class MockTrainArgs:
    def __init__(self, model, **kwargs):
    data_path = download_val_data("shard_00000000.tar", "./tests/assets/")
    self.model = model # part of model config
    self.model_norm = "gain_only_layer_norm"
    self.qk_norm = False
    self.train_data = [
    data_path,
    ]
    self.log_logit_mean = False
    self.device = "cpu"
    self.precision = "float32"
    self.wd = 0.033
    self.lr = 3e-3
    self.beta1 = 0.9
    self.beta2 = 0.95
    self.eps = 1e-8
    self.warmup = 2
    self.skip_scheduler = False
    self.accum_freq = 1
    self.batch_size = 8
    self.grad_clip_norm = 1.0
    self.rank = 0
    self.local_rank = 0
    self.log_every_n_steps = 1e8
    self.save_logs = False
    self.logs = None
    self.name = "test_model_name"
    self.dataset_type = "webdataset"
    self.data_key = "json"
    self.ffn_type = "swiglu"
    self.train_num_samples = 250000
    self.train_data_mix_weights = None
    self.train_data_upsampling_factors = None
    self.disable_buffer = False
    self.seed = 1
    self.vocab_size = 50432
    self.seq_len = 300
    self.epochs = 1
    self.save_frequency = 1
    self.checkpoint_path = "./tests/assets/checkpoints/"
    self.resume = None
    self.distributed = False
    self.delete_previous_checkpoint = False
    self.workers = 1
    self.world_size = 1
    self.val_data = None
    self.lr_cooldown_end = 3e-5
    self.force_min_lr = 0.0
    self.scaler = None
    self.accum_freq = 1
    self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
    self.wandb = False
    self.fsdp = False
    self.fsdp_amp = False
    self.positional_embedding_type = "rotary"
    self.dist_backend = "nccl"
    self.dist_url = "env://"
    self.dataset_manifest = None
    self.target_mask_left = None
    self.target_mask_individual = None
    self.ignore_parse_errors = False
    for k, v in kwargs.items():
    if hasattr(self, k):
    setattr(self, k, v)
  4. args = argparse.Namespace(
    **{
    # Generation params:
    "model": "open_lm_1b_old",
    "input_text": "random",
    "max_gen_len": None,
    "context_len": None,
    "temperature": 0.0,
    "top_p": 1.0,
    "use_cache": False,
    "checkpoint": "checkpoints/open_lm_1b_old.pt",
    # Model params that might not be in config:
    "model_norm": "default_layer_norm",
    "qk_norm": False,
    "positional_embedding_type": "head_rotary",
    "ffn_type": "swiglu",
    }
    )

We should instead just call parse_args, or at the very least, only have these args in one part of the tests.