Project dependencies may have API risk issues
PyDeps opened this issue · comments
Hi, In mesh-transformer-jax, inappropriate dependency versioning constraints can cause risks.
Below are the dependencies and version constraints that the project is using
numpy~=1.19.5
tqdm>=4.45.0
wandb>=0.11.2
einops~=0.3.0
requests~=2.25.1
fabric~=2.6.0
optax==0.0.9
dm-haiku==0.0.5
git+https://github.com/EleutherAI/lm-evaluation-harness/
ray[default]==1.4.1
jax~=0.2.12
Flask~=1.1.2
cloudpickle~=1.3.0
tensorflow-cpu~=2.6.0
google-cloud-storage~=1.36.2
transformers
smart_open[gcs]
func_timeout
ftfy
fastapi
uvicorn
lm_dataformat
pathy
The version constraint == will introduce the risk of dependency conflicts because the scope of dependencies is too strict.
The version constraint No Upper Bound and * will introduce the risk of the missing API Error because the latest version of the dependencies may remove some APIs.
After further analysis, in this project,
The version constraint of dependency tqdm can be changed to >=4.36.0,<=4.64.0.
The version constraint of dependency wandb can be changed to >=0.5.16,<=0.9.7.
The version constraint of dependency dm-haiku can be changed to >=0.0.1,<=0.0.8.
The version constraint of dependency google-cloud-storage can be changed to >=1.17.0,<=2.4.0.
The version constraint of dependency fastapi can be changed to >=0.1.2,<=0.78.0.
The above modification suggestions can reduce the dependency conflicts as much as possible,
and introduce the latest version as much as possible without calling Error in the projects.
The invocation of the current project includes all the following methods.
The calling methods from the tqdm
tqdm.tqdm.set_postfix tqdm.tqdm.update tqdm.tqdm
The calling methods from the wandb
wandb.init wandb.log
The calling methods from the dm-haiku
haiku.PRNGSequence.take haiku.without_apply_rng haiku.experimental.optimize_rng_use haiku.data_structures.tree_size
The calling methods from the google-cloud-storage
google.cloud.storage.Client.list_blobs
The calling methods from the fastapi
fastapi.FastAPI.post fastapi.FastAPI
The calling methods from the all methods
optax.chain rotate_every_two mesh_transformer.checkpoint.write_ckpt x.reshape jax.eval_shape flask.Flask multiprocessing.pool.ThreadPool haiku.Flatten get_bearer jax.numpy.exp queue.Queue mesh_transformer.TPU_cluster.TPUCluster mesh_transformer.layers.TransformerLayerShardV2 json.load time.sleep n.run.remote flask.make_response config.config.TransformerLayerShardV2.get_init_decode_state super self.proj.loss.hk.remat self.move_weights_pjit shard_strategy.count vocab.example_shape.key.next.jax.random.uniform.astype numpy.sum jax.experimental.pjit.pjit val_loss.np.array.mean.append eval_apply_fn self._relative_position_bucket f.read to_id jax.lax.dot_general ops.get_gptj_model jax.nn.softmax numpy.logical_and.tolist jax.numpy.stack self.queue.put mesh_transformer.layers.ProjectionShard jax.lax.sort_key_val self.config.get response.headers.add l.decode_once jax.local_device_count lm_eval.evaluator.evaluate.items rotate_every_two_v2 threading.Lock jax.numpy.linalg.norm x.jnp.transpose.reshape argparse.ArgumentParser.add_argument mesh_transformer.layers.TransformerLayerShard ckpt_step.str.meta.get p.map val_set.reset mesh_transformer.layers.Projection projection_apply_fn jax.numpy.array_equal tqdm.tqdm.update numpy.zeros_like qids.append jax.tree_unflatten tensorflow.data.TFRecordDataset param_init_fn self.rpe _unshard self.queue_ids.get compile_model mesh_transformer.util.head_print n.eval.remote jax.numpy.log ops.get_gptj_model.add_to_queue f_psum mesh_transformer.build_model.build_model.eval haiku.transform jax.random.uniform jax.host_count haiku.get_parameter numpy.array self.convert_requests n.train.remote jax.numpy.var config.config.TransformerLayerShardV2 grad_norm.np.array.mean x.self.k.reshape batch_items.append x.reshape.reshape getattr tfrecord_loader.TFRecordNewInputs.get_state self.proj self.output_q.get threading.Thread data.items json.load.append numpy.prod jax.lax.stop_gradient x.numpy node.load_ckpt.remote input numpy.sqrt numpy.where jax.numpy.sort min queue.Queue.put transformers.GPT2TokenizerFast.from_pretrained.decode fastapi.FastAPI.on_event self.tokenizer.encode numpy.arange self.input jax.tree_map f_psum.defvjp _corsify_actual_response _unshard.append global_norm self.eval config.get tree_flatten_with_names numpy.empty states.append self.train_pjit flask.jsonify state.items is_leaf json.load.items jax.numpy.transpose jax.numpy.array self.input_q.get given_length.astype numpy.log jax.random.split repr ClipByGlobalNormState exit int i.decode apply_fns new_states.append iter all_array_equal z_loss.sum_exp_logits.jnp.log.jnp.square.mean GPTJ self.to_data l.hk.remat read_remote.remote mesh_transformer.transformer_shard.CausalTransformer.generate getnorm requests.post mesh_transformer.transformer_shard.CausalTransformer x.reshape.astype batch_flattened.append wandb.log loss.np.array.mean functools.partial jax.lax.pmax jax.lax.rsqrt jax.lax.broadcasted_iota tokenizer self.train_xmap generate_fn numpy.finfo jax.nn.one_hot str self.network_builder.generate numpy.tril self.ff mesh_transformer.transformer_shard.CausalTransformer.eval payloads.QueueResponse _build_cors_prelight_response mesh_transformer.layers.EmbeddingShard self.proj.max attention_vec.reshape.reshape ray.get.append google.cloud.storage.Client jax.random.categorical itertools.cycle numpy.zeros embed_apply_fn x.self.q.reshape self.self_attn x.jnp.zeros_like.astype itertools.zip_longest transformers.GPT2TokenizerFast.from_pretrained.encode jax.lax.all_gather unstacked.append jax.tree_multimap jax.experimental.maps.ResourceEnv g_psum jax.numpy.sqrt jax.lax.pmean train_step jax.lax.axis_index loss.append mesh_transformer.util.additive_weight_decay output.append mesh_transformer.util.clip_by_global_norm qid.self.queue_ids.get numpy.array.mean NotImplementedError ValueError get_project self.dense_proj_o self.prepare_item.get ops.get_gptj_model.load_model tfrecord_loader.TFRecordNewInputs json.dumps jax.experimental.PartitionSpec mesh_transformer.util.global_norm numpy.ones Exception zip mesh_transformer.util.to_f32.to_bf16.early_cast id process_request next jax.experimental.maps.Mesh file.prefetch.apply init_decode_apply file_index.dir.np.load.keys self.input_proj jax.numpy.argsort lm_eval.tasks.get_task_dict batch_items.items mesh_transformer.build_model.build_model.move last.astype app.run.threading.Thread.start fabric.Connection config.EmbeddingShardV2 requests.delete.json self.generate_xmap ftfy.fix_text sum jax.numpy.multiply timer haiku.data_structures.tree_size self.init_pjit config.config.TransformerLayerShardV2.decode_once CausalTransformerShard.generate_initial infer CausalTransformerShard mesh_transformer.util.to_f32.to_bf16.bf16_optimizer jax.host_id self.reset file.prefetch.prefetch config.Projection.loss open apply_rotary_pos_emb nucleaus_filter jax.random.PRNGKey config.Projection json.dump p.imap self.output_proj val_grad_fn mesh_transformer.build_model.build_model.train mesh_transformer.transformer_shard.CausalTransformer.move_xmap io.BytesIO x.x.all x.self.v.reshape self.network_builder.train val_set.get_samples optax.scale all_top_p.append setuptools.find_packages self.output iter_decode_apply jax.experimental.pjit.with_sharding_constraint smart_open.open random.randint residual.hk.remat pytree.items tasks.util.shrink_seq pad_amount.tokens.np.pad.astype decode tensorflow.io.VarLenFeature self.network_builder.write_ckpt fastapi.FastAPI node.write_ckpt.remote qid.self.queue_ids.put last_loss.np.array.mean all_ctx.append check_tpu conn.run list func_timeout.func_set_timeout numpy.maximum.astype jax.numpy.concatenate conn.put tasks.eval_harness.EvalHarnessAdaptor self.infer_batch jax.lax.scan delete_tpu ops.get_gptj_model.start_background argparse.ArgumentParser lm_eval.evaluator.evaluate fastapi.FastAPI.post tensorflow.sparse.to_dense threading.Thread.start numpy.minimum numpy.sum.tolist jax.value_and_grad mesh_transformer.train_actor.NetworkRunner.options range new_vals.append reshard.all optax.apply_updates split self.network.generate tensorflow.sparse.reorder self.dense_proj optax.scale_by_schedule ReplicatedLayerNorm self.q jax.numpy.where all_temp.append google.cloud.storage.Client.list_blobs f_pmean.defvjp eval_step optax.GradientTransformation jax.lax.psum parallel_write init_fns blob.delete params.append eval_loss_fn tasks.EvalHarnessAdaptor projection_init_fn jax.tree_leaves reshard params.get res.ray.get.i.i.np.array.mean numpy.savez wandb.init grouper numpy.max compression.i.tf.data.TFRecordDataset.map optax.scale_by_adam haiku.PRNGSequence jax.numpy.zeros_like time.time logging.getLogger self.embed.hk.remat multiprocessing.Pool x.reshape.append jax.numpy.square last_loss.append numpy.pad mesh_transformer.build_model.build_model mesh_transformer.transformer_shard.CausalTransformer.train self.o mesh_transformer.checkpoint.load_ckpt_v2 self.network_builder self.glu jax.experimental.maps.mesh glob.glob ray.init tensorflow.io.FixedLenFeature self.tokenizer.add_special_tokens self.norm sch self.pool.imap reshaped.reshape.reshape numpy.cos jax.tree_flatten tqdm.tqdm step.step.all self.sample_once self.map_fn batch.append noise_scale_stats.update ray_tpu.wait_til optimizer.update super.__init__ read_sharded_v2 val_set.sample_once max logging.getLogger.debug save queue.Queue.get jax.devices.np.array.reshape self.v haiku.remat self.qvk_proj fastapi.FastAPI.add_middleware self.network_builder.load_ckpt haiku.initializers.TruncatedNormal all_length.append haiku.experimental.optimize_rng_use json.load.pop temp.logits.key.jax.random.categorical.astype mesh_transformer.util.to_f32.to_bf16.config val_sets.values numpy.sin jax.numpy.zeros tasks.util.sample_batch requests.delete self.eval_xmap haiku.Linear self.network_builder.eval tfrecord_loader.TFRecordNewInputs.get_samples optax.additive_weight_decay task_res.items jax.numpy.zeros.block_until_ready haiku.LayerNorm self.used.append jax.numpy.sum fixed_pos_embedding self.input_q.put setuptools.setup jax.numpy.arange jax.numpy.asarray apply_rotary_pos_emb_v2 warnings.filterwarnings functools.lru_cache softmax_sample format traceback.print_exc numpy.dtype transformer_init_fn n.get_params.remote v.attention_weights.jnp.einsum.reshape payloads.CompletionResponse itertools.chain jax.numpy.exp.sum mesh_transformer.build_model.build_model.save ray.is_initialized aux.get reshard.reshape jax.devices ray.remote self.prepare_item scheduler mesh_transformer.util.maybe_shard ray_tpu.get_connection jax.numpy.clip ray.get node.move_params.remote json.load.get divmod p.imap_unordered jax.numpy.einsum grad_norm_micro.np.array.mean CausalTransformerShard.loss gpt3_schedule index_weights parallel_read fix_dtype jax.lax.pmean.mean subprocess.check_output.decode optax.AdditiveWeightDecayState numpy.load jax.numpy.mean numpy.concatenate TFRecordWIT n.generate.remote jax.experimental.maps.xmap outputs.append index_fname.open.read float CausalTransformerShard.generate_once ray_tpu.create_tpu val_sets.items mesh_transformer.checkpoint.read_ckpt bcast_iota.Ellipsis.jnp.newaxis.rp_bucket.jnp.array.astype self.nodes.append mesh_transformer.util.g_psum self.dim_per_head.np.sqrt.astype transformers.GPT2TokenizerFast.from_pretrained samples.append dp.mp.key.next.jax.random.uniform.astype tree_leaves_with_names subprocess.check_output get_inital queue.Queue.qsize conn.sudo q.put transformer_apply_fn jax.numpy.cumsum l.get_init_decode_state mesh_transformer.util.gpt3_schedule self.tpu.eval optimizer.init haiku.next_rng_key jax.numpy.split parse_args all print ops.get_gptj_model.wait_for_queue val_loss.np.array.mean argparse.ArgumentParser.parse_args self.queue.get self.embed RMSNorm mesh_transformer.layers.RelativePositionEmbs numpy.stack self.head_split requests.get g_psum.defvjp process_init self.network_builder.move_xmap mesh_transformer.checkpoint.write_ckpt_v2 logging.getLogger.info map filter jax.numpy.broadcast_to self.eval_pjit uvicorn.run einops.repeat mesh_transformer.train_actor.NetworkRunner.options.remote jax.numpy.reshape self.d_head.np.sqrt.astype len self.get_samples gc.collect mesh_transformer.util.to_f32 self.norm.reshape TFRecordWIT.sample_once haiku.PRNGSequence.take einops.rearrange sampler float.update index_fname.open.read.splitlines numpy.einsum out.reshape.reshape haiku.without_apply_rng tensorflow.io.parse_single_example self.output_q.put all_tokenized.append shrink_seq self.transformer_layers.append network.state.count.item bool embed_init_fn self.tokenizer.decode RuntimeError all_q.append multiprocessing.set_start_method haiku.initializers.Constant tensorflow.data.experimental.dense_to_ragged_batch tensorflow.cast jax.nn.gelu mesh_transformer.build_model.build_model.load jax.device_put max_exact.num_buckets.max_exact.max_distance.np.log.np.float32.np.finfo.eps.max_exact.np.float32.n.astype.np.log.astype mesh_transformer.util.to_bf16 convert_fn max_lengths.append tqdm.tqdm.set_postfix subprocess.check_output.decode.strip isinstance numpy.array_split numpy.logical_and jax.numpy.cos mesh_transformer.util.f_psum enumerate flask.Flask.route mesh_transformer.layers.EmbeddingShardV2 self.init_xmap seq.np.zeros.astype os.path.expanduser jax.device_count contexts.append self.k numpy.maximum
@developer
Could please help me check this issue?
May I pull a request to fix it?
Thank you very much.