quantize_embeddings + KeyedJaggedTensor+ vbe cannot work
yjjinjie opened this issue · comments
yjjinjie commented
import torch
from torchrec import KeyedJaggedTensor
from torchrec import EmbeddingBagConfig,EmbeddingConfig
from torchrec import EmbeddingBagCollection,EmbeddingCollection
kt2 = KeyedJaggedTensor(
keys=['user_id', 'item_id', 'id_3', 'id_4', 'id_5', 'raw_1', 'raw_4', 'combo_1', 'lookup_2', 'lookup_3', 'lookup_4', 'match_2', 'match_3', 'match_4', 'click_50_seq__item_id', 'click_50_seq__id_3', 'click_50_seq__raw_1'],
values=torch.tensor([573174, 5073, 3562, 3, 18, 13, 11, 49, 26,
4, 2, 2, 4, 2, 4, 736847, 849333, 997432,
640218, 9926, 9926, 0, 0, 0, 0, 59926, 59926,
0, 0, 0, 0, 2835, 769, 1265, 8232, 6399,
114, 7487, 2876, 953, 7840, 7538, 7998, 7852, 3528,
1475, 7620, 6110, 572, 735, 4405, 5655, 6736, 2173,
3421, 2311, 7122, 2159, 4535, 2162, 4657, 3151, 4522,
1075, 306, 8968, 2056, 2256, 3919, 8624, 5372, 6018,
3861, 4114, 3984, 2287, 1481, 4757, 1189, 2518, 913,
9421, 3093, 5911, 9704, 8168, 9410, 728, 2451, 243,
5187, 5836, 8830, 4894, 614, 7705, 9258, 3518, 4434,
4, 2, 4, 2, 4, 2, 3, 2, 2,
3, 3, 3, 4, 4, 3, 0, 4, 0,
2, 2, 3, 4, 4, 0, 2, 2, 4,
0, 3, 2, 2, 3, 0, 4, 0, 4,
4, 4, 2, 2, 3, 4, 2, 4, 3,
4, 2, 4, 2, 2, 2, 2, 0, 3,
4, 4, 3, 2, 4, 4, 4, 4, 3,
2, 3, 4, 2, 4, 0, 4, 4, 4,
4, 0, 0, 2, 1, 1, 0, 3, 4,
4, 2, 4, 1, 1, 4, 2, 2, 4,
0, 4, 4, 4, 4, 4, 1, 4, 2,
0, 0, 0, 2, 4, 4, 2, 4, 2,
4, 4, 1, 1, 4, 1, 4, 4, 1,
0, 4, 4, 4, 3, 0, 0, 2, 4,
2, 2, 4, 4, 4, 2, 2, 4, 2,
3]),
lengths=torch.tensor([ 1, 1, 1, 1, 0, 0, 1, 2, 2, 1, 1, 4, 2, 2, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 24, 44, 24, 44, 24, 44], dtype=torch.int64),
stride_per_key_per_rank=[[1], [2], [2], [2], [2], [2], [1], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2]],
inverse_indices=(['user_id', 'item_id', 'id_3', 'id_4', 'id_5', 'raw_1', 'raw_4', 'combo_1', 'lookup_2', 'lookup_3',
'lookup_4', 'match_2', 'match_3', 'match_4', 'click_50_seq__item_id', 'click_50_seq__id_3',
'click_50_seq__raw_1'],
torch.tensor([[0, 0], [0, 1],[0, 1], [0, 1], [0, 1], [0, 1],[0, 0], [0, 1], [0, 1], [0, 1],
[0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]])
)
)
eb_configs2=[
EmbeddingBagConfig(num_embeddings=1000000, embedding_dim=16, name='user_id_emb', feature_names=['user_id'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=10000, embedding_dim=16, name='item_id_emb', feature_names=['item_id'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=8, name='id_3_emb', feature_names=['id_3'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=100, embedding_dim=16, name='id_4_emb', feature_names=['id_4', 'id_5'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='raw_1_emb', feature_names=['raw_1'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='raw_4_emb', feature_names=['raw_4'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=1000000, embedding_dim=16, name='combo_1_emb', feature_names=['combo_1'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=10000, embedding_dim=8, name='lookup_2_emb', feature_names=['lookup_2'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=1000, embedding_dim=8, name='lookup_3_emb', feature_names=['lookup_3'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='lookup_4_emb', feature_names=['lookup_4'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=100000, embedding_dim=16, name='match_2_emb', feature_names=['match_2'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=10000, embedding_dim=8, name='match_3_emb', feature_names=['match_3'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='match_4_emb', feature_names=['match_4'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
]
ebc = EmbeddingBagCollection(eb_configs2)
print(ebc(kt2))
from torchrec.inference.modules import quantize_embeddings
import torch
import torch.nn as nn
class EmbeddingGroupImpl(nn.Module):
def __init__(self,ebc):
super().__init__()
self.ebc=ebc
def forward(
self,
sparse_feature
):
self.ebc(sparse_feature)
a=EmbeddingGroupImpl(ebc=ebc)
a.forward(kt2)
quant_model = quantize_embeddings(a, dtype=torch.qint8, inplace=True)
print(quant_model(kt2))
Traceback (most recent call last):
File "/larec/tzrec/tests/test_per2.py", line 89, in <module>
print(quant_model(kt2))
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/larec/tzrec/tests/test_per2.py", line 83, in forward
self.ebc(sparse_feature)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torchrec/quant/embedding_modules.py", line 487, in forward
else emb_op.forward(
File "/opt/conda/lib/python3.10/site-packages/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py", line 764, in forward
torch.ops.fbgemm.bounds_check_indices(
File "/opt/conda/lib/python3.10/site-packages/torch/_ops.py", line 758, in __call__
return self._op(*args, **(kwargs or {}))
RuntimeError: offsets size 27 is not equal to B (1) * T (14) + 1