# Customize Samplers¶

Here we present how to develop a new sampler, and apply it into RecBole. The new sampler is used when we need complex sampling method.

Here, we take the KGSampler as an example.

## Create a New Sampler Class¶

To begin with, we create a new sampler based on AbstractSampler:

from recbole.sampler import AbstractSampler
class KGSampler(AbstractSampler):
pass


## Implement __init__()¶

Then, we implement __init__(), in this method, we can flexibly define and initialize the parameters, where we only need to invoke super.__init__(distribution).

def __init__(self, dataset, distribution='uniform'):
self.dataset = dataset

self.tid_field = dataset.tail_entity_field
self.tid_list = dataset.tail_entities

self.entity_num = dataset.entity_num

super().__init__(distribution=distribution)


## Implement get_random_list()¶

We do not use the random function in python or numpy due to their lower efficiency. Instead, we realize our own random() function, where the key method is to combine the random list with the pointer. The pointer point to some element in the random list. When one calls self.random(), the element is returned, and moves the pointer backward by one element. If the pointer point to the last element, then it will return to the head of the element.

In AbstractSampler, the __init__() will call get_random_list(), and shuffle the results. We only need to return a list including all the elements.

It should be noted 0 can be the token used for padding, thus one should remain this value.

Example code:

def get_random_list(self):
if self.distribution == 'uniform':
return list(range(1, self.entity_num))
elif self.distribution == 'popularity':
return list(self.hid_list) + list(self.tid_list)
else:
raise NotImplementedError('Distribution [{}] has not been implemented'.format(self.distribution))


## Implement get_used_ids()¶

For negative sampling, we do not want to sample positive instance, this function is used to compute the positive sample. The function will return numpy, and the index is the ID. The return value will be saved in self.used_ids.

Example code:

def get_used_ids(self):
used_tail_entity_id = np.array([set() for i in range(self.entity_num)])
for hid, tid in zip(self.hid_list, self.tid_list):
return used_tail_entity_id


## Implementing the sampling function¶

In AbstractSampler, we have implemented sample_by_key_ids() function, where we have three parameters: key_ids, num and used_ids. Key_ids is the candidate objective ID list, num is the number of samples, used_ids are the positive sample list.

In the function, we sample num instances for each element in key_ids. The function finally return numpy.ndarray, the index of 0, len(key_ids), len(key_ids) * 2, …, len(key_ids) * (num - 1) is the result of key_ids[0]. The index of 1, len(key_ids) + 1, len(key_ids) * 2 + 1, …, len(key_ids) * (num - 1) + 1 is the result of key_ids[1].

One can also design her own sampler, if the above process is not appropriate.

Example code:

def sample_by_entity_ids(self, head_entity_ids, num=1):
try:
except IndexError:


## Complete Code¶

class KGSampler(AbstractSampler):
""":class:KGSampler is used to sample negative entities in a knowledge graph.

Args:
dataset (Dataset): The knowledge graph dataset, which contains triplets in a knowledge graph.
distribution (str, optional): Distribution of the negative entities. Defaults to 'uniform'.
"""
def __init__(self, dataset, distribution='uniform'):
self.dataset = dataset

self.tid_field = dataset.tail_entity_field
self.tid_list = dataset.tail_entities

self.entity_num = dataset.entity_num

super().__init__(distribution=distribution)

def get_random_list(self):
"""
Returns:
np.ndarray or list: Random list of entity_id.
"""
if self.distribution == 'uniform':
return list(range(1, self.entity_num))
elif self.distribution == 'popularity':
return list(self.hid_list) + list(self.tid_list)
else:
raise NotImplementedError('Distribution [{}] has not been implemented'.format(self.distribution))

def get_used_ids(self):
"""
Returns:
np.ndarray: Used entity_ids is the same as tail_entity_ids in knowledge graph.
Index is head_entity_id, and element is a set of tail_entity_ids.
"""
used_tail_entity_id = np.array([set() for i in range(self.entity_num)])
for hid, tid in zip(self.hid_list, self.tid_list):
return used_tail_entity_id

Args:
num (int, optional): Number of sampled entity_ids for each head_entity_id. Defaults to 1.

Returns:
np.ndarray: Sampled entity_ids.