Source code for

# @Time   : 2020/9/3
# @Author : Yupeng Hou
# @Email  :

# @Time   : 2020/10/16, 2020/9/15, 2020/10/25, 2022/7/10
# @Author : Yupeng Hou, Xingyu Pan, Yushuo Chen, Lanling Xu
# @Email  :,,,


import os
from collections import Counter

import numpy as np
import pandas as pd
import torch
from scipy.sparse import coo_matrix

from import Dataset
from recbole.utils import FeatureSource, FeatureType, set_color
from recbole.utils.url import decide_download, download_url, extract_zip

[docs]class KnowledgeBasedDataset(Dataset): """:class:`KnowledgeBasedDataset` is based on :class:``, and load ``.kg`` and ``.link`` additionally. Entities are remapped together with ``item_id`` specially. All entities are remapped into three consecutive ID sections. - virtual entities that only exist in interaction data. - entities that exist both in interaction data and kg triplets. - entities only exist in kg triplets. It also provides several interfaces to transfer ``.kg`` features into coo sparse matrix, csr sparse matrix, :class:`DGL.Graph` or :class:`PyG.Data`. Attributes: head_entity_field (str): The same as ``config['HEAD_ENTITY_ID_FIELD']``. tail_entity_field (str): The same as ``config['TAIL_ENTITY_ID_FIELD']``. relation_field (str): The same as ``config['RELATION_ID_FIELD']``. entity_field (str): The same as ``config['ENTITY_ID_FIELD']``. kg_feat (pandas.DataFrame): Internal data structure stores the kg triplets. It's loaded from file ``.kg``. item2entity (dict): Dict maps ``item_id`` to ``entity``, which is loaded from file ``.link``. entity2item (dict): Dict maps ``entity`` to ``item_id``, which is loaded from file ``.link``. Note: :attr:`entity_field` doesn't exist exactly. It's only a symbol, representing entity features. ``[UI-Relation]`` is a special relation token. """ def __init__(self, config): super().__init__(config) def _get_field_from_config(self): super()._get_field_from_config() self.head_entity_field = self.config["HEAD_ENTITY_ID_FIELD"] self.tail_entity_field = self.config["TAIL_ENTITY_ID_FIELD"] self.relation_field = self.config["RELATION_ID_FIELD"] self.entity_field = self.config["ENTITY_ID_FIELD"] self.kg_reverse_r = self.config["kg_reverse_r"] self.entity_kg_num_interval = self.config["entity_kg_num_interval"] self.relation_kg_num_interval = self.config["relation_kg_num_interval"] self._check_field( "head_entity_field", "tail_entity_field", "relation_field", "entity_field" ) self.set_field_property( self.entity_field, FeatureType.TOKEN, FeatureSource.KG, 1 ) self.logger.debug( set_color("relation_field", "blue") + f": {self.relation_field}" ) self.logger.debug(set_color("entity_field", "blue") + f": {self.entity_field}") def _data_filtering(self): super()._data_filtering() self._filter_kg_by_triple_num() self._filter_link() def _filter_kg_by_triple_num(self): """Filter by number of triples. The interval of the number of triples can be set, and only entities/relations whose number of triples is in the specified interval can be retained. See :doc:`../user_guide/data/data_args` for detail arg setting. Note: Lower bound of the interval is also called k-core filtering, which means this method will filter loops until all the entities and relations has at least k triples. """ entity_kg_num_interval = self._parse_intervals_str( self.config["entity_kg_num_interval"] ) relation_kg_num_interval = self._parse_intervals_str( self.config["relation_kg_num_interval"] ) if entity_kg_num_interval is None and relation_kg_num_interval is None: return entity_kg_num = Counter() if entity_kg_num_interval: head_entity_kg_num = Counter(self.kg_feat[self.head_entity_field].values) tail_entity_kg_num = Counter(self.kg_feat[self.tail_entity_field].values) entity_kg_num = head_entity_kg_num + tail_entity_kg_num relation_kg_num = ( Counter(self.kg_feat[self.relation_field].values) if relation_kg_num_interval else Counter() ) while True: ban_head_entities = self._get_illegal_ids_by_inter_num( field=self.head_entity_field, feat=None, inter_num=entity_kg_num, inter_interval=entity_kg_num_interval, ) ban_tail_entities = self._get_illegal_ids_by_inter_num( field=self.tail_entity_field, feat=None, inter_num=entity_kg_num, inter_interval=entity_kg_num_interval, ) ban_entities = ban_head_entities | ban_tail_entities ban_relations = self._get_illegal_ids_by_inter_num( field=self.relation_field, feat=None, inter_num=relation_kg_num, inter_interval=relation_kg_num_interval, ) if len(ban_entities) == 0 and len(ban_relations) == 0: break dropped_kg = pd.Series(False, index=self.kg_feat.index) head_entity_kg = self.kg_feat[self.head_entity_field] tail_entity_kg = self.kg_feat[self.tail_entity_field] relation_kg = self.kg_feat[self.relation_field] dropped_kg |= head_entity_kg.isin(ban_entities) dropped_kg |= tail_entity_kg.isin(ban_entities) dropped_kg |= relation_kg.isin(ban_relations) entity_kg_num -= Counter(head_entity_kg[dropped_kg].values) entity_kg_num -= Counter(tail_entity_kg[dropped_kg].values) relation_kg_num -= Counter(relation_kg[dropped_kg].values) dropped_index = self.kg_feat.index[dropped_kg] self.logger.debug(f"[{len(dropped_index)}] dropped triples.") self.kg_feat.drop(dropped_index, inplace=True) def _filter_link(self): """Filter rows of :attr:`item2entity` and :attr:`entity2item`, whose ``entity_id`` doesn't occur in kg triplets and ``item_id`` doesn't occur in interaction records. """ item_tokens = self._get_rec_item_token() ent_tokens = self._get_entity_token() illegal_item = set() illegal_ent = set() for item in self.item2entity: ent = self.item2entity[item] if item not in item_tokens or ent not in ent_tokens: illegal_item.add(item) illegal_ent.add(ent) for item in illegal_item: del self.item2entity[item] for ent in illegal_ent: del self.entity2item[ent] remained_inter = pd.Series(True, index=self.inter_feat.index) remained_inter &= self.inter_feat[self.iid_field].isin(self.item2entity.keys()) self.inter_feat.drop(self.inter_feat.index[~remained_inter], inplace=True) def _download(self): super()._download() url = self._get_download_url("kg_url", allow_none=True) if url is None: return"Prepare to download linked knowledge graph from [{url}].") if decide_download(url): # No need to create dir, as `super()._download()` has created one. path = download_url(url, self.dataset_path) extract_zip(path, self.dataset_path) os.unlink(path) f"\nLinked KG for [{self.dataset_name}] requires additional conversion " f"to atomic files (.kg and .link).\n" f"Please refer to " f"for detailed instructions.\n" f"You can run RecBole after the conversion, see you soon." ) exit(0) else:"Stop download.") exit(-1) def _load_data(self, token, dataset_path): super()._load_data(token, dataset_path) self.kg_feat = self._load_kg(self.dataset_name, self.dataset_path) self.item2entity, self.entity2item = self._load_link( self.dataset_name, self.dataset_path ) def __str__(self): info = [ super().__str__(), f"The number of entities: {self.entity_num}", f"The number of relations: {self.relation_num}", f"The number of triples: {len(self.kg_feat)}", f"The number of items that have been linked to KG: {len(self.item2entity)}", ] # yapf: disable return "\n".join(info) def _build_feat_name_list(self): feat_name_list = super()._build_feat_name_list() if self.kg_feat is not None: feat_name_list.append("kg_feat") return feat_name_list def _load_kg(self, token, dataset_path): self.logger.debug(set_color(f"Loading kg from [{dataset_path}].", "green")) kg_path = os.path.join(dataset_path, f"{token}.kg") if not os.path.isfile(kg_path): raise ValueError(f"[{token}.kg] not found in [{dataset_path}].") df = self._load_feat(kg_path, FeatureSource.KG) self._check_kg(df) return df def _check_kg(self, kg): kg_warn_message = "kg data requires field [{}]" assert self.head_entity_field in kg, kg_warn_message.format( self.head_entity_field ) assert self.tail_entity_field in kg, kg_warn_message.format( self.tail_entity_field ) assert self.relation_field in kg, kg_warn_message.format(self.relation_field) def _load_link(self, token, dataset_path): self.logger.debug(set_color(f"Loading link from [{dataset_path}].", "green")) link_path = os.path.join(dataset_path, f"{token}.link") if not os.path.isfile(link_path): raise ValueError(f"[{token}.link] not found in [{dataset_path}].") df = self._load_feat(link_path, "link") self._check_link(df) item2entity, entity2item = {}, {} for item_id, entity_id in zip( df[self.iid_field].values, df[self.entity_field].values ): item2entity[item_id] = entity_id entity2item[entity_id] = item_id return item2entity, entity2item def _check_link(self, link): link_warn_message = "link data requires field [{}]" assert self.entity_field in link, link_warn_message.format(self.entity_field) assert self.iid_field in link, link_warn_message.format(self.iid_field) def _init_alias(self): """Add :attr:`alias_of_entity_id`, :attr:`alias_of_relation_id` and update :attr:`_rest_fields`.""" self._set_alias("entity_id", [self.head_entity_field, self.tail_entity_field]) self._set_alias("relation_id", [self.relation_field]) super()._init_alias() self._rest_fields = np.setdiff1d( self._rest_fields, [self.entity_field], assume_unique=True ) def _get_rec_item_token(self): """Get set of entity tokens from fields in ``rec`` level.""" remap_list = self._get_remap_list(self.alias["item_id"]) tokens, _ = self._concat_remaped_tokens(remap_list) return set(tokens) def _get_entity_token(self): """Get set of entity tokens from fields in ``ent`` level.""" remap_list = self._get_remap_list(self.alias["entity_id"]) tokens, _ = self._concat_remaped_tokens(remap_list) return set(tokens) def _reset_ent_remapID(self, field, idmap, id2token, token2id): self.field2id_token[field] = id2token self.field2token_id[field] = token2id for feat in self.field2feats(field): ftype = self.field2type[field] if ftype == FeatureType.TOKEN: old_idx = feat[field].values else: old_idx = feat[field].agg(np.concatenate) new_idx = idmap[old_idx] if ftype == FeatureType.TOKEN: feat[field] = new_idx else: split_point = np.cumsum(feat[field].agg(len))[:-1] feat[field] = np.split(new_idx, split_point) def _merge_item_and_entity(self): """Merge item-id and entity-id into the same id-space.""" item_token = self.field2id_token[self.iid_field] entity_token = self.field2id_token[self.head_entity_field] item_num = len(item_token) link_num = len(self.item2entity) entity_num = len(entity_token) # reset item id item_priority = np.array([token in self.item2entity for token in item_token]) item_order = np.argsort(item_priority, kind="stable") item_id_map = np.zeros_like(item_order) item_id_map[item_order] = np.arange(item_num) new_item_id2token = item_token[item_order] new_item_token2id = {t: i for i, t in enumerate(new_item_id2token)} for field in self.alias["item_id"]: self._reset_ent_remapID( field, item_id_map, new_item_id2token, new_item_token2id ) # reset entity id entity_priority = np.array( [ token != "[PAD]" and token not in self.entity2item for token in entity_token ] ) entity_order = np.argsort(entity_priority, kind="stable") entity_id_map = np.zeros_like(entity_order) for i in entity_order[1 : link_num + 1]: entity_id_map[i] = new_item_token2id[self.entity2item[entity_token[i]]] entity_id_map[entity_order[link_num + 1 :]] = np.arange( item_num, item_num + entity_num - link_num - 1 ) new_entity_id2token = np.concatenate( [new_item_id2token, entity_token[entity_order[link_num + 1 :]]] ) for i in range(item_num - link_num, item_num): new_entity_id2token[i] = self.item2entity[new_entity_id2token[i]] new_entity_token2id = {t: i for i, t in enumerate(new_entity_id2token)} for field in self.alias["entity_id"]: self._reset_ent_remapID( field, entity_id_map, new_entity_id2token, new_entity_token2id ) self.field2id_token[self.entity_field] = new_entity_id2token self.field2token_id[self.entity_field] = new_entity_token2id def _add_auxiliary_relation(self): """Add auxiliary relations in ``self.relation_field``.""" if self.kg_reverse_r: # '0' is used for padding, so the number needs to be reduced by one original_rel_num = len(self.field2id_token[self.relation_field]) - 1 original_hids = self.kg_feat[self.head_entity_field] original_tids = self.kg_feat[self.tail_entity_field] original_rels = self.kg_feat[self.relation_field] # Internal id gap of a relation and its reverse edge is original relation num reverse_rels = original_rels + original_rel_num # Add mapping for internal and external ID of relations for i in range(1, original_rel_num + 1): original_token = self.field2id_token[self.relation_field][i] reverse_token = original_token + "_r" self.field2token_id[self.relation_field][reverse_token] = ( i + original_rel_num ) self.field2id_token[self.relation_field] = np.append( self.field2id_token[self.relation_field], reverse_token ) # Update knowledge graph triples with reverse relations reverse_kg_data = { self.head_entity_field: original_tids, self.relation_field: reverse_rels, self.head_entity_field: original_hids, } reverse_kg_feat = pd.DataFrame(reverse_kg_data) self.kg_feat = pd.concat([self.kg_feat, reverse_kg_feat]) # Add UI-relation pairs in the relation field kg_rel_num = len(self.field2id_token[self.relation_field]) self.field2token_id[self.relation_field]["[UI-Relation]"] = kg_rel_num self.field2id_token[self.relation_field] = np.append( self.field2id_token[self.relation_field], "[UI-Relation]" ) def _remap_ID_all(self): super()._remap_ID_all() self._merge_item_and_entity() self._add_auxiliary_relation() @property def relation_num(self): """Get the number of different tokens of ``self.relation_field``. Returns: int: Number of different tokens of ``self.relation_field``. """ return self.num(self.relation_field) @property def entity_num(self): """Get the number of different tokens of entities, including virtual entities. Returns: int: Number of different tokens of entities, including virtual entities. """ return self.num(self.entity_field) @property def head_entities(self): """ Returns: numpy.ndarray: List of head entities of kg triplets. """ return self.kg_feat[self.head_entity_field].numpy() @property def tail_entities(self): """ Returns: numpy.ndarray: List of tail entities of kg triplets. """ return self.kg_feat[self.tail_entity_field].numpy() @property def relations(self): """ Returns: numpy.ndarray: List of relations of kg triplets. """ return self.kg_feat[self.relation_field].numpy() @property def entities(self): """ Returns: numpy.ndarray: List of entity id, including virtual entities. """ return np.arange(self.entity_num)
[docs] def kg_graph(self, form="coo", value_field=None): """Get graph or sparse matrix that describe relations between entities. For an edge of <src, tgt>, ``graph[src, tgt] = 1`` if ``value_field`` is ``None``, else ``graph[src, tgt] = self.kg_feat[value_field][src, tgt]``. Currently, we support graph in `DGL`_ and `PyG`_, and two type of sparse matrices, ``coo`` and ``csr``. Args: form (str, optional): Format of sparse matrix, or library of graph data structure. Defaults to ``coo``. value_field (str, optional): edge attributes of graph, or data of sparse matrix, Defaults to ``None``. Returns: Graph / Sparse matrix of kg triplets. .. _DGL: .. _PyG: """ args = [ self.kg_feat, self.head_entity_field, self.tail_entity_field, form, value_field, ] if form in ["coo", "csr"]: return self._create_sparse_matrix(*args) elif form in ["dgl", "pyg"]: return self._create_graph(*args) else: raise NotImplementedError("kg graph format [{}] has not been implemented.")
def _create_ckg_sparse_matrix(self, form="coo", show_relation=False): user_num = self.user_num hids = self.head_entities + user_num tids = self.tail_entities + user_num uids = self.inter_feat[self.uid_field].numpy() iids = self.inter_feat[self.iid_field].numpy() + user_num ui_rel_num = len(uids) ui_rel_id = self.relation_num - 1 assert self.field2id_token[self.relation_field][ui_rel_id] == "[UI-Relation]" src = np.concatenate([uids, iids, hids]) tgt = np.concatenate([iids, uids, tids]) if not show_relation: data = np.ones(len(src)) else: kg_rel = self.kg_feat[self.relation_field].numpy() ui_rel = np.full(2 * ui_rel_num, ui_rel_id, dtype=kg_rel.dtype) data = np.concatenate([ui_rel, kg_rel]) node_num = self.entity_num + self.user_num mat = coo_matrix((data, (src, tgt)), shape=(node_num, node_num)) if form == "coo": return mat elif form == "csr": return mat.tocsr() else: raise NotImplementedError( f"Sparse matrix format [{form}] has not been implemented." ) def _create_ckg_graph(self, form="dgl", show_relation=False): user_num = self.user_num kg_tensor = self.kg_feat inter_tensor = self.inter_feat head_entity = kg_tensor[self.head_entity_field] + user_num tail_entity = kg_tensor[self.tail_entity_field] + user_num user = inter_tensor[self.uid_field] item = inter_tensor[self.iid_field] + user_num src =[user, item, head_entity]) tgt =[item, user, tail_entity]) if show_relation: ui_rel_num = user.shape[0] ui_rel_id = self.relation_num - 1 assert ( self.field2id_token[self.relation_field][ui_rel_id] == "[UI-Relation]" ) kg_rel = kg_tensor[self.relation_field] ui_rel = torch.full((2 * ui_rel_num,), ui_rel_id, dtype=kg_rel.dtype) edge =[ui_rel, kg_rel]) if form == "dgl": import dgl graph = dgl.graph((src, tgt)) if show_relation: graph.edata[self.relation_field] = edge return graph elif form == "pyg": from import Data edge_attr = edge if show_relation else None graph = Data(edge_index=torch.stack([src, tgt]), edge_attr=edge_attr) return graph else: raise NotImplementedError( f"Graph format [{form}] has not been implemented." )
[docs] def ckg_graph(self, form="coo", value_field=None): """Get graph or sparse matrix that describe relations of CKG, which combines interactions and kg triplets into the same graph. Item ids and entity ids are added by ``user_num`` temporally. For an edge of <src, tgt>, ``graph[src, tgt] = 1`` if ``value_field`` is ``None``, else ``graph[src, tgt] = self.kg_feat[self.relation_field][src, tgt]`` or ``graph[src, tgt] = [UI-Relation]``. Currently, we support graph in `DGL`_ and `PyG`_, and two type of sparse matrices, ``coo`` and ``csr``. Args: form (str, optional): Format of sparse matrix, or library of graph data structure. Defaults to ``coo``. value_field (str, optional): ``self.relation_field`` or ``None``, Defaults to ``None``. Returns: Graph / Sparse matrix of kg triplets. .. _DGL: .. _PyG: """ if value_field is not None and value_field != self.relation_field: raise ValueError( f"Value_field [{value_field}] can only be [{self.relation_field}] in ckg_graph." ) show_relation = value_field is not None if form in ["coo", "csr"]: return self._create_ckg_sparse_matrix(form, show_relation) elif form in ["dgl", "pyg"]: return self._create_ckg_graph(form, show_relation) else: raise NotImplementedError("ckg graph format [{}] has not been implemented.")