Rostlab / EAT

Embedding-based annotation transfer (EAT) uses Euclidean distance between vector representations (embeddings) of proteins to transfer annotations from a set of labeled lookup protein embeddings to query protein embedding.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Floating point conversion issue with `use_tucker`

jgreener64 opened this issue · comments

The following works fine for me:

python eat.py --lookup test.fasta --queries test.fasta --output test/

But when I add --use_tucker 1 I get:

Start loading ProtT5...
Finished loading Rostlab/prot_t5_xl_half_uniref50-enc in 28.2[s]
Start generating embeddings for 50 proteins.This process might take a few minutes.Using batch-processing! If you run OOM/RuntimeError, you should use single-sequence embedding by setting max_batch=1.
Creating per-protein embeddings took: 1.4[s]
Start generating embeddings for 50 proteins.This process might take a few minutes.Using batch-processing! If you run OOM/RuntimeError, you should use single-sequence embedding by setting max_batch=1.
Creating per-protein embeddings took: 0.7[s]
No existing model found. Start downloading pre-trained ProtTucker(ProtT5)...
Loading Tucker checkpoint from: temp/tucker_weights.pt
Traceback (most recent call last):
  File "/home/jgreener/soft/EAT/eat.py", line 515, in <module>
    main()
  File "/home/jgreener/soft/EAT/eat.py", line 496, in main
    eater = EAT(lookup_p, query_p, output_d,
  File "/home/jgreener/soft/EAT/eat.py", line 220, in __init__
    self.lookup_embs = self.tucker_embeddings(self.lookup_embs)
  File "/home/jgreener/soft/EAT/eat.py", line 245, in tucker_embeddings
    dataset = model.single_pass(dataset)
  File "/home/jgreener/soft/EAT/eat.py", line 36, in single_pass
    return self.tucker(x)
  File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/functional.py", line 1848, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: expected scalar type Float but found Half

I am on Python 3.9.16, PyTorch 1.10.0, h5py 3.6.0, numpy 1.22.0, scikit-learn 0.24.2 and transformers 4.17.0. test.fasta is uploaded as test.txt.

Hey :)
first of all: thanks for your feedback!
On your issue: the problem is that the current embedder, ProtT5, is run in half-precision which also produces embeddings of this datatype. In its current version, ProtTucker's weight are still loaded in full-precision which causes this RuntimeError.
So you can either a) up-cast the embeddings to fp32 before feeding them to Tucker or b) down-cast ProtTucker to fp16.
Depending on the size of your set and how speed-sensitive your application is, I would probably go for solution a) if you have a small enough set and only for version b) if you want to search large sets (millions of proteins) against each other.
For a) you would need to add an up-casting of the embeddings before this line via. sth like self.lookup_embs=self.lookup_embs.astype(np.float) (do for both, lookup & targets). For b) you would only need to add model=model.half() somewhere here.

Hope this helps; let me know if this solved your issue;

That worked, thanks. I added the following lines before https://github.com/Rostlab/EAT/blob/main/eat.py#L220:

            self.lookup_embs = self.lookup_embs.to(torch.float)
            self.query_embs = self.query_embs.to(torch.float)