Deduplicate argparse namespace creation for tests
achalddave opened this issue · comments
Achal Dave commented
Right now, we set up arguments in tests where we don't need to. As a result, we end up needing to change tests every time we add a parameter:
Some places:
open_lm/open_lm/tests/test_accumulation.py
Lines 42 to 59 in b5f9beb
args = { "device": "cpu", "precision": "fp16", "accum_freq": 1, "seq_len": 9, "vocab_size": 10, "batch_size": 16, "log_logit_mean": False, "grad_clip_norm": 1.0, "skip_scheduler": True, "rank": 0, "local_rank": 0, "world_size": 1, "wandb": False, "log_every_n_steps": 1, "target_mask_left": None, "target_mask_individual": None, } open_lm/tests/test_generate_kv_cache_time.py
Lines 21 to 37 in b5f9beb
args = argparse.Namespace( **{ # Generation params: "model": "open_lm_160m", "input_text": "random", "max_gen_len": max_gen_len, "context_len": context_len, "temperature": 0.0, "top_p": 1.0, "use_cache": False, # Model params that might not be in config: "model_norm": "gain_only_layer_norm", "qk_norm": False, "positional_embedding_type": "rotary", "ffn_type": "swiglu", } ) Lines 13 to 79 in b5f9beb
class MockTrainArgs: def __init__(self, model, **kwargs): data_path = download_val_data("shard_00000000.tar", "./tests/assets/") self.model = model # part of model config self.model_norm = "gain_only_layer_norm" self.qk_norm = False self.train_data = [ data_path, ] self.log_logit_mean = False self.device = "cpu" self.precision = "float32" self.wd = 0.033 self.lr = 3e-3 self.beta1 = 0.9 self.beta2 = 0.95 self.eps = 1e-8 self.warmup = 2 self.skip_scheduler = False self.accum_freq = 1 self.batch_size = 8 self.grad_clip_norm = 1.0 self.rank = 0 self.local_rank = 0 self.log_every_n_steps = 1e8 self.save_logs = False self.logs = None self.name = "test_model_name" self.dataset_type = "webdataset" self.data_key = "json" self.ffn_type = "swiglu" self.train_num_samples = 250000 self.train_data_mix_weights = None self.train_data_upsampling_factors = None self.disable_buffer = False self.seed = 1 self.vocab_size = 50432 self.seq_len = 300 self.epochs = 1 self.save_frequency = 1 self.checkpoint_path = "./tests/assets/checkpoints/" self.resume = None self.distributed = False self.delete_previous_checkpoint = False self.workers = 1 self.world_size = 1 self.val_data = None self.lr_cooldown_end = 3e-5 self.force_min_lr = 0.0 self.scaler = None self.accum_freq = 1 self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.wandb = False self.fsdp = False self.fsdp_amp = False self.positional_embedding_type = "rotary" self.dist_backend = "nccl" self.dist_url = "env://" self.dataset_manifest = None self.target_mask_left = None self.target_mask_individual = None self.ignore_parse_errors = False for k, v in kwargs.items(): if hasattr(self, k): setattr(self, k, v) open_lm/tests/test_generate_load_kv_cache_equal.py
Lines 29 to 46 in b5f9beb
args = argparse.Namespace( **{ # Generation params: "model": "open_lm_1b_old", "input_text": "random", "max_gen_len": None, "context_len": None, "temperature": 0.0, "top_p": 1.0, "use_cache": False, "checkpoint": "checkpoints/open_lm_1b_old.pt", # Model params that might not be in config: "model_norm": "default_layer_norm", "qk_norm": False, "positional_embedding_type": "head_rotary", "ffn_type": "swiglu", } )
We should instead just call parse_args
, or at the very least, only have these args in one part of the tests.