kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Project dependencies may have API risk issues

PyDeps opened this issue · comments

Hi, In mesh-transformer-jax, inappropriate dependency versioning constraints can cause risks.

Below are the dependencies and version constraints that the project is using

numpy~=1.19.5
tqdm>=4.45.0
wandb>=0.11.2
einops~=0.3.0
requests~=2.25.1
fabric~=2.6.0
optax==0.0.9
dm-haiku==0.0.5
git+https://github.com/EleutherAI/lm-evaluation-harness/
ray[default]==1.4.1
jax~=0.2.12
Flask~=1.1.2
cloudpickle~=1.3.0
tensorflow-cpu~=2.6.0
google-cloud-storage~=1.36.2
transformers
smart_open[gcs]
func_timeout
ftfy
fastapi
uvicorn
lm_dataformat
pathy

The version constraint == will introduce the risk of dependency conflicts because the scope of dependencies is too strict.
The version constraint No Upper Bound and * will introduce the risk of the missing API Error because the latest version of the dependencies may remove some APIs.

After further analysis, in this project,
The version constraint of dependency tqdm can be changed to >=4.36.0,<=4.64.0.
The version constraint of dependency wandb can be changed to >=0.5.16,<=0.9.7.
The version constraint of dependency dm-haiku can be changed to >=0.0.1,<=0.0.8.
The version constraint of dependency google-cloud-storage can be changed to >=1.17.0,<=2.4.0.
The version constraint of dependency fastapi can be changed to >=0.1.2,<=0.78.0.

The above modification suggestions can reduce the dependency conflicts as much as possible,
and introduce the latest version as much as possible without calling Error in the projects.

The invocation of the current project includes all the following methods.

The calling methods from the tqdm
tqdm.tqdm.set_postfix
tqdm.tqdm.update
tqdm.tqdm
The calling methods from the wandb
wandb.init
wandb.log
The calling methods from the dm-haiku
haiku.PRNGSequence.take
haiku.without_apply_rng
haiku.experimental.optimize_rng_use
haiku.data_structures.tree_size
The calling methods from the google-cloud-storage
google.cloud.storage.Client.list_blobs
The calling methods from the fastapi
fastapi.FastAPI.post
fastapi.FastAPI
The calling methods from the all methods
optax.chain
rotate_every_two
mesh_transformer.checkpoint.write_ckpt
x.reshape
jax.eval_shape
flask.Flask
multiprocessing.pool.ThreadPool
haiku.Flatten
get_bearer
jax.numpy.exp
queue.Queue
mesh_transformer.TPU_cluster.TPUCluster
mesh_transformer.layers.TransformerLayerShardV2
json.load
time.sleep
n.run.remote
flask.make_response
config.config.TransformerLayerShardV2.get_init_decode_state
super
self.proj.loss.hk.remat
self.move_weights_pjit
shard_strategy.count
vocab.example_shape.key.next.jax.random.uniform.astype
numpy.sum
jax.experimental.pjit.pjit
val_loss.np.array.mean.append
eval_apply_fn
self._relative_position_bucket
f.read
to_id
jax.lax.dot_general
ops.get_gptj_model
jax.nn.softmax
numpy.logical_and.tolist
jax.numpy.stack
self.queue.put
mesh_transformer.layers.ProjectionShard
jax.lax.sort_key_val
self.config.get
response.headers.add
l.decode_once
jax.local_device_count
lm_eval.evaluator.evaluate.items
rotate_every_two_v2
threading.Lock
jax.numpy.linalg.norm
x.jnp.transpose.reshape
argparse.ArgumentParser.add_argument
mesh_transformer.layers.TransformerLayerShard
ckpt_step.str.meta.get
p.map
val_set.reset
mesh_transformer.layers.Projection
projection_apply_fn
jax.numpy.array_equal
tqdm.tqdm.update
numpy.zeros_like
qids.append
jax.tree_unflatten
tensorflow.data.TFRecordDataset
param_init_fn
self.rpe
_unshard
self.queue_ids.get
compile_model
mesh_transformer.util.head_print
n.eval.remote
jax.numpy.log
ops.get_gptj_model.add_to_queue
f_psum
mesh_transformer.build_model.build_model.eval
haiku.transform
jax.random.uniform
jax.host_count
haiku.get_parameter
numpy.array
self.convert_requests
n.train.remote
jax.numpy.var
config.config.TransformerLayerShardV2
grad_norm.np.array.mean
x.self.k.reshape
batch_items.append
x.reshape.reshape
getattr
tfrecord_loader.TFRecordNewInputs.get_state
self.proj
self.output_q.get
threading.Thread
data.items
json.load.append
numpy.prod
jax.lax.stop_gradient
x.numpy
node.load_ckpt.remote
input
numpy.sqrt
numpy.where
jax.numpy.sort
min
queue.Queue.put
transformers.GPT2TokenizerFast.from_pretrained.decode
fastapi.FastAPI.on_event
self.tokenizer.encode
numpy.arange
self.input
jax.tree_map
f_psum.defvjp
_corsify_actual_response
_unshard.append
global_norm
self.eval
config.get
tree_flatten_with_names
numpy.empty
states.append
self.train_pjit
flask.jsonify
state.items
is_leaf
json.load.items
jax.numpy.transpose
jax.numpy.array
self.input_q.get
given_length.astype
numpy.log
jax.random.split
repr
ClipByGlobalNormState
exit
int
i.decode
apply_fns
new_states.append
iter
all_array_equal
z_loss.sum_exp_logits.jnp.log.jnp.square.mean
GPTJ
self.to_data
l.hk.remat
read_remote.remote
mesh_transformer.transformer_shard.CausalTransformer.generate
getnorm
requests.post
mesh_transformer.transformer_shard.CausalTransformer
x.reshape.astype
batch_flattened.append
wandb.log
loss.np.array.mean
functools.partial
jax.lax.pmax
jax.lax.rsqrt
jax.lax.broadcasted_iota
tokenizer
self.train_xmap
generate_fn
numpy.finfo
jax.nn.one_hot
str
self.network_builder.generate
numpy.tril
self.ff
mesh_transformer.transformer_shard.CausalTransformer.eval
payloads.QueueResponse
_build_cors_prelight_response
mesh_transformer.layers.EmbeddingShard
self.proj.max
attention_vec.reshape.reshape
ray.get.append
google.cloud.storage.Client
jax.random.categorical
itertools.cycle
numpy.zeros
embed_apply_fn
x.self.q.reshape
self.self_attn
x.jnp.zeros_like.astype
itertools.zip_longest
transformers.GPT2TokenizerFast.from_pretrained.encode
jax.lax.all_gather
unstacked.append
jax.tree_multimap
jax.experimental.maps.ResourceEnv
g_psum
jax.numpy.sqrt
jax.lax.pmean
train_step
jax.lax.axis_index
loss.append
mesh_transformer.util.additive_weight_decay
output.append
mesh_transformer.util.clip_by_global_norm
qid.self.queue_ids.get
numpy.array.mean
NotImplementedError
ValueError
get_project
self.dense_proj_o
self.prepare_item.get
ops.get_gptj_model.load_model
tfrecord_loader.TFRecordNewInputs
json.dumps
jax.experimental.PartitionSpec
mesh_transformer.util.global_norm
numpy.ones
Exception
zip
mesh_transformer.util.to_f32.to_bf16.early_cast
id
process_request
next
jax.experimental.maps.Mesh
file.prefetch.apply
init_decode_apply
file_index.dir.np.load.keys
self.input_proj
jax.numpy.argsort
lm_eval.tasks.get_task_dict
batch_items.items
mesh_transformer.build_model.build_model.move
last.astype
app.run.threading.Thread.start
fabric.Connection
config.EmbeddingShardV2
requests.delete.json
self.generate_xmap
ftfy.fix_text
sum
jax.numpy.multiply
timer
haiku.data_structures.tree_size
self.init_pjit
config.config.TransformerLayerShardV2.decode_once
CausalTransformerShard.generate_initial
infer
CausalTransformerShard
mesh_transformer.util.to_f32.to_bf16.bf16_optimizer
jax.host_id
self.reset
file.prefetch.prefetch
config.Projection.loss
open
apply_rotary_pos_emb
nucleaus_filter
jax.random.PRNGKey
config.Projection
json.dump
p.imap
self.output_proj
val_grad_fn
mesh_transformer.build_model.build_model.train
mesh_transformer.transformer_shard.CausalTransformer.move_xmap
io.BytesIO
x.x.all
x.self.v.reshape
self.network_builder.train
val_set.get_samples
optax.scale
all_top_p.append
setuptools.find_packages
self.output
iter_decode_apply
jax.experimental.pjit.with_sharding_constraint
smart_open.open
random.randint
residual.hk.remat
pytree.items
tasks.util.shrink_seq
pad_amount.tokens.np.pad.astype
decode
tensorflow.io.VarLenFeature
self.network_builder.write_ckpt
fastapi.FastAPI
node.write_ckpt.remote
qid.self.queue_ids.put
last_loss.np.array.mean
all_ctx.append
check_tpu
conn.run
list
func_timeout.func_set_timeout
numpy.maximum.astype
jax.numpy.concatenate
conn.put
tasks.eval_harness.EvalHarnessAdaptor
self.infer_batch
jax.lax.scan
delete_tpu
ops.get_gptj_model.start_background
argparse.ArgumentParser
lm_eval.evaluator.evaluate
fastapi.FastAPI.post
tensorflow.sparse.to_dense
threading.Thread.start
numpy.minimum
numpy.sum.tolist
jax.value_and_grad
mesh_transformer.train_actor.NetworkRunner.options
range
new_vals.append
reshard.all
optax.apply_updates
split
self.network.generate
tensorflow.sparse.reorder
self.dense_proj
optax.scale_by_schedule
ReplicatedLayerNorm
self.q
jax.numpy.where
all_temp.append
google.cloud.storage.Client.list_blobs
f_pmean.defvjp
eval_step
optax.GradientTransformation
jax.lax.psum
parallel_write
init_fns
blob.delete
params.append
eval_loss_fn
tasks.EvalHarnessAdaptor
projection_init_fn
jax.tree_leaves
reshard
params.get
res.ray.get.i.i.np.array.mean
numpy.savez
wandb.init
grouper
numpy.max
compression.i.tf.data.TFRecordDataset.map
optax.scale_by_adam
haiku.PRNGSequence
jax.numpy.zeros_like
time.time
logging.getLogger
self.embed.hk.remat
multiprocessing.Pool
x.reshape.append
jax.numpy.square
last_loss.append
numpy.pad
mesh_transformer.build_model.build_model
mesh_transformer.transformer_shard.CausalTransformer.train
self.o
mesh_transformer.checkpoint.load_ckpt_v2
self.network_builder
self.glu
jax.experimental.maps.mesh
glob.glob
ray.init
tensorflow.io.FixedLenFeature
self.tokenizer.add_special_tokens
self.norm
sch
self.pool.imap
reshaped.reshape.reshape
numpy.cos
jax.tree_flatten
tqdm.tqdm
step.step.all
self.sample_once
self.map_fn
batch.append
noise_scale_stats.update
ray_tpu.wait_til
optimizer.update
super.__init__
read_sharded_v2
val_set.sample_once
max
logging.getLogger.debug
save
queue.Queue.get
jax.devices.np.array.reshape
self.v
haiku.remat
self.qvk_proj
fastapi.FastAPI.add_middleware
self.network_builder.load_ckpt
haiku.initializers.TruncatedNormal
all_length.append
haiku.experimental.optimize_rng_use
json.load.pop
temp.logits.key.jax.random.categorical.astype
mesh_transformer.util.to_f32.to_bf16.config
val_sets.values
numpy.sin
jax.numpy.zeros
tasks.util.sample_batch
requests.delete
self.eval_xmap
haiku.Linear
self.network_builder.eval
tfrecord_loader.TFRecordNewInputs.get_samples
optax.additive_weight_decay
task_res.items
jax.numpy.zeros.block_until_ready
haiku.LayerNorm
self.used.append
jax.numpy.sum
fixed_pos_embedding
self.input_q.put
setuptools.setup
jax.numpy.arange
jax.numpy.asarray
apply_rotary_pos_emb_v2
warnings.filterwarnings
functools.lru_cache
softmax_sample
format
traceback.print_exc
numpy.dtype
transformer_init_fn
n.get_params.remote
v.attention_weights.jnp.einsum.reshape
payloads.CompletionResponse
itertools.chain
jax.numpy.exp.sum
mesh_transformer.build_model.build_model.save
ray.is_initialized
aux.get
reshard.reshape
jax.devices
ray.remote
self.prepare_item
scheduler
mesh_transformer.util.maybe_shard
ray_tpu.get_connection
jax.numpy.clip
ray.get
node.move_params.remote
json.load.get
divmod
p.imap_unordered
jax.numpy.einsum
grad_norm_micro.np.array.mean
CausalTransformerShard.loss
gpt3_schedule
index_weights
parallel_read
fix_dtype
jax.lax.pmean.mean
subprocess.check_output.decode
optax.AdditiveWeightDecayState
numpy.load
jax.numpy.mean
numpy.concatenate
TFRecordWIT
n.generate.remote
jax.experimental.maps.xmap
outputs.append
index_fname.open.read
float
CausalTransformerShard.generate_once
ray_tpu.create_tpu
val_sets.items
mesh_transformer.checkpoint.read_ckpt
bcast_iota.Ellipsis.jnp.newaxis.rp_bucket.jnp.array.astype
self.nodes.append
mesh_transformer.util.g_psum
self.dim_per_head.np.sqrt.astype
transformers.GPT2TokenizerFast.from_pretrained
samples.append
dp.mp.key.next.jax.random.uniform.astype
tree_leaves_with_names
subprocess.check_output
get_inital
queue.Queue.qsize
conn.sudo
q.put
transformer_apply_fn
jax.numpy.cumsum
l.get_init_decode_state
mesh_transformer.util.gpt3_schedule
self.tpu.eval
optimizer.init
haiku.next_rng_key
jax.numpy.split
parse_args
all
print
ops.get_gptj_model.wait_for_queue
val_loss.np.array.mean
argparse.ArgumentParser.parse_args
self.queue.get
self.embed
RMSNorm
mesh_transformer.layers.RelativePositionEmbs
numpy.stack
self.head_split
requests.get
g_psum.defvjp
process_init
self.network_builder.move_xmap
mesh_transformer.checkpoint.write_ckpt_v2
logging.getLogger.info
map
filter
jax.numpy.broadcast_to
self.eval_pjit
uvicorn.run
einops.repeat
mesh_transformer.train_actor.NetworkRunner.options.remote
jax.numpy.reshape
self.d_head.np.sqrt.astype
len
self.get_samples
gc.collect
mesh_transformer.util.to_f32
self.norm.reshape
TFRecordWIT.sample_once
haiku.PRNGSequence.take
einops.rearrange
sampler
float.update
index_fname.open.read.splitlines
numpy.einsum
out.reshape.reshape
haiku.without_apply_rng
tensorflow.io.parse_single_example
self.output_q.put
all_tokenized.append
shrink_seq
self.transformer_layers.append
network.state.count.item
bool
embed_init_fn
self.tokenizer.decode
RuntimeError
all_q.append
multiprocessing.set_start_method
haiku.initializers.Constant
tensorflow.data.experimental.dense_to_ragged_batch
tensorflow.cast
jax.nn.gelu
mesh_transformer.build_model.build_model.load
jax.device_put
max_exact.num_buckets.max_exact.max_distance.np.log.np.float32.np.finfo.eps.max_exact.np.float32.n.astype.np.log.astype
mesh_transformer.util.to_bf16
convert_fn
max_lengths.append
tqdm.tqdm.set_postfix
subprocess.check_output.decode.strip
isinstance
numpy.array_split
numpy.logical_and
jax.numpy.cos
mesh_transformer.util.f_psum
enumerate
flask.Flask.route
mesh_transformer.layers.EmbeddingShardV2
self.init_xmap
seq.np.zeros.astype
os.path.expanduser
jax.device_count
contexts.append
self.k
numpy.maximum

@developer
Could please help me check this issue?
May I pull a request to fix it?
Thank you very much.