Source code for recbole.data.dataset.customized_dataset

# @Time   : 2020/10/19
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

# UPDATE
# @Time   : 2021/7/9
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

"""
recbole.data.customized_dataset
##################################

We only recommend building customized datasets by inheriting.

Customized datasets named ``[Model Name]Dataset`` can be automatically called.
"""

import numpy as np
import torch

from recbole.data.dataset import KGSeqDataset, SequentialDataset
from recbole.data.interaction import Interaction
from recbole.sampler import SeqSampler
from recbole.utils.enum_type import FeatureType


[docs]class GRU4RecKGDataset(KGSeqDataset): def __init__(self, config): super().__init__(config)
[docs]class KSRDataset(KGSeqDataset): def __init__(self, config): super().__init__(config)
[docs]class DIENDataset(SequentialDataset): """:class:`DIENDataset` is based on :class:`~recbole.data.dataset.sequential_dataset.SequentialDataset`. It is different from :class:`SequentialDataset` in `data_augmentation`. It add users' negative item list to interaction. The original version of sampling negative item list is implemented by Zhichao Feng (fzcbupt@gmail.com) in 2021/2/25, and he updated the codes in 2021/3/19. In 2021/7/9, Yupeng refactored SequentialDataset & SequentialDataLoader, then refactored DIENDataset, either. Attributes: augmentation (bool): Whether the interactions should be augmented in RecBole. seq_sample (recbole.sampler.SeqSampler): A sampler used to sample negative item sequence. neg_item_list_field (str): Field name for negative item sequence. neg_item_list (torch.tensor): all users' negative item history sequence. """ def __init__(self, config): super().__init__(config) list_suffix = config['LIST_SUFFIX'] neg_prefix = config['NEG_PREFIX'] self.seq_sampler = SeqSampler(self) self.neg_item_list_field = neg_prefix + self.iid_field + list_suffix self.neg_item_list = self.seq_sampler.sample_neg_sequence(self.inter_feat[self.iid_field])
[docs] def data_augmentation(self): """Augmentation processing for sequential dataset. E.g., ``u1`` has purchase sequence ``<i1, i2, i3, i4>``, then after augmentation, we will generate three cases. ``u1, <i1> | i2`` (Which means given user_id ``u1`` and item_seq ``<i1>``, we need to predict the next item ``i2``.) The other cases are below: ``u1, <i1, i2> | i3`` ``u1, <i1, i2, i3> | i4`` """ self.logger.debug('data_augmentation') self._aug_presets() self._check_field('uid_field', 'time_field') max_item_list_len = self.config['MAX_ITEM_LIST_LENGTH'] self.sort(by=[self.uid_field, self.time_field], ascending=True) last_uid = None uid_list, item_list_index, target_index, item_list_length = [], [], [], [] seq_start = 0 for i, uid in enumerate(self.inter_feat[self.uid_field].numpy()): if last_uid != uid: last_uid = uid seq_start = i else: if i - seq_start > max_item_list_len: seq_start += 1 uid_list.append(uid) item_list_index.append(slice(seq_start, i)) target_index.append(i) item_list_length.append(i - seq_start) uid_list = np.array(uid_list) item_list_index = np.array(item_list_index) target_index = np.array(target_index) item_list_length = np.array(item_list_length, dtype=np.int64) new_length = len(item_list_index) new_data = self.inter_feat[target_index] new_dict = { self.item_list_length_field: torch.tensor(item_list_length), } for field in self.inter_feat: if field != self.uid_field: list_field = getattr(self, f'{field}_list_field') list_len = self.field2seqlen[list_field] shape = (new_length, list_len) if isinstance(list_len, int) else (new_length,) + list_len list_ftype = self.field2type[list_field] dtype = torch.int64 if list_ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ] else torch.float64 new_dict[list_field] = torch.zeros(shape, dtype=dtype) value = self.inter_feat[field] for i, (index, length) in enumerate(zip(item_list_index, item_list_length)): new_dict[list_field][i][:length] = value[index] # DIEN if field == self.iid_field: new_dict[self.neg_item_list_field] = torch.zeros(shape, dtype=dtype) for i, (index, length) in enumerate(zip(item_list_index, item_list_length)): new_dict[self.neg_item_list_field][i][:length] = self.neg_item_list[index] new_data.update(Interaction(new_dict)) self.inter_feat = new_data