RUCAIBox / RecBole

A unified, comprehensive and efficient recommendation library

Home Page:https://recbole.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

运行srgnn时报错

cyxg7 opened this issue · comments

敬爱的工作者您好!我在运行srgnn时报错,猜测应该是main函数中trainer和interaction使用的是recbole而非recbole_gnn框架下的问题,但我不知道如何进行修改补充,辛苦您为我答疑解惑,期待您的回复,万分感谢!(前两天在GNN模块进行提问,可能是太忙了的原因没有小伙伴进行回复,故在这里进行再次提问,辛苦您为我解答~)

main函数:
from recbole_gnn.config import Config
from recbole_gnn.utils import create_dataset, data_preparation
from recbole.utils import init_logger, init_seed
from recbole_gnn.utils import set_color, get_trainer
from logging import getLogger

from test import SRGNN

if name == 'main':

configurations initialization

config = Config(
model=SRGNN,
dataset='diginetica',
config_file_list=['config.yaml', 'config_model.yaml'],
)
init_seed(config['seed'], config['reproducibility'])

logger initialization

init_logger(config)
logger = getLogger()

logger.info(config)

dataset filtering

dataset = create_dataset(config)
logger.info(dataset)

dataset splitting

train_data, valid_data, test_data = data_preparation(config, dataset)

model = SRGNN(config, train_data.dataset).to(config['device'])

logger.info(model)

trainer loading and initialization

trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)

trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)

model training

best_valid_score, best_valid_result = trainer.fit(
train_data, valid_data, saved=True, show_progress=config['show_progress']
)

model evaluation

test_result = trainer.evaluate(test_data, load_best_model=True, show_progress=config['show_progress'])

logger.info(set_color('best valid result:', 'yellow') + f': {best_valid_result}')
logger.info(set_color('test result:', 'yellow') + f': {test_result}')
config.yaml与config_model.yaml均使用框架中提供的参数。

运行结果:
General Hyper Parameters:
gpu_id = 0
use_gpu = True
seed = 2020
state = INFO
reproducibility = True
data_path = dataset/diginetica
checkpoint_dir = saved
show_progress = True
save_dataset = False
dataset_save_path = None
save_dataloaders = False
dataloaders_save_path = None
log_wandb = False

Training Hyper Parameters:
epochs = 500
train_batch_size = 4096
learner = adam
learning_rate = 0.001
neg_sampling = None
eval_step = 1
stopping_step = 10
clip_grad_norm = None
weight_decay = 0.0
loss_decimal_place = 4

Evaluation Hyper Parameters:
eval_args = {'split': {'LS': 'valid_and_test'}, 'mode': 'full', 'order': 'TO', 'group_by': 'user'}
repeatable = True
metrics = ['MRR', 'Precision']
topk = [10, 20]
valid_metric = MRR@10
valid_metric_bigger = True
eval_batch_size = 2000
metric_decimal_place = 5

Dataset Hyper Parameters:
field_separator =
seq_separator =
USER_ID_FIELD = session_id
ITEM_ID_FIELD = item_id
RATING_FIELD = rating
TIME_FIELD = timestamp
seq_len = None
LABEL_FIELD = label
threshold = None
NEG_PREFIX = neg_
load_col = {'inter': ['session_id', 'item_id', 'timestamp']}
unload_col = None
unused_col = None
additional_feat_suffix = None
rm_dup_inter = None
val_interval = None
filter_inter_by_user_or_item = True
user_inter_num_interval = [5,inf)
item_inter_num_interval = [5,inf)
alias_of_user_id = None
alias_of_item_id = None
alias_of_entity_id = None
alias_of_relation_id = None
preload_weight = None
normalize_field = None
normalize_all = None
ITEM_LIST_LENGTH_FIELD = item_length
LIST_SUFFIX = _list
MAX_ITEM_LIST_LENGTH = 20
POSITION_FIELD = position_id
HEAD_ENTITY_ID_FIELD = head_id
TAIL_ENTITY_ID_FIELD = tail_id
RELATION_ID_FIELD = relation_id
ENTITY_ID_FIELD = entity_id
benchmark_filename = None

Other Hyper Parameters:
wandb_project = recbole
require_pow = False
embedding_size = 64
step = 1
loss_type = CE
MODEL_TYPE = ModelType.SEQUENTIAL
gnn_transform = sess_graph
train_neg_sample_args = {'strategy': 'none'}
MODEL_INPUT_TYPE = InputType.POINTWISE
eval_type = EvaluatorType.RANKING
device = cpu
eval_neg_sample_args = {'strategy': 'full', 'distribution': 'uniform'}

06 Mar 13:17 INFO diginetica
The number of users: 72014
Average actions of users: 8.060905669809618
The number of items: 29454
Average actions of items: 19.70902794282416
The number of inters: 580490
The sparsity of the dataset: 99.97263260088765%
Remain Fields: ['session_id', 'item_id', 'timestamp']
06 Mar 13:17 INFO Constructing session graphs.
100%|██████████| 364451/364451 [00:33<00:00, 11034.37it/s]
06 Mar 13:18 INFO Constructing session graphs.
100%|██████████| 72013/72013 [00:07<00:00, 9464.61it/s]
06 Mar 13:18 INFO Constructing session graphs.
100%|██████████| 72013/72013 [00:07<00:00, 9047.17it/s]
06 Mar 13:18 INFO SessionGraph Transform in DataLoader.
06 Mar 13:18 INFO SessionGraph Transform in DataLoader.
06 Mar 13:18 INFO SessionGraph Transform in DataLoader.
06 Mar 13:18 INFO [Training]: train_batch_size = [4096] negative sampling: [{'strategy': 'none'}]
06 Mar 13:18 INFO [Evaluation]: eval_batch_size = [2000] eval_args: [{'split': {'LS': 'valid_and_test'}, 'mode': 'full', 'order': 'TO', 'group_by': 'user'}]
06 Mar 13:18 INFO SRGNN(
(item_embedding): Embedding(29454, 64, padding_idx=0)
(gnncell): SRGNNCell(
(incomming_conv): SRGNNConv()
(outcomming_conv): SRGNNConv()
(lin_ih): Linear(in_features=128, out_features=192, bias=True)
(lin_hh): Linear(in_features=64, out_features=192, bias=True)
)
(linear_one): Linear(in_features=64, out_features=64, bias=True)
(linear_two): Linear(in_features=64, out_features=64, bias=True)
(linear_three): Linear(in_features=64, out_features=1, bias=False)
(linear_transform): Linear(in_features=128, out_features=64, bias=True)
(loss_fct): CrossEntropyLoss()
)
Trainable parameters: 1947264
Train 0: 0%| | 0/89 [00:00<?, ?it/s]
Traceback (most recent call last):
File "E:/ADACONDA/envs/pytorch/pythonproject_test/Next Work/RecBole-GNN-main/main.py", line 41, in
best_valid_score, best_valid_result = trainer.fit(
File "E:\ADACONDA\envs\pytorch\lib\site-packages\recbole\trainer\trainer.py", line 335, in fit
train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress)
File "E:\ADACONDA\envs\pytorch\lib\site-packages\recbole\trainer\trainer.py", line 181, in _train_epoch
losses = loss_func(interaction)
File "E:\ADACONDA\envs\pytorch\pythonproject_test\Next Work\RecBole-GNN-main\test.py", line 105, in calculate_loss
x = interaction['x']
File "E:\ADACONDA\envs\pytorch\lib\site-packages\recbole\data\interaction.py", line 131, in getitem
return self.interaction[index]
KeyError: 'x'