Customize Samplers

In RecBole, sampler module is designed to select negative items for training and evaluation.

Here we present how to develop a new sampler, and apply it into RecBole. The new sampler is used when we need complex sampling method.

In RecBole, we now only support two kinds of sampling strategies: random negative sampling (RNS) and popularity-biased negative sampling (PNS). RNS is to select the negative items in uniform distribution, and PNS is to select the negative item in a popularity-biased distribution. For PNS, we set the popularity-biased distribution based on the total number of items’ interactions.

In our framework, if you want to create a new sampler, you need to inherit the AbstractSampler, implement __init__(), rewrite three functions: _uni_sampling(), _get_candidates_list(), get_used_ids() and create a new sampling function.

Here, we take the KGSampler as an example.

Create a New Sampler Class

To begin with, we create a new sampler based on AbstractSampler:

from recbole.sampler import AbstractSampler
class KGSampler(AbstractSampler):
    pass

Implement __init__()

Then, we implement __init__(), in this method, we can flexibly define and initialize the parameters, where we only need to invoke super.__init__(distribution).

def __init__(self, dataset, distribution='uniform'):
    self.dataset = dataset

    self.hid_field = dataset.head_entity_field
    self.tid_field = dataset.tail_entity_field
    self.hid_list = dataset.head_entities
    self.tid_list = dataset.tail_entities

    self.head_entities = set(dataset.head_entities)
    self.entity_num = dataset.entity_num

    super().__init__(distribution=distribution)

Implement _uni_sampling()

To implement the RNS for KGSampler, we need to rewrite the _uni_sampling(). Here we use the numpy.random.randint() to help us randomly select the entity_id. This function will return the selected samples’ id (here is entity_id).

Example code:

def _uni_sampling(self, sample_num):
    return np.random.randint(1, self.entity_num, sample_num)

Implement _get_candidates_list()

To implement PNS for KGSampler, we need to rewrite the _get_candidates_list(). This function is used to get a candidate list for PNS, and we will set the sampling distribution based on Counter(candidate_list). This function will return a list of candidates’ id.

Example code:

def _get_candidates_list(self):
    return list(self.hid_list) + list(self.tid_list)

Implement get_used_ids()

For negative sampling, we do not want to sample positive instance, this function is used to record the positive sample. The function will return numpy, and the index is the ID. The returned value will be saved in self.used_ids.

Example code:

def get_used_ids(self):
   used_tail_entity_id = np.array([set() for _ in range(self.entity_num)])
    for hid, tid in zip(self.hid_list, self.tid_list):
        used_tail_entity_id[hid].add(tid)

    for used_tail_set in used_tail_entity_id:
        if len(used_tail_set) + 1 == self.entity_num:  # [pad] is a entity.
            raise ValueError(
                'Some head entities have relation with all entities, '
                'which we can not sample negative entities for them.'
            )
    return used_tail_entity_id

Implement the sampling function

In AbstractSampler, we have implemented sample_by_key_ids() function, where we have three parameters: key_ids, num and used_ids. Key_ids is the candidate objective ID list, num is the number of samples, used_ids is the positive sample list.

In the function, we sample num instances for each element in key_ids. The function finally return numpy.ndarray, the index of 0, len(key_ids), len(key_ids) * 2, …, len(key_ids) * (num - 1) is the result of key_ids[0]. The index of 1, len(key_ids) + 1, len(key_ids) * 2 + 1, …, len(key_ids) * (num - 1) + 1 is the result of key_ids[1].

One can also design his own sampler, if the above process is not appropriate.

Example code:

def sample_by_entity_ids(self, head_entity_ids, num=1):
    try:
        return self.sample_by_key_ids(head_entity_ids, num, self.used_ids[head_entity_ids])
    except IndexError:
        for head_entity_id in head_entity_ids:
            if head_entity_id not in self.head_entities:
                raise ValueError('head_entity_id [{}] not exist'.format(head_entity_id))

Complete Code

class KGSampler(AbstractSampler):
    def __init__(self, dataset, distribution='uniform'):
        self.dataset = dataset

        self.hid_field = dataset.head_entity_field
        self.tid_field = dataset.tail_entity_field
        self.hid_list = dataset.head_entities
        self.tid_list = dataset.tail_entities

        self.head_entities = set(dataset.head_entities)
        self.entity_num = dataset.entity_num

        super().__init__(distribution=distribution)

    def _uni_sampling(self, sample_num):
        return np.random.randint(1, self.entity_num, sample_num)

    def _get_candidates_list(self):
        return list(self.hid_list) + list(self.tid_list)

    def get_used_ids(self):
        used_tail_entity_id = np.array([set() for _ in range(self.entity_num)])
        for hid, tid in zip(self.hid_list, self.tid_list):
            used_tail_entity_id[hid].add(tid)

        for used_tail_set in used_tail_entity_id:
            if len(used_tail_set) + 1 == self.entity_num:  # [pad] is a entity.
                raise ValueError(
                    'Some head entities have relation with all entities, '
                    'which we can not sample negative entities for them.'
                )
        return used_tail_entity_id

    def sample_by_entity_ids(self, head_entity_ids, num=1):
        try:
            return self.sample_by_key_ids(head_entity_ids, num)
        except IndexError:
            for head_entity_id in head_entity_ids:
                if head_entity_id not in self.head_entities:
                    raise ValueError(f'head_entity_id [{head_entity_id}] not exist.')