Source code for recbole.sampler.sampler

# -*- coding: utf-8 -*-
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn
# @File   : sampler.py

# UPDATE
# @Time   : 2020/8/17, 2020/8/31, 2020/10/6, 2020/9/18
# @Author : Xingyu Pan, Kaiyuan Li, Yupeng Hou, Yushuo Chen
# @email  : panxy@ruc.edu.cn, tsotfsk@outlook.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn

"""
recbole.sampler
########################
"""

import random
import copy
import numpy as np


[docs]class AbstractSampler(object): """:class:`AbstractSampler` is a abstract class, all sampler should inherit from it. This sampler supports returning a certain number of random value_ids according to the input key_id, and it also supports to prohibit certain key-value pairs by setting used_ids. Besides, in order to improve efficiency, we use :attr:`random_pr` to move around the :attr:`random_list` to generate random numbers, so we need to implement the :meth:`get_random_list` method in the subclass. Args: distribution (str): The string of distribution, which is used for subclass. Attributes: random_list (list or numpy.ndarray): The shuffled result of :meth:`get_random_list`. used_ids (numpy.ndarray): The result of :meth:`get_used_ids`. """ def __init__(self, distribution): self.distribution = distribution self.random_list = self.get_random_list() random.shuffle(self.random_list) self.random_pr = 0 self.random_list_length = len(self.random_list) self.used_ids = self.get_used_ids()
[docs] def get_random_list(self): """ Returns: np.ndarray or list: Random list of value_id. """ raise NotImplementedError('method [get_random_list] should be implemented')
[docs] def get_used_ids(self): """ Returns: np.ndarray: Used ids. Index is key_id, and element is a set of value_ids. """ raise NotImplementedError('method [get_used_ids] should be implemented')
[docs] def random(self): """ Returns: value_id (int): Random value_id. Generated by :attr:`random_list`. """ value_id = self.random_list[self.random_pr % self.random_list_length] self.random_pr += 1 return value_id
[docs] def sample_by_key_ids(self, key_ids, num, used_ids): """Sampling by key_ids. Args: key_ids (np.ndarray or list): Input key_ids. num (int): Number of sampled value_ids for each key_id. used_ids (np.ndarray): Used ids. index is key_id, and element is a set of value_ids. Returns: np.ndarray: Sampled value_ids. value_ids[0], value_ids[len(key_ids)], value_ids[len(key_ids) * 2], ..., value_id[len(key_ids) * (num - 1)] is sampled for key_ids[0]; value_ids[1], value_ids[len(key_ids) + 1], value_ids[len(key_ids) * 2 + 1], ..., value_id[len(key_ids) * (num - 1) + 1] is sampled for key_ids[1]; ...; and so on. """ key_num = len(key_ids) total_num = key_num * num value_ids = np.zeros(total_num, dtype=np.int64) used_id_list = np.tile(used_ids, num) for i, used_ids in enumerate(used_id_list): cur = self.random() while cur in used_ids: cur = self.random() value_ids[i] = cur return value_ids
[docs]class Sampler(AbstractSampler): """:class:`Sampler` is used to sample negative items for each input user. In order to avoid positive items in train-phase to be sampled in vaild-phase, and positive items in train-phase or vaild-phase to be sampled in test-phase, we need to input the datasets of all phases for pre-processing. And, before using this sampler, it is needed to call :meth:`set_phase` to get the sampler of corresponding phase. Args: phases (str or list of str): All the phases of input. datasets (Dataset or list of Dataset): All the dataset for each phase. distribution (str, optional): Distribution of the negative items. Defaults to 'uniform'. Attributes: phase (str): the phase of sampler. It will not be set until :meth:`set_phase` is called. """ def __init__(self, phases, datasets, distribution='uniform'): if not isinstance(phases, list): phases = [phases] if not isinstance(datasets, list): datasets = [datasets] if len(phases) != len(datasets): raise ValueError('phases {} and datasets {} should have the same length'.format(phases, datasets)) self.phases = phases self.datasets = datasets self.uid_field = datasets[0].uid_field self.iid_field = datasets[0].iid_field self.n_users = datasets[0].user_num self.n_items = datasets[0].item_num super().__init__(distribution=distribution)
[docs] def get_random_list(self): """ Returns: np.ndarray or list: Random list of item_id. """ if self.distribution == 'uniform': return list(range(1, self.n_items)) elif self.distribution == 'popularity': random_item_list = [] for dataset in self.datasets: random_item_list.extend(dataset.inter_feat[self.iid_field].values) return random_item_list else: raise NotImplementedError('Distribution [{}] has not been implemented'.format(self.distribution))
[docs] def get_used_ids(self): """ Returns: dict: Used item_ids is the same as positive item_ids. Key is phase, and value is a np.ndarray which index is user_id, and element is a set of item_ids. """ used_item_id = dict() last = [set() for i in range(self.n_users)] for phase, dataset in zip(self.phases, self.datasets): cur = np.array([set(s) for s in last]) for uid, iid in dataset.inter_feat[[self.uid_field, self.iid_field]].values: cur[uid].add(iid) last = used_item_id[phase] = cur return used_item_id
[docs] def set_phase(self, phase): """Get the sampler of corresponding phase. Args: phase (str): The phase of new sampler. Returns: Sampler: the copy of this sampler, :attr:`phase` is set the same as input phase, and :attr:`used_ids` is set to the value of corresponding phase. """ if phase not in self.phases: raise ValueError('phase [{}] not exist'.format(phase)) new_sampler = copy.copy(self) new_sampler.phase = phase new_sampler.used_ids = new_sampler.used_ids[phase] return new_sampler
[docs] def sample_by_user_ids(self, user_ids, num): """Sampling by user_ids. Args: user_ids (np.ndarray or list): Input user_ids. num (int): Number of sampled item_ids for each user_id. Returns: np.ndarray: Sampled item_ids. item_ids[0], item_ids[len(user_ids)], item_ids[len(user_ids) * 2], ..., item_id[len(user_ids) * (num - 1)] is sampled for user_ids[0]; item_ids[1], item_ids[len(user_ids) + 1], item_ids[len(user_ids) * 2 + 1], ..., item_id[len(user_ids) * (num - 1) + 1] is sampled for user_ids[1]; ...; and so on. """ try: return self.sample_by_key_ids(user_ids, num, self.used_ids[user_ids]) except IndexError: for user_id in user_ids: if user_id < 0 or user_id >= self.n_users: raise ValueError('user_id [{}] not exist'.format(user_id))
[docs]class KGSampler(AbstractSampler): """:class:`KGSampler` is used to sample negative entities in a knowledge graph. Args: dataset (Dataset): The knowledge graph dataset, which contains triplets in a knowledge graph. distribution (str, optional): Distribution of the negative entities. Defaults to 'uniform'. """ def __init__(self, dataset, distribution='uniform'): self.dataset = dataset self.hid_field = dataset.head_entity_field self.tid_field = dataset.tail_entity_field self.hid_list = dataset.head_entities self.tid_list = dataset.tail_entities self.head_entities = set(dataset.head_entities) self.entity_num = dataset.entity_num super().__init__(distribution=distribution)
[docs] def get_random_list(self): """ Returns: np.ndarray or list: Random list of entity_id. """ if self.distribution == 'uniform': return list(range(1, self.entity_num)) elif self.distribution == 'popularity': return list(self.hid_list) + list(self.tid_list) else: raise NotImplementedError('Distribution [{}] has not been implemented'.format(self.distribution))
[docs] def get_used_ids(self): """ Returns: np.ndarray: Used entity_ids is the same as tail_entity_ids in knowledge graph. Index is head_entity_id, and element is a set of tail_entity_ids. """ used_tail_entity_id = np.array([set() for i in range(self.entity_num)]) for hid, tid in zip(self.hid_list, self.tid_list): used_tail_entity_id[hid].add(tid) return used_tail_entity_id
[docs] def sample_by_entity_ids(self, head_entity_ids, num=1): """Sampling by head_entity_ids. Args: head_entity_ids (np.ndarray or list): Input head_entity_ids. num (int, optional): Number of sampled entity_ids for each head_entity_id. Defaults to ``1``. Returns: np.ndarray: Sampled entity_ids. entity_ids[0], entity_ids[len(head_entity_ids)], entity_ids[len(head_entity_ids) * 2], ..., entity_id[len(head_entity_ids) * (num - 1)] is sampled for head_entity_ids[0]; entity_ids[1], entity_ids[len(head_entity_ids) + 1], entity_ids[len(head_entity_ids) * 2 + 1], ..., entity_id[len(head_entity_ids) * (num - 1) + 1] is sampled for head_entity_ids[1]; ...; and so on. """ try: return self.sample_by_key_ids(head_entity_ids, num, self.used_ids[head_entity_ids]) except IndexError: for head_entity_id in head_entity_ids: if head_entity_id not in self.head_entities: raise ValueError('head_entity_id [{}] not exist'.format(head_entity_id))
[docs]class RepeatableSampler(AbstractSampler): """:class:`RepeatableSampler` is used to sample negative items for each input user. The difference from :class:`Sampler` is it can only sampling the items that have not appeared at all phases. Args: phases (str or list of str): All the phases of input. dataset (Dataset): The union of all datasets for each phase. distribution (str, optional): Distribution of the negative items. Defaults to 'uniform'. Attributes: phase (str): the phase of sampler. It will not be set until :meth:`set_phase` is called. """ def __init__(self, phases, dataset, distribution='uniform'): if not isinstance(phases, list): phases = [phases] self.phases = phases self.dataset = dataset self.iid_field = dataset.iid_field self.user_num = dataset.user_num self.item_num = dataset.item_num super().__init__(distribution=distribution)
[docs] def get_random_list(self): """ Returns: np.ndarray or list: Random list of item_id. """ if self.distribution == 'uniform': return list(range(1, self.item_num)) elif self.distribution == 'popularity': return self.dataset.inter_feat[self.iid_field].values else: raise NotImplementedError('Distribution [{}] has not been implemented'.format(self.distribution))
[docs] def get_used_ids(self): """ Returns: np.ndarray: Used item_ids is the same as positive item_ids. Index is user_id, and element is a set of item_ids. """ return np.array([set() for i in range(self.user_num)])
[docs] def sample_by_user_ids(self, user_ids, num): """Sampling by user_ids. Args: user_ids (np.ndarray or list): Input user_ids. num (int): Number of sampled item_ids for each user_id. Returns: np.ndarray: Sampled item_ids. item_ids[0], item_ids[len(user_ids)], item_ids[len(user_ids) * 2], ..., item_id[len(user_ids) * (num - 1)] is sampled for user_ids[0]; item_ids[1], item_ids[len(user_ids) + 1], item_ids[len(user_ids) * 2 + 1], ..., item_id[len(user_ids) * (num - 1) + 1] is sampled for user_ids[1]; ...; and so on. """ try: return self.sample_by_key_ids(user_ids, num, self.used_ids[user_ids]) except IndexError: for user_id in user_ids: if user_id < 0 or user_id >= self.n_users: raise ValueError('user_id [{}] not exist'.format(user_id))
[docs] def set_phase(self, phase): """Get the sampler of corresponding phase. Args: phase (str): The phase of new sampler. Returns: Sampler: the copy of this sampler, and :attr:`phase` is set the same as input phase. """ if phase not in self.phases: raise ValueError('phase [{}] not exist'.format(phase)) new_sampler = copy.copy(self) new_sampler.phase = phase return new_sampler