Customize Samplers

Here we present how to develop a new sampler, and apply it into RecBole. The new sampler is used when we need complex sampling method.

Here, we take the KGSampler as an example.

Create a New Sampler Class

To begin with, we create a new sampler based on AbstractSampler:

from recbole.sampler import AbstractSampler
class KGSampler(AbstractSampler):
    pass

Implement __init__()

Then, we implement __init__(), in this method, we can flexibly define and initialize the parameters, where we only need to invoke super.__init__(distribution).

def __init__(self, dataset, distribution='uniform'):
    self.dataset = dataset

    self.hid_field = dataset.head_entity_field
    self.tid_field = dataset.tail_entity_field
    self.hid_list = dataset.head_entities
    self.tid_list = dataset.tail_entities

    self.head_entities = set(dataset.head_entities)
    self.entity_num = dataset.entity_num

    super().__init__(distribution=distribution)

Implement get_random_list()

We do not use the random function in python or numpy due to their lower efficiency. Instead, we realize our own random() function, where the key method is to combine the random list with the pointer. The pointer point to some element in the random list. When one calls self.random(), the element is returned, and moves the pointer backward by one element. If the pointer point to the last element, then it will return to the head of the element.

In AbstractSampler, the __init__() will call get_random_list(), and shuffle the results. We only need to return a list including all the elements.

It should be noted 0 can be the token used for padding, thus one should remain this value.

Example code:

def get_random_list(self):
    if self.distribution == 'uniform':
        return list(range(1, self.entity_num))
    elif self.distribution == 'popularity':
        return list(self.hid_list) + list(self.tid_list)
    else:
        raise NotImplementedError('Distribution [{}] has not been implemented'.format(self.distribution))

Implement get_used_ids()

For negative sampling, we do not want to sample positive instance, this function is used to compute the positive sample. The function will return numpy, and the index is the ID. The return value will be saved in self.used_ids.

Example code:

def get_used_ids(self):
    used_tail_entity_id = np.array([set() for i in range(self.entity_num)])
    for hid, tid in zip(self.hid_list, self.tid_list):
        used_tail_entity_id[hid].add(tid)
    return used_tail_entity_id

Implementing the sampling function

In AbstractSampler, we have implemented sample_by_key_ids() function, where we have three parameters: key_ids, num and used_ids. Key_ids is the candidate objective ID list, num is the number of samples, used_ids are the positive sample list.

In the function, we sample num instances for each element in key_ids. The function finally return numpy.ndarray, the index of 0, len(key_ids), len(key_ids) * 2, …, len(key_ids) * (num - 1) is the result of key_ids[0]. The index of 1, len(key_ids) + 1, len(key_ids) * 2 + 1, …, len(key_ids) * (num - 1) + 1 is the result of key_ids[1].

One can also design her own sampler, if the above process is not appropriate.

Example code:

def sample_by_entity_ids(self, head_entity_ids, num=1):
    try:
        return self.sample_by_key_ids(head_entity_ids, num, self.used_ids[head_entity_ids])
    except IndexError:
        for head_entity_id in head_entity_ids:
            if head_entity_id not in self.head_entities:
                raise ValueError('head_entity_id [{}] not exist'.format(head_entity_id))

Complete Code

class KGSampler(AbstractSampler):
    """:class:`KGSampler` is used to sample negative entities in a knowledge graph.

    Args:
        dataset (Dataset): The knowledge graph dataset, which contains triplets in a knowledge graph.
        distribution (str, optional): Distribution of the negative entities. Defaults to 'uniform'.
    """
    def __init__(self, dataset, distribution='uniform'):
        self.dataset = dataset

        self.hid_field = dataset.head_entity_field
        self.tid_field = dataset.tail_entity_field
        self.hid_list = dataset.head_entities
        self.tid_list = dataset.tail_entities

        self.head_entities = set(dataset.head_entities)
        self.entity_num = dataset.entity_num

        super().__init__(distribution=distribution)

    def get_random_list(self):
        """
        Returns:
            np.ndarray or list: Random list of entity_id.
        """
        if self.distribution == 'uniform':
            return list(range(1, self.entity_num))
        elif self.distribution == 'popularity':
            return list(self.hid_list) + list(self.tid_list)
        else:
            raise NotImplementedError('Distribution [{}] has not been implemented'.format(self.distribution))

    def get_used_ids(self):
        """
        Returns:
            np.ndarray: Used entity_ids is the same as tail_entity_ids in knowledge graph.
            Index is head_entity_id, and element is a set of tail_entity_ids.
        """
        used_tail_entity_id = np.array([set() for i in range(self.entity_num)])
        for hid, tid in zip(self.hid_list, self.tid_list):
            used_tail_entity_id[hid].add(tid)
        return used_tail_entity_id

    def sample_by_entity_ids(self, head_entity_ids, num=1):
        """Sampling by head_entity_ids.

        Args:
            head_entity_ids (np.ndarray or list): Input head_entity_ids.
            num (int, optional): Number of sampled entity_ids for each head_entity_id. Defaults to ``1``.

        Returns:
            np.ndarray: Sampled entity_ids.
            entity_ids[0], entity_ids[len(head_entity_ids)], entity_ids[len(head_entity_ids) * 2], ...,
            entity_id[len(head_entity_ids) * (num - 1)] is sampled for head_entity_ids[0];
            entity_ids[1], entity_ids[len(head_entity_ids) + 1], entity_ids[len(head_entity_ids) * 2 + 1], ...,
            entity_id[len(head_entity_ids) * (num - 1) + 1] is sampled for head_entity_ids[1]; ...; and so on.
        """
        try:
            return self.sample_by_key_ids(head_entity_ids, num, self.used_ids[head_entity_ids])
        except IndexError:
            for head_entity_id in head_entity_ids:
                if head_entity_id not in self.head_entities:
                    raise ValueError('head_entity_id [{}] not exist'.format(head_entity_id))