Source code for recbole.utils.utils

# -*- coding: utf-8 -*-
# @Time   : 2020/7/17
# @Author : Shanlei Mu
# @Email  : slmu@ruc.edu.cn

# UPDATE
# @Time   : 2021/3/8
# @Author : Jiawei Guan
# @Email  : guanjw@ruc.edu.cn

"""
recbole.utils.utils
################################
"""

import datetime
import importlib
import os
import random

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from recbole.utils.enum_type import ModelType


[docs]def get_local_time(): r"""Get current time Returns: str: current time """ cur = datetime.datetime.now() cur = cur.strftime('%b-%d-%Y_%H-%M-%S') return cur
[docs]def ensure_dir(dir_path): r"""Make sure the directory exists, if it does not exist, create it Args: dir_path (str): directory path """ if not os.path.exists(dir_path): os.makedirs(dir_path)
[docs]def get_model(model_name): r"""Automatically select model class based on model name Args: model_name (str): model name Returns: Recommender: model class """ model_submodule = [ 'general_recommender', 'context_aware_recommender', 'sequential_recommender', 'knowledge_aware_recommender', 'exlib_recommender' ] model_file_name = model_name.lower() model_module = None for submodule in model_submodule: module_path = '.'.join(['recbole.model', submodule, model_file_name]) if importlib.util.find_spec(module_path, __name__): model_module = importlib.import_module(module_path, __name__) break if model_module is None: raise ValueError('`model_name` [{}] is not the name of an existing model.'.format(model_name)) model_class = getattr(model_module, model_name) return model_class
[docs]def get_trainer(model_type, model_name): r"""Automatically select trainer class based on model type and model name Args: model_type (ModelType): model type model_name (str): model name Returns: Trainer: trainer class """ try: return getattr(importlib.import_module('recbole.trainer'), model_name + 'Trainer') except AttributeError: if model_type == ModelType.KNOWLEDGE: return getattr(importlib.import_module('recbole.trainer'), 'KGTrainer') elif model_type == ModelType.TRADITIONAL: return getattr(importlib.import_module('recbole.trainer'), 'TraditionalTrainer') else: return getattr(importlib.import_module('recbole.trainer'), 'Trainer')
[docs]def early_stopping(value, best, cur_step, max_step, bigger=True): r""" validation-based early stopping Args: value (float): current result best (float): best result cur_step (int): the number of consecutive steps that did not exceed the best result max_step (int): threshold steps for stopping bigger (bool, optional): whether the bigger the better Returns: tuple: - float, best result after this step - int, the number of consecutive steps that did not exceed the best result after this step - bool, whether to stop - bool, whether to update """ stop_flag = False update_flag = False if bigger: if value >= best: cur_step = 0 best = value update_flag = True else: cur_step += 1 if cur_step > max_step: stop_flag = True else: if value <= best: cur_step = 0 best = value update_flag = True else: cur_step += 1 if cur_step > max_step: stop_flag = True return best, cur_step, stop_flag, update_flag
[docs]def calculate_valid_score(valid_result, valid_metric=None): r""" return valid score from valid result Args: valid_result (dict): valid result valid_metric (str, optional): the selected metric in valid result for valid score Returns: float: valid score """ if valid_metric: return valid_result[valid_metric] else: return valid_result['Recall@10']
[docs]def dict2str(result_dict): r""" convert result dict to str Args: result_dict (dict): result dict Returns: str: result str """ return ' '.join([str(metric) + ' : ' + str(value) for metric, value in result_dict.items()])
[docs]def init_seed(seed, reproducibility): r""" init random seed for random functions in numpy, torch, cuda and cudnn Args: seed (int): random seed reproducibility (bool): Whether to require reproducibility """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) if reproducibility: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True else: torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False
[docs]def get_tensorboard(logger): r""" Creates a SummaryWriter of Tensorboard that can log PyTorch models and metrics into a directory for visualization within the TensorBoard UI. For the convenience of the user, the naming rule of the SummaryWriter's log_dir is the same as the logger. Args: logger: its output filename is used to name the SummaryWriter's log_dir. If the filename is not available, we will name the log_dir according to the current time. Returns: SummaryWriter: it will write out events and summaries to the event file. """ base_path = 'log_tensorboard' dir_name = None for handler in logger.handlers: if hasattr(handler, "baseFilename"): dir_name = os.path.basename(getattr(handler, 'baseFilename')).split('.')[0] break if dir_name is None: dir_name = '{}-{}'.format('model', get_local_time()) dir_path = os.path.join(base_path, dir_name) writer = SummaryWriter(dir_path) return writer
[docs]def get_gpu_usage(device=None): r""" Return the reserved memory and total memory of given device in a string. Args: device: cuda.device. It is the device that the model run on. Returns: str: it contains the info about reserved memory and total memory of given device. """ reserved = torch.cuda.max_memory_reserved(device) / 1024 ** 3 total = torch.cuda.get_device_properties(device).total_memory / 1024 ** 3 return '{:.2f} G/{:.2f} G'.format(reserved, total)