# -*- coding: utf-8 -*-
# @Author : Yupeng Hou
# @Email : houyupeng@ruc.edu.cn
# @File : sampler.py
# UPDATE
# @Time : 2020/8/17, 2020/8/31, 2020/10/6, 2020/9/18
# @Author : Xingyu Pan, Kaiyuan Li, Yupeng Hou, Yushuo Chen
# @email : panxy@ruc.edu.cn, tsotfsk@outlook.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn
"""
recbole.sampler
########################
"""
import random
import copy
import numpy as np
[docs]class AbstractSampler(object):
""":class:`AbstractSampler` is a abstract class, all sampler should inherit from it. This sampler supports returning
a certain number of random value_ids according to the input key_id, and it also supports to prohibit
certain key-value pairs by setting used_ids. Besides, in order to improve efficiency, we use :attr:`random_pr`
to move around the :attr:`random_list` to generate random numbers, so we need to implement the
:meth:`get_random_list` method in the subclass.
Args:
distribution (str): The string of distribution, which is used for subclass.
Attributes:
random_list (list or numpy.ndarray): The shuffled result of :meth:`get_random_list`.
used_ids (numpy.ndarray): The result of :meth:`get_used_ids`.
"""
def __init__(self, distribution):
self.distribution = distribution
self.random_list = self.get_random_list()
random.shuffle(self.random_list)
self.random_pr = 0
self.random_list_length = len(self.random_list)
self.used_ids = self.get_used_ids()
[docs] def get_random_list(self):
"""
Returns:
np.ndarray or list: Random list of value_id.
"""
raise NotImplementedError('method [get_random_list] should be implemented')
[docs] def get_used_ids(self):
"""
Returns:
np.ndarray: Used ids. Index is key_id, and element is a set of value_ids.
"""
raise NotImplementedError('method [get_used_ids] should be implemented')
[docs] def random(self):
"""
Returns:
value_id (int): Random value_id. Generated by :attr:`random_list`.
"""
value_id = self.random_list[self.random_pr % self.random_list_length]
self.random_pr += 1
return value_id
[docs] def sample_by_key_ids(self, key_ids, num, used_ids):
"""Sampling by key_ids.
Args:
key_ids (np.ndarray or list): Input key_ids.
num (int): Number of sampled value_ids for each key_id.
used_ids (np.ndarray): Used ids. index is key_id, and element is a set of value_ids.
Returns:
np.ndarray: Sampled value_ids.
value_ids[0], value_ids[len(key_ids)], value_ids[len(key_ids) * 2], ..., value_id[len(key_ids) * (num - 1)]
is sampled for key_ids[0];
value_ids[1], value_ids[len(key_ids) + 1], value_ids[len(key_ids) * 2 + 1], ...,
value_id[len(key_ids) * (num - 1) + 1] is sampled for key_ids[1]; ...; and so on.
"""
key_num = len(key_ids)
total_num = key_num * num
value_ids = np.zeros(total_num, dtype=np.int64)
used_id_list = np.tile(used_ids, num)
for i, used_ids in enumerate(used_id_list):
cur = self.random()
while cur in used_ids:
cur = self.random()
value_ids[i] = cur
return value_ids
[docs]class Sampler(AbstractSampler):
""":class:`Sampler` is used to sample negative items for each input user. In order to avoid positive items
in train-phase to be sampled in vaild-phase, and positive items in train-phase or vaild-phase to be sampled
in test-phase, we need to input the datasets of all phases for pre-processing. And, before using this sampler,
it is needed to call :meth:`set_phase` to get the sampler of corresponding phase.
Args:
phases (str or list of str): All the phases of input.
datasets (Dataset or list of Dataset): All the dataset for each phase.
distribution (str, optional): Distribution of the negative items. Defaults to 'uniform'.
Attributes:
phase (str): the phase of sampler. It will not be set until :meth:`set_phase` is called.
"""
def __init__(self, phases, datasets, distribution='uniform'):
if not isinstance(phases, list):
phases = [phases]
if not isinstance(datasets, list):
datasets = [datasets]
if len(phases) != len(datasets):
raise ValueError('phases {} and datasets {} should have the same length'.format(phases, datasets))
self.phases = phases
self.datasets = datasets
self.uid_field = datasets[0].uid_field
self.iid_field = datasets[0].iid_field
self.n_users = datasets[0].user_num
self.n_items = datasets[0].item_num
super().__init__(distribution=distribution)
[docs] def get_random_list(self):
"""
Returns:
np.ndarray or list: Random list of item_id.
"""
if self.distribution == 'uniform':
return list(range(1, self.n_items))
elif self.distribution == 'popularity':
random_item_list = []
for dataset in self.datasets:
random_item_list.extend(dataset.inter_feat[self.iid_field].values)
return random_item_list
else:
raise NotImplementedError('Distribution [{}] has not been implemented'.format(self.distribution))
[docs] def get_used_ids(self):
"""
Returns:
dict: Used item_ids is the same as positive item_ids.
Key is phase, and value is a np.ndarray which index is user_id, and element is a set of item_ids.
"""
used_item_id = dict()
last = [set() for i in range(self.n_users)]
for phase, dataset in zip(self.phases, self.datasets):
cur = np.array([set(s) for s in last])
for uid, iid in dataset.inter_feat[[self.uid_field, self.iid_field]].values:
cur[uid].add(iid)
last = used_item_id[phase] = cur
return used_item_id
[docs] def set_phase(self, phase):
"""Get the sampler of corresponding phase.
Args:
phase (str): The phase of new sampler.
Returns:
Sampler: the copy of this sampler, :attr:`phase` is set the same as input phase, and :attr:`used_ids`
is set to the value of corresponding phase.
"""
if phase not in self.phases:
raise ValueError('phase [{}] not exist'.format(phase))
new_sampler = copy.copy(self)
new_sampler.phase = phase
new_sampler.used_ids = new_sampler.used_ids[phase]
return new_sampler
[docs] def sample_by_user_ids(self, user_ids, num):
"""Sampling by user_ids.
Args:
user_ids (np.ndarray or list): Input user_ids.
num (int): Number of sampled item_ids for each user_id.
Returns:
np.ndarray: Sampled item_ids.
item_ids[0], item_ids[len(user_ids)], item_ids[len(user_ids) * 2], ..., item_id[len(user_ids) * (num - 1)]
is sampled for user_ids[0];
item_ids[1], item_ids[len(user_ids) + 1], item_ids[len(user_ids) * 2 + 1], ...,
item_id[len(user_ids) * (num - 1) + 1] is sampled for user_ids[1]; ...; and so on.
"""
try:
return self.sample_by_key_ids(user_ids, num, self.used_ids[user_ids])
except IndexError:
for user_id in user_ids:
if user_id < 0 or user_id >= self.n_users:
raise ValueError('user_id [{}] not exist'.format(user_id))
[docs]class KGSampler(AbstractSampler):
""":class:`KGSampler` is used to sample negative entities in a knowledge graph.
Args:
dataset (Dataset): The knowledge graph dataset, which contains triplets in a knowledge graph.
distribution (str, optional): Distribution of the negative entities. Defaults to 'uniform'.
"""
def __init__(self, dataset, distribution='uniform'):
self.dataset = dataset
self.hid_field = dataset.head_entity_field
self.tid_field = dataset.tail_entity_field
self.hid_list = dataset.head_entities
self.tid_list = dataset.tail_entities
self.head_entities = set(dataset.head_entities)
self.entity_num = dataset.entity_num
super().__init__(distribution=distribution)
[docs] def get_random_list(self):
"""
Returns:
np.ndarray or list: Random list of entity_id.
"""
if self.distribution == 'uniform':
return list(range(1, self.entity_num))
elif self.distribution == 'popularity':
return list(self.hid_list) + list(self.tid_list)
else:
raise NotImplementedError('Distribution [{}] has not been implemented'.format(self.distribution))
[docs] def get_used_ids(self):
"""
Returns:
np.ndarray: Used entity_ids is the same as tail_entity_ids in knowledge graph.
Index is head_entity_id, and element is a set of tail_entity_ids.
"""
used_tail_entity_id = np.array([set() for i in range(self.entity_num)])
for hid, tid in zip(self.hid_list, self.tid_list):
used_tail_entity_id[hid].add(tid)
return used_tail_entity_id
[docs] def sample_by_entity_ids(self, head_entity_ids, num=1):
"""Sampling by head_entity_ids.
Args:
head_entity_ids (np.ndarray or list): Input head_entity_ids.
num (int, optional): Number of sampled entity_ids for each head_entity_id. Defaults to ``1``.
Returns:
np.ndarray: Sampled entity_ids.
entity_ids[0], entity_ids[len(head_entity_ids)], entity_ids[len(head_entity_ids) * 2], ...,
entity_id[len(head_entity_ids) * (num - 1)] is sampled for head_entity_ids[0];
entity_ids[1], entity_ids[len(head_entity_ids) + 1], entity_ids[len(head_entity_ids) * 2 + 1], ...,
entity_id[len(head_entity_ids) * (num - 1) + 1] is sampled for head_entity_ids[1]; ...; and so on.
"""
try:
return self.sample_by_key_ids(head_entity_ids, num, self.used_ids[head_entity_ids])
except IndexError:
for head_entity_id in head_entity_ids:
if head_entity_id not in self.head_entities:
raise ValueError('head_entity_id [{}] not exist'.format(head_entity_id))
[docs]class RepeatableSampler(AbstractSampler):
""":class:`RepeatableSampler` is used to sample negative items for each input user. The difference from
:class:`Sampler` is it can only sampling the items that have not appeared at all phases.
Args:
phases (str or list of str): All the phases of input.
dataset (Dataset): The union of all datasets for each phase.
distribution (str, optional): Distribution of the negative items. Defaults to 'uniform'.
Attributes:
phase (str): the phase of sampler. It will not be set until :meth:`set_phase` is called.
"""
def __init__(self, phases, dataset, distribution='uniform'):
if not isinstance(phases, list):
phases = [phases]
self.phases = phases
self.dataset = dataset
self.iid_field = dataset.iid_field
self.user_num = dataset.user_num
self.item_num = dataset.item_num
super().__init__(distribution=distribution)
[docs] def get_random_list(self):
"""
Returns:
np.ndarray or list: Random list of item_id.
"""
if self.distribution == 'uniform':
return list(range(1, self.item_num))
elif self.distribution == 'popularity':
return self.dataset.inter_feat[self.iid_field].values
else:
raise NotImplementedError('Distribution [{}] has not been implemented'.format(self.distribution))
[docs] def get_used_ids(self):
"""
Returns:
np.ndarray: Used item_ids is the same as positive item_ids.
Index is user_id, and element is a set of item_ids.
"""
return np.array([set() for i in range(self.user_num)])
[docs] def sample_by_user_ids(self, user_ids, num):
"""Sampling by user_ids.
Args:
user_ids (np.ndarray or list): Input user_ids.
num (int): Number of sampled item_ids for each user_id.
Returns:
np.ndarray: Sampled item_ids.
item_ids[0], item_ids[len(user_ids)], item_ids[len(user_ids) * 2], ..., item_id[len(user_ids) * (num - 1)]
is sampled for user_ids[0];
item_ids[1], item_ids[len(user_ids) + 1], item_ids[len(user_ids) * 2 + 1], ...,
item_id[len(user_ids) * (num - 1) + 1] is sampled for user_ids[1]; ...; and so on.
"""
try:
return self.sample_by_key_ids(user_ids, num, self.used_ids[user_ids])
except IndexError:
for user_id in user_ids:
if user_id < 0 or user_id >= self.n_users:
raise ValueError('user_id [{}] not exist'.format(user_id))
[docs] def set_phase(self, phase):
"""Get the sampler of corresponding phase.
Args:
phase (str): The phase of new sampler.
Returns:
Sampler: the copy of this sampler, and :attr:`phase` is set the same as input phase.
"""
if phase not in self.phases:
raise ValueError('phase [{}] not exist'.format(phase))
new_sampler = copy.copy(self)
new_sampler.phase = phase
return new_sampler