Source code for recbole.sampler.sampler

# -*- coding: utf-8 -*-
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn
# @File   : sampler.py

# UPDATE
# @Time   : 2021/7/23, 2020/8/31, 2020/10/6, 2020/9/18, 2021/3/19
# @Author : Xingyu Pan, Kaiyuan Li, Yupeng Hou, Yushuo Chen, Zhichao Feng
# @email  : xy_pan@foxmail.com, tsotfsk@outlook.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, fzcbupt@gmail.com

"""
recbole.sampler
########################
"""

import copy

import numpy as np
from numpy.random import sample
import torch
from collections import Counter


[docs]class AbstractSampler(object): """:class:`AbstractSampler` is a abstract class, all sampler should inherit from it. This sampler supports returning a certain number of random value_ids according to the input key_id, and it also supports to prohibit certain key-value pairs by setting used_ids. Args: distribution (str): The string of distribution, which is used for subclass. Attributes: used_ids (numpy.ndarray): The result of :meth:`get_used_ids`. """ def __init__(self, distribution): self.distribution = '' self.set_distribution(distribution) self.used_ids = self.get_used_ids()
[docs] def set_distribution(self, distribution): """Set the distribution of sampler. Args: distribution (str): Distribution of the negative items. """ self.distribution = distribution if distribution == 'popularity': self._build_alias_table()
def _uni_sampling(self, sample_num): """Sample [sample_num] items in the uniform distribution. Args: sample_num (int): the number of samples. Returns: sample_list (np.array): a list of samples. """ raise NotImplementedError('Method [_uni_sampling] should be implemented') def _get_candidates_list(self): """Get sample candidates list for _pop_sampling() Returns: candidates_list (list): a list of candidates id. """ raise NotImplementedError('Method [_get_candidates_list] should be implemented') def _build_alias_table(self): """Build alias table for popularity_biased sampling. """ candidates_list = self._get_candidates_list() self.prob = dict(Counter(candidates_list)) self.alias = self.prob.copy() large_q = [] small_q = [] for i in self.prob: self.alias[i] = -1 self.prob[i] = self.prob[i] / len(candidates_list) * len(self.prob) if self.prob[i] > 1: large_q.append(i) elif self.prob[i] < 1: small_q.append(i) while len(large_q) != 0 and len(small_q) != 0: l = large_q.pop(0) s = small_q.pop(0) self.alias[s] = l self.prob[l] = self.prob[l] - (1 - self.prob[s]) if self.prob[l] < 1: small_q.append(l) elif self.prob[l] > 1: large_q.append(l) def _pop_sampling(self, sample_num): """Sample [sample_num] items in the popularity-biased distribution. Args: sample_num (int): the number of samples. Returns: sample_list (np.array): a list of samples. """ keys = list(self.prob.keys()) random_index_list = np.random.randint(0, len(keys), sample_num) random_prob_list = np.random.random(sample_num) final_random_list = [] for idx, prob in zip(random_index_list, random_prob_list): if self.prob[keys[idx]] > prob: final_random_list.append(keys[idx]) else: final_random_list.append(self.alias[keys[idx]]) return np.array(final_random_list)
[docs] def sampling(self, sample_num): """Sampling [sample_num] item_ids. Args: sample_num (int): the number of samples. Returns: sample_list (np.array): a list of samples and the len is [sample_num]. """ if self.distribution == 'uniform': return self._uni_sampling(sample_num) elif self.distribution == 'popularity': return self._pop_sampling(sample_num) else: raise NotImplementedError(f'The sampling distribution [{self.distribution}] is not implemented.')
[docs] def get_used_ids(self): """ Returns: numpy.ndarray: Used ids. Index is key_id, and element is a set of value_ids. """ raise NotImplementedError('Method [get_used_ids] should be implemented')
[docs] def sample_by_key_ids(self, key_ids, num): """Sampling by key_ids. Args: key_ids (numpy.ndarray or list): Input key_ids. num (int): Number of sampled value_ids for each key_id. Returns: torch.tensor: Sampled value_ids. value_ids[0], value_ids[len(key_ids)], value_ids[len(key_ids) * 2], ..., value_id[len(key_ids) * (num - 1)] is sampled for key_ids[0]; value_ids[1], value_ids[len(key_ids) + 1], value_ids[len(key_ids) * 2 + 1], ..., value_id[len(key_ids) * (num - 1) + 1] is sampled for key_ids[1]; ...; and so on. """ key_ids = np.array(key_ids) key_num = len(key_ids) total_num = key_num * num if (key_ids == key_ids[0]).all(): key_id = key_ids[0] used = np.array(list(self.used_ids[key_id])) value_ids = self.sampling(total_num) check_list = np.arange(total_num)[np.isin(value_ids, used)] while len(check_list) > 0: value_ids[check_list] = value = self.sampling(len(check_list)) mask = np.isin(value, used) check_list = check_list[mask] else: value_ids = np.zeros(total_num, dtype=np.int64) check_list = np.arange(total_num) key_ids = np.tile(key_ids, num) while len(check_list) > 0: value_ids[check_list] = self.sampling(len(check_list)) check_list = np.array([ i for i, used, v in zip(check_list, self.used_ids[key_ids[check_list]], value_ids[check_list]) if v in used ]) return torch.tensor(value_ids)
[docs]class Sampler(AbstractSampler): """:class:`Sampler` is used to sample negative items for each input user. In order to avoid positive items in train-phase to be sampled in valid-phase, and positive items in train-phase or valid-phase to be sampled in test-phase, we need to input the datasets of all phases for pre-processing. And, before using this sampler, it is needed to call :meth:`set_phase` to get the sampler of corresponding phase. Args: phases (str or list of str): All the phases of input. datasets (Dataset or list of Dataset): All the dataset for each phase. distribution (str, optional): Distribution of the negative items. Defaults to 'uniform'. Attributes: phase (str): the phase of sampler. It will not be set until :meth:`set_phase` is called. """ def __init__(self, phases, datasets, distribution='uniform'): if not isinstance(phases, list): phases = [phases] if not isinstance(datasets, list): datasets = [datasets] if len(phases) != len(datasets): raise ValueError(f'Phases {phases} and datasets {datasets} should have the same length.') self.phases = phases self.datasets = datasets self.uid_field = datasets[0].uid_field self.iid_field = datasets[0].iid_field self.user_num = datasets[0].user_num self.item_num = datasets[0].item_num super().__init__(distribution=distribution) def _get_candidates_list(self): candidates_list = [] for dataset in self.datasets: candidates_list.extend(dataset.inter_feat[self.iid_field].numpy()) return candidates_list def _uni_sampling(self, sample_num): return np.random.randint(1, self.item_num, sample_num)
[docs] def get_used_ids(self): """ Returns: dict: Used item_ids is the same as positive item_ids. Key is phase, and value is a numpy.ndarray which index is user_id, and element is a set of item_ids. """ used_item_id = dict() last = [set() for _ in range(self.user_num)] for phase, dataset in zip(self.phases, self.datasets): cur = np.array([set(s) for s in last]) for uid, iid in zip(dataset.inter_feat[self.uid_field].numpy(), dataset.inter_feat[self.iid_field].numpy()): cur[uid].add(iid) last = used_item_id[phase] = cur for used_item_set in used_item_id[self.phases[-1]]: if len(used_item_set) + 1 == self.item_num: # [pad] is a item. raise ValueError( 'Some users have interacted with all items, ' 'which we can not sample negative items for them. ' 'Please set `user_inter_num_interval` to filter those users.' ) return used_item_id
[docs] def set_phase(self, phase): """Get the sampler of corresponding phase. Args: phase (str): The phase of new sampler. Returns: Sampler: the copy of this sampler, :attr:`phase` is set the same as input phase, and :attr:`used_ids` is set to the value of corresponding phase. """ if phase not in self.phases: raise ValueError(f'Phase [{phase}] not exist.') new_sampler = copy.copy(self) new_sampler.phase = phase new_sampler.used_ids = new_sampler.used_ids[phase] return new_sampler
[docs] def sample_by_user_ids(self, user_ids, item_ids, num): """Sampling by user_ids. Args: user_ids (numpy.ndarray or list): Input user_ids. item_ids (numpy.ndarray or list): Input item_ids. num (int): Number of sampled item_ids for each user_id. Returns: torch.tensor: Sampled item_ids. item_ids[0], item_ids[len(user_ids)], item_ids[len(user_ids) * 2], ..., item_id[len(user_ids) * (num - 1)] is sampled for user_ids[0]; item_ids[1], item_ids[len(user_ids) + 1], item_ids[len(user_ids) * 2 + 1], ..., item_id[len(user_ids) * (num - 1) + 1] is sampled for user_ids[1]; ...; and so on. """ try: return self.sample_by_key_ids(user_ids, num) except IndexError: for user_id in user_ids: if user_id < 0 or user_id >= self.user_num: raise ValueError(f'user_id [{user_id}] not exist.')
[docs]class KGSampler(AbstractSampler): """:class:`KGSampler` is used to sample negative entities in a knowledge graph. Args: dataset (Dataset): The knowledge graph dataset, which contains triplets in a knowledge graph. distribution (str, optional): Distribution of the negative entities. Defaults to 'uniform'. """ def __init__(self, dataset, distribution='uniform'): self.dataset = dataset self.hid_field = dataset.head_entity_field self.tid_field = dataset.tail_entity_field self.hid_list = dataset.head_entities self.tid_list = dataset.tail_entities self.head_entities = set(dataset.head_entities) self.entity_num = dataset.entity_num super().__init__(distribution=distribution) def _uni_sampling(self, sample_num): return np.random.randint(1, self.entity_num, sample_num) def _get_candidates_list(self): return list(self.hid_list) + list(self.tid_list)
[docs] def get_used_ids(self): """ Returns: numpy.ndarray: Used entity_ids is the same as tail_entity_ids in knowledge graph. Index is head_entity_id, and element is a set of tail_entity_ids. """ used_tail_entity_id = np.array([set() for _ in range(self.entity_num)]) for hid, tid in zip(self.hid_list, self.tid_list): used_tail_entity_id[hid].add(tid) for used_tail_set in used_tail_entity_id: if len(used_tail_set) + 1 == self.entity_num: # [pad] is a entity. raise ValueError( 'Some head entities have relation with all entities, ' 'which we can not sample negative entities for them.' ) return used_tail_entity_id
[docs] def sample_by_entity_ids(self, head_entity_ids, num=1): """Sampling by head_entity_ids. Args: head_entity_ids (numpy.ndarray or list): Input head_entity_ids. num (int, optional): Number of sampled entity_ids for each head_entity_id. Defaults to ``1``. Returns: torch.tensor: Sampled entity_ids. entity_ids[0], entity_ids[len(head_entity_ids)], entity_ids[len(head_entity_ids) * 2], ..., entity_id[len(head_entity_ids) * (num - 1)] is sampled for head_entity_ids[0]; entity_ids[1], entity_ids[len(head_entity_ids) + 1], entity_ids[len(head_entity_ids) * 2 + 1], ..., entity_id[len(head_entity_ids) * (num - 1) + 1] is sampled for head_entity_ids[1]; ...; and so on. """ try: return self.sample_by_key_ids(head_entity_ids, num) except IndexError: for head_entity_id in head_entity_ids: if head_entity_id not in self.head_entities: raise ValueError(f'head_entity_id [{head_entity_id}] not exist.')
[docs]class RepeatableSampler(AbstractSampler): """:class:`RepeatableSampler` is used to sample negative items for each input user. The difference from :class:`Sampler` is it can only sampling the items that have not appeared at all phases. Args: phases (str or list of str): All the phases of input. dataset (Dataset): The union of all datasets for each phase. distribution (str, optional): Distribution of the negative items. Defaults to 'uniform'. Attributes: phase (str): the phase of sampler. It will not be set until :meth:`set_phase` is called. """ def __init__(self, phases, dataset, distribution='uniform'): if not isinstance(phases, list): phases = [phases] self.phases = phases self.dataset = dataset self.iid_field = dataset.iid_field self.user_num = dataset.user_num self.item_num = dataset.item_num super().__init__(distribution=distribution) def _uni_sampling(self, sample_num): return np.random.randint(1, self.item_num, sample_num) def _get_candidates_list(self): return list(self.dataset.inter_feat[self.iid_field].numpy())
[docs] def get_used_ids(self): """ Returns: numpy.ndarray: Used item_ids is the same as positive item_ids. Index is user_id, and element is a set of item_ids. """ return np.array([set() for _ in range(self.user_num)])
[docs] def sample_by_user_ids(self, user_ids, item_ids, num): """Sampling by user_ids. Args: user_ids (numpy.ndarray or list): Input user_ids. item_ids (numpy.ndarray or list): Input item_ids. num (int): Number of sampled item_ids for each user_id. Returns: torch.tensor: Sampled item_ids. item_ids[0], item_ids[len(user_ids)], item_ids[len(user_ids) * 2], ..., item_id[len(user_ids) * (num - 1)] is sampled for user_ids[0]; item_ids[1], item_ids[len(user_ids) + 1], item_ids[len(user_ids) * 2 + 1], ..., item_id[len(user_ids) * (num - 1) + 1] is sampled for user_ids[1]; ...; and so on. """ try: self.used_ids = np.array([{i} for i in item_ids]) return self.sample_by_key_ids(np.arange(len(user_ids)), num) except IndexError: for user_id in user_ids: if user_id < 0 or user_id >= self.user_num: raise ValueError(f'user_id [{user_id}] not exist.')
[docs] def set_phase(self, phase): """Get the sampler of corresponding phase. Args: phase (str): The phase of new sampler. Returns: Sampler: the copy of this sampler, and :attr:`phase` is set the same as input phase. """ if phase not in self.phases: raise ValueError(f'Phase [{phase}] not exist.') new_sampler = copy.copy(self) new_sampler.phase = phase return new_sampler
[docs]class SeqSampler(AbstractSampler): """:class:`SeqSampler` is used to sample negative item sequence. Args: datasets (Dataset or list of Dataset): All the dataset for each phase. distribution (str, optional): Distribution of the negative items. Defaults to 'uniform'. """ def __init__(self, dataset, distribution='uniform'): self.dataset = dataset self.iid_field = dataset.iid_field self.user_num = dataset.user_num self.item_num = dataset.item_num super().__init__(distribution=distribution) def _uni_sampling(self, sample_num): return np.random.randint(1, self.item_num, sample_num)
[docs] def get_used_ids(self): pass
[docs] def sample_neg_sequence(self, pos_sequence): """For each moment, sampling one item from all the items except the one the user clicked on at that moment. Args: pos_sequence (torch.Tensor): all users' item history sequence, with the shape of `(N, )`. Returns: torch.tensor : all users' negative item history sequence. """ total_num = len(pos_sequence) value_ids = np.zeros(total_num, dtype=np.int64) check_list = np.arange(total_num) while len(check_list) > 0: value_ids[check_list] = self.sampling(len(check_list)) check_index = np.where(value_ids[check_list] == pos_sequence[check_list]) check_list = check_list[check_index] return torch.tensor(value_ids)