LSHSelfAttention only works with different lengths if the parameter use_reference_code is set to "True"
renevs opened this issue · comments
Description
LSHSelfAttention only works with different lengths, after initiated, if the parameter use_reference_code is set to "True". So I cant use the LSHSelfAttention in a Reformer Model with BucketByLength.
Environment information
OS: Ubuntu 18.04.1 LTS
$ pip freeze | grep trax
trax==1.3.9
$ pip freeze | grep tensor
mesh-tensorflow==0.1.19
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.5.0
tensorflow-datasets==4.3.0
tensorflow-estimator==2.5.0
tensorflow-hub==0.12.0
tensorflow-metadata==1.0.0
tensorflow-text==2.5.0
$ pip freeze | grep jax
jax==0.2.16
jaxlib==0.1.67
jupyter-server-mathjax==0.2.2
$ python -V
Python 3.8.10
For bugs: reproduction and error logs
# Steps to reproduce:
import trax
from trax import layers as tl
import jax.numpy as jnp
shapedtype = trax.shapes.ShapeDtype((4,32,512), dtype=jnp.int32)
# g = tl.SelfAttention( mode='train ')
g = tl.LSHSelfAttention(chunk_len=8,
use_reference_code=False,
mode = 'train')
g.init(shapedtype)
valores = jnp.ones((8,16,512))
g(valores)
Error logs:
LayerError Traceback (most recent call last)
in
12 g.init(shapedtype)
13 valores = jnp.ones((8,16,512))
---> 14 g(valores)
~/anaconda3/envs/ambiente_ipdr/lib/python3.8/site-packages/trax/layers/base.py in call(self, x, weights, state, rng)
195 self.state = state # Needed if the model wasn't fully initialized.
196 state = self.state
--> 197 outputs, new_state = self.pure_fn(x, weights, state, rng)
198 self.state = new_state
199 return outputs
~/anaconda3/envs/ambiente_ipdr/lib/python3.8/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
603 # Skipping 3 lines as it's always the uninteresting internal call.
604 name, trace = self._name, _short_traceback(skip=3)
--> 605 raise LayerError(name, 'pure_fn',
606 self._caller, signature(x), trace) from None
607
LayerError: Exception passing through layer LSHSelfAttention (in pure_fn):
layer created in file [...]/layers/research/efficient_attention.py, line 1744
layer input shapes: ShapeDtype{shape:(8, 16, 512), dtype:float32}
File [...]/trax/layers/base.py, line 673, in _do_custom_gradients
output, state = do_forward(self.state, self._rng, x, self.weights)
File [...]/jax/_src/custom_derivatives.py, line 486, in call
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat,
File [...]/jax/_src/custom_derivatives.py, line 566, in bind
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
File [...]/site-packages/jax/core.py, line 617, in process_custom_vjp_call
return fun.call_wrapped(*tracers)
File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/trax/fastmath/jax.py, line 167, in _f
return f(*args, **kwargs)
File [...]/trax/layers/base.py, line 651, in _f
res = self.forward(y)
File [...]/layers/research/efficient_attention.py, line 2117, in forward
output, new_state, _, _ = self.forward_and_or_backward(
File [...]/layers/research/efficient_attention.py, line 2538, in forward_and_or_backward
loop_val = fastmath.fori_loop(
File [...]/trax/fastmath/ops.py, line 173, in fori_loop
return backend()['fori_loop'](lower, upper, body_fn, init_val)
File [...]/jax/_src/traceback_util.py, line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File [...]/src/lax/control_flow.py, line 212, in fori_loop
(, result), _ = scan(fori_scan_body_fun(body_fun), (lower, init_val),
File [...]/jax/_src/traceback_util.py, line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File [...]/_src/lax/control_flow.py, line 1288, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File [...]/_src/lax/control_flow.py, line 1274, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(
File [...]/jax/_src/util.py, line 186, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File [...]/jax/_src/util.py, line 179, in cached
return f(*args, **kwargs)
File [...]/_src/lax/control_flow.py, line 76, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
File [...]/jax/_src/util.py, line 186, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File [...]/jax/_src/util.py, line 179, in cached
return f(*args, **kwargs)
File [...]/_src/lax/control_flow.py, line 70, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File [...]/jax/interpreters/partial_eval.py, line 1252, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File [...]/jax/interpreters/partial_eval.py, line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/_src/lax/control_flow.py, line 143, in scanned_fun
return (i + 1, body_fun(i, x)), None
File [...]/layers/research/efficient_attention.py, line 2421, in run_inner
s_all = tree_update(s_all, idx, s_h)
File [...]/layers/research/efficient_attention.py, line 2352, in tree_update
return fastmath.nested_map_multiarg(
File [...]/trax/fastmath/numpy.py, line 136, in nested_map_multiarg
return tuple([nested_map_multiarg(f, *[o[i] for o in objs])
File [...]/trax/fastmath/numpy.py, line 136, in
return tuple([nested_map_multiarg(f, *[o[i] for o in objs])
File [...]/trax/fastmath/numpy.py, line 143, in nested_map_multiarg
return f(*objs)
File [...]/layers/research/efficient_attention.py, line 2353, in
lambda x, y: fastmath.index_update(x, jax.ops.index[indices], y),
File [...]/trax/fastmath/ops.py, line 199, in index_update
return backend()['index_update'](*args, **kwargs)
File [...]/_src/ops/scatter.py, line 351, in index_update
return _scatter_update(
File [...]/_src/ops/scatter.py, line 68, in _scatter_update
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
File [...]/_src/ops/scatter.py, line 90, in _scatter_impl
y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
File [...]/_src/numpy/lax_numpy.py, line 1816, in broadcast_to
raise ValueError(msg.format(arr_shape, shape))
ValueError: Incompatible shapes for broadcasting: (16,) and requested shape (32,)