Source code for

# @Time   : 2020/10/19
# @Author : Yupeng Hou
# @Email  :

# @Time   : 2021/7/9
# @Author : Yupeng Hou
# @Email  :


We only recommend building customized datasets by inheriting.

Customized datasets named ``[Model Name]Dataset`` can be automatically called.

import numpy as np
import torch

from import KGSeqDataset, SequentialDataset
from import Interaction
from recbole.sampler import SeqSampler
from recbole.utils.enum_type import FeatureType

[docs]class GRU4RecKGDataset(KGSeqDataset): def __init__(self, config): super().__init__(config)
[docs]class KSRDataset(KGSeqDataset): def __init__(self, config): super().__init__(config)
[docs]class DIENDataset(SequentialDataset): """:class:`DIENDataset` is based on :class:``. It is different from :class:`SequentialDataset` in `data_augmentation`. It add users' negative item list to interaction. The original version of sampling negative item list is implemented by Zhichao Feng ( in 2021/2/25, and he updated the codes in 2021/3/19. In 2021/7/9, Yupeng refactored SequentialDataset & SequentialDataLoader, then refactored DIENDataset, either. Attributes: augmentation (bool): Whether the interactions should be augmented in RecBole. seq_sample (recbole.sampler.SeqSampler): A sampler used to sample negative item sequence. neg_item_list_field (str): Field name for negative item sequence. neg_item_list (torch.tensor): all users' negative item history sequence. """ def __init__(self, config): super().__init__(config) list_suffix = config["LIST_SUFFIX"] neg_prefix = config["NEG_PREFIX"] self.seq_sampler = SeqSampler(self) self.neg_item_list_field = neg_prefix + self.iid_field + list_suffix self.neg_item_list = self.seq_sampler.sample_neg_sequence( self.inter_feat[self.iid_field] )
[docs] def data_augmentation(self): """Augmentation processing for sequential dataset. E.g., ``u1`` has purchase sequence ``<i1, i2, i3, i4>``, then after augmentation, we will generate three cases. ``u1, <i1> | i2`` (Which means given user_id ``u1`` and item_seq ``<i1>``, we need to predict the next item ``i2``.) The other cases are below: ``u1, <i1, i2> | i3`` ``u1, <i1, i2, i3> | i4`` """ self.logger.debug("data_augmentation") self._aug_presets() self._check_field("uid_field", "time_field") max_item_list_len = self.config["MAX_ITEM_LIST_LENGTH"] self.sort(by=[self.uid_field, self.time_field], ascending=True) last_uid = None uid_list, item_list_index, target_index, item_list_length = [], [], [], [] seq_start = 0 for i, uid in enumerate(self.inter_feat[self.uid_field].numpy()): if last_uid != uid: last_uid = uid seq_start = i else: if i - seq_start > max_item_list_len: seq_start += 1 uid_list.append(uid) item_list_index.append(slice(seq_start, i)) target_index.append(i) item_list_length.append(i - seq_start) uid_list = np.array(uid_list) item_list_index = np.array(item_list_index) target_index = np.array(target_index) item_list_length = np.array(item_list_length, dtype=np.int64) new_length = len(item_list_index) new_data = self.inter_feat[target_index] new_dict = { self.item_list_length_field: torch.tensor(item_list_length), } for field in self.inter_feat: if field != self.uid_field: list_field = getattr(self, f"{field}_list_field") list_len = self.field2seqlen[list_field] shape = ( (new_length, list_len) if isinstance(list_len, int) else (new_length,) + list_len ) if ( self.field2type[field] in [FeatureType.FLOAT, FeatureType.FLOAT_SEQ] and field in self.config["numerical_features"] ): shape += (2,) list_ftype = self.field2type[list_field] dtype = ( torch.int64 if list_ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ] else torch.float64 ) new_dict[list_field] = torch.zeros(shape, dtype=dtype) value = self.inter_feat[field] for i, (index, length) in enumerate( zip(item_list_index, item_list_length) ): new_dict[list_field][i][:length] = value[index] # DIEN if field == self.iid_field: new_dict[self.neg_item_list_field] = torch.zeros(shape, dtype=dtype) for i, (index, length) in enumerate( zip(item_list_index, item_list_length) ): new_dict[self.neg_item_list_field][i][ :length ] = self.neg_item_list[index] new_data.update(Interaction(new_dict)) self.inter_feat = new_data