Floating point conversion issue with `use_tucker`
jgreener64 opened this issue · comments
The following works fine for me:
python eat.py --lookup test.fasta --queries test.fasta --output test/
But when I add --use_tucker 1
I get:
Start loading ProtT5...
Finished loading Rostlab/prot_t5_xl_half_uniref50-enc in 28.2[s]
Start generating embeddings for 50 proteins.This process might take a few minutes.Using batch-processing! If you run OOM/RuntimeError, you should use single-sequence embedding by setting max_batch=1.
Creating per-protein embeddings took: 1.4[s]
Start generating embeddings for 50 proteins.This process might take a few minutes.Using batch-processing! If you run OOM/RuntimeError, you should use single-sequence embedding by setting max_batch=1.
Creating per-protein embeddings took: 0.7[s]
No existing model found. Start downloading pre-trained ProtTucker(ProtT5)...
Loading Tucker checkpoint from: temp/tucker_weights.pt
Traceback (most recent call last):
File "/home/jgreener/soft/EAT/eat.py", line 515, in <module>
main()
File "/home/jgreener/soft/EAT/eat.py", line 496, in main
eater = EAT(lookup_p, query_p, output_d,
File "/home/jgreener/soft/EAT/eat.py", line 220, in __init__
self.lookup_embs = self.tucker_embeddings(self.lookup_embs)
File "/home/jgreener/soft/EAT/eat.py", line 245, in tucker_embeddings
dataset = model.single_pass(dataset)
File "/home/jgreener/soft/EAT/eat.py", line 36, in single_pass
return self.tucker(x)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward
return F.linear(input, self.weight, self.bias)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/functional.py", line 1848, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: expected scalar type Float but found Half
I am on Python 3.9.16, PyTorch 1.10.0, h5py 3.6.0, numpy 1.22.0, scikit-learn 0.24.2 and transformers 4.17.0. test.fasta
is uploaded as test.txt.
Hey :)
first of all: thanks for your feedback!
On your issue: the problem is that the current embedder, ProtT5, is run in half-precision which also produces embeddings of this datatype. In its current version, ProtTucker's weight are still loaded in full-precision which causes this RuntimeError.
So you can either a) up-cast the embeddings to fp32 before feeding them to Tucker or b) down-cast ProtTucker to fp16.
Depending on the size of your set and how speed-sensitive your application is, I would probably go for solution a) if you have a small enough set and only for version b) if you want to search large sets (millions of proteins) against each other.
For a) you would need to add an up-casting of the embeddings before this line via. sth like self.lookup_embs=self.lookup_embs.astype(np.float)
(do for both, lookup & targets). For b) you would only need to add model=model.half()
somewhere here.
Hope this helps; let me know if this solved your issue;
That worked, thanks. I added the following lines before https://github.com/Rostlab/EAT/blob/main/eat.py#L220:
self.lookup_embs = self.lookup_embs.to(torch.float)
self.query_embs = self.query_embs.to(torch.float)