Source code for recbole.data.dataloader.knowledge_dataloader

# @Time   : 2020/7/7
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

# UPDATE
# @Time   : 2020/9/18, 2020/9/21, 2020/8/31
# @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li
# @email  : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, tsotfsk@outlook.com

"""
recbole.data.dataloader.knowledge_dataloader
################################################
"""

from recbole.data.dataloader import AbstractDataLoader, GeneralNegSampleDataLoader
from recbole.utils import InputType, KGDataLoaderState


[docs]class KGDataLoader(AbstractDataLoader): """:class:`KGDataLoader` is a dataloader which would return the triplets with negative examples in a knowledge graph. Args: config (Config): The config of dataloader. dataset (Dataset): The dataset of dataloader. sampler (KGSampler): The knowledge graph sampler of dataloader. batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. dl_format (InputType, optional): The input type of dataloader. Defaults to :obj:`~recbole.utils.InputType.PAIRWISE`. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. Attributes: shuffle (bool): Whether the dataloader will be shuffle after a round. However, in :class:`KGDataLoader`, it's guaranteed to be ``True``. """ def __init__(self, config, dataset, sampler, batch_size=1, dl_format=InputType.PAIRWISE, shuffle=False): self.sampler = sampler self.neg_sample_num = 1 self.neg_prefix = config['NEG_PREFIX'] self.hid_field = dataset.head_entity_field self.tid_field = dataset.tail_entity_field # kg negative cols self.neg_tid_field = self.neg_prefix + self.tid_field dataset.copy_field_property(self.neg_tid_field, self.tid_field) super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
[docs] def setup(self): """Make sure that the :attr:`shuffle` is True. If :attr:`shuffle` is False, it will be changed to True and give a warning to user. """ if self.shuffle is False: self.shuffle = True self.logger.warning('kg based dataloader must shuffle the data')
@property def pr_end(self): return len(self.dataset.kg_feat) def _shuffle(self): self.dataset.kg_feat = self.dataset.kg_feat.sample(frac=1).reset_index(drop=True) def _next_batch_data(self): cur_data = self.dataset.kg_feat[self.pr: self.pr + self.step] self.pr += self.step if self.real_time: cur_data = self._neg_sampling(cur_data) return self._dataframe_to_interaction(cur_data)
[docs] def data_preprocess(self): """Do neg-sampling before training/evaluation. """ self.dataset.kg_feat = self._neg_sampling(self.dataset.kg_feat)
def _neg_sampling(self, kg_feat): hids = kg_feat[self.hid_field].to_list() neg_tids = self.sampler.sample_by_entity_ids(hids, self.neg_sample_num) kg_feat.insert(len(kg_feat.columns), self.neg_tid_field, neg_tids) return kg_feat
[docs]class KnowledgeBasedDataLoader(AbstractDataLoader): """:class:`KnowledgeBasedDataLoader` is used for knowledge based model. It has three states, which is saved in :attr:`state`. In different states, :meth:`~_next_batch_data` will return different :class:`~recbole.data.interaction.Interaction`. Detailed, please see :attr:`~state`. Args: config (Config): The config of dataloader. dataset (Dataset): The dataset of dataloader. sampler (Sampler): The sampler of dataloader. kg_sampler (KGSampler): The knowledge graph sampler of dataloader. neg_sample_args (dict): The neg_sample_args of dataloader. batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. dl_format (InputType, optional): The input type of dataloader. Defaults to :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. Attributes: state (KGDataLoaderState): This dataloader has three states: - :obj:`~recbole.utils.enum_type.KGDataLoaderState.RS` - :obj:`~recbole.utils.enum_type.KGDataLoaderState.KG` - :obj:`~recbole.utils.enum_type.KGDataLoaderState.RSKG` In the first state, this dataloader would only return the triplets with negative examples in a knowledge graph. In the second state, this dataloader would only return the user-item interaction. In the last state, this dataloader would return both knowledge graph information and user-item interaction information. """ def __init__(self, config, dataset, sampler, kg_sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False): # using sampler self.general_dataloader = GeneralNegSampleDataLoader(config=config, dataset=dataset, sampler=sampler, neg_sample_args=neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle) # using kg_sampler and dl_format is pairwise self.kg_dataloader = KGDataLoader(config, dataset, kg_sampler, batch_size=batch_size, dl_format=InputType.PAIRWISE, shuffle=shuffle) self.main_dataloader = self.general_dataloader super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle) @property def pr(self): """Pointer of :class:`KnowledgeBasedDataLoader`. It would be affect by self.state. """ return self.main_dataloader.pr @pr.setter def pr(self, value): self.main_dataloader.pr = value def __iter__(self): if not hasattr(self, 'state') or not hasattr(self, 'main_dataloader'): raise ValueError('The dataloader\'s state and main_dataloader must be set ' 'when using the kg based dataloader') return super().__iter__() def _shuffle(self): if self.state == KGDataLoaderState.RSKG: self.general_dataloader._shuffle() self.kg_dataloader._shuffle() else: self.main_dataloader._shuffle() def __next__(self): if self.pr >= self.pr_end: if self.state == KGDataLoaderState.RSKG: self.general_dataloader.pr = 0 self.kg_dataloader.pr = 0 else: self.pr = 0 raise StopIteration() return self._next_batch_data() def __len__(self): return len(self.main_dataloader) @property def pr_end(self): return self.main_dataloader.pr_end def _next_batch_data(self): if self.state == KGDataLoaderState.KG: return self.kg_dataloader._next_batch_data() elif self.state == KGDataLoaderState.RS: return self.general_dataloader._next_batch_data() elif self.state == KGDataLoaderState.RSKG: if self.kg_dataloader.pr >= self.kg_dataloader.pr_end: self.kg_dataloader.pr = 0 kg_data = self.kg_dataloader._next_batch_data() rec_data = self.general_dataloader._next_batch_data() rec_data.update(kg_data) return rec_data
[docs] def set_mode(self, state): """Set the mode of :class:`KnowledgeBasedDataLoader`, it can be set to three states: - KGDataLoaderState.RS - KGDataLoaderState.KG - KGDataLoaderState.RSKG The state of :class:`KnowledgeBasedDataLoader` would affect the result of _next_batch_data(). Args: state (KGDataLoaderState): the state of :class:`KnowledgeBasedDataLoader`. """ if state not in set(KGDataLoaderState): raise NotImplementedError('kg data loader has no state named [{}]'.format(self.state)) self.state = state if self.state == KGDataLoaderState.RS: self.main_dataloader = self.general_dataloader elif self.state == KGDataLoaderState.KG: self.main_dataloader = self.kg_dataloader else: # RSKG kgpr = self.kg_dataloader.pr_end rspr = self.general_dataloader.pr_end self.main_dataloader = self.general_dataloader if rspr < kgpr else self.kg_dataloader