Customize Samplers¶
In RecBole, sampler module is designed to select negative items for training and evaluation.
Here we present how to develop a new sampler, and apply it into RecBole. The new sampler is used when we need complex sampling method.
In RecBole, we now only support two kinds of sampling strategies: random negative sampling (RNS) and popularity-biased negative sampling (PNS). RNS is to select the negative items in uniform distribution, and PNS is to select the negative item in a popularity-biased distribution. For PNS, we set the popularity-biased distribution based on the total number of items’ interactions.
In our framework, if you want to create a new sampler, you need to inherit the AbstractSampler
, implement
__init__()
,
rewrite three functions: _uni_sampling()
,
_get_candidates_list()
,
get_used_ids()
and create a new sampling function.
Here, we take the KGSampler
as an example.
Create a New Sampler Class¶
To begin with, we create a new sampler based on AbstractSampler
:
from recbole.sampler import AbstractSampler
class KGSampler(AbstractSampler):
pass
Implement __init__()¶
Then, we implement __init__()
, in this method, we can flexibly define and initialize the parameters,
where we only need to invoke super.__init__(distribution)
.
def __init__(self, dataset, distribution='uniform'):
self.dataset = dataset
self.hid_field = dataset.head_entity_field
self.tid_field = dataset.tail_entity_field
self.hid_list = dataset.head_entities
self.tid_list = dataset.tail_entities
self.head_entities = set(dataset.head_entities)
self.entity_num = dataset.entity_num
super().__init__(distribution=distribution)
Implement _uni_sampling()¶
To implement the RNS for KGSampler, we need to rewrite the _uni_sampling()
.
Here we use the numpy.random.randint()
to help us randomly select the entity_id
. This function will return the
selected samples’ id (here is entity_id
).
Example code:
def _uni_sampling(self, sample_num):
return np.random.randint(1, self.entity_num, sample_num)
Implement _get_candidates_list()¶
To implement PNS for KGSampler, we need to rewrite the _get_candidates_list()
.
This function is used to get a candidate list for PNS, and we will set the sampling distribution based on
Counter(candidate_list)
. This function will return a list of candidates’ id.
Example code:
def _get_candidates_list(self):
return list(self.hid_list) + list(self.tid_list)
Implement get_used_ids()¶
For negative sampling, we do not want to sample positive instance, this function is used to record the positive sample.
The function will return numpy, and the index is the ID. The returned value will be saved in self.used_ids
.
Example code:
def get_used_ids(self):
used_tail_entity_id = np.array([set() for _ in range(self.entity_num)])
for hid, tid in zip(self.hid_list, self.tid_list):
used_tail_entity_id[hid].add(tid)
for used_tail_set in used_tail_entity_id:
if len(used_tail_set) + 1 == self.entity_num: # [pad] is a entity.
raise ValueError(
'Some head entities have relation with all entities, '
'which we can not sample negative entities for them.'
)
return used_tail_entity_id
Implement the sampling function¶
In AbstractSampler
, we have implemented sample_by_key_ids()
function,
where we have three parameters: key_ids
, num
and used_ids
.
Key_ids
is the candidate objective ID list, num
is the number of samples, used_ids
is the positive sample list.
In the function, we sample num
instances for each element in key_ids
. The function finally return numpy.ndarray
,
the index of 0, len(key_ids), len(key_ids) * 2, …, len(key_ids) * (num - 1) is the result of key_ids[0].
The index of 1, len(key_ids) + 1, len(key_ids) * 2 + 1, …, len(key_ids) * (num - 1) + 1 is the result of key_ids[1].
One can also design his own sampler, if the above process is not appropriate.
Example code:
def sample_by_entity_ids(self, head_entity_ids, num=1):
try:
return self.sample_by_key_ids(head_entity_ids, num, self.used_ids[head_entity_ids])
except IndexError:
for head_entity_id in head_entity_ids:
if head_entity_id not in self.head_entities:
raise ValueError('head_entity_id [{}] not exist'.format(head_entity_id))
Complete Code¶
class KGSampler(AbstractSampler):
def __init__(self, dataset, distribution='uniform'):
self.dataset = dataset
self.hid_field = dataset.head_entity_field
self.tid_field = dataset.tail_entity_field
self.hid_list = dataset.head_entities
self.tid_list = dataset.tail_entities
self.head_entities = set(dataset.head_entities)
self.entity_num = dataset.entity_num
super().__init__(distribution=distribution)
def _uni_sampling(self, sample_num):
return np.random.randint(1, self.entity_num, sample_num)
def _get_candidates_list(self):
return list(self.hid_list) + list(self.tid_list)
def get_used_ids(self):
used_tail_entity_id = np.array([set() for _ in range(self.entity_num)])
for hid, tid in zip(self.hid_list, self.tid_list):
used_tail_entity_id[hid].add(tid)
for used_tail_set in used_tail_entity_id:
if len(used_tail_set) + 1 == self.entity_num: # [pad] is a entity.
raise ValueError(
'Some head entities have relation with all entities, '
'which we can not sample negative entities for them.'
)
return used_tail_entity_id
def sample_by_entity_ids(self, head_entity_ids, num=1):
try:
return self.sample_by_key_ids(head_entity_ids, num)
except IndexError:
for head_entity_id in head_entity_ids:
if head_entity_id not in self.head_entities:
raise ValueError(f'head_entity_id [{head_entity_id}] not exist.')