# -*- coding: utf-8 -*-
# @Time : 2020/7/17
# @Author : Shanlei Mu
# @Email : slmu@ruc.edu.cn
# UPDATE
# @Time : 2021/3/8
# @Author : Jiawei Guan
# @Email : guanjw@ruc.edu.cn
"""
recbole.utils.utils
################################
"""
import datetime
import importlib
import os
import random
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from recbole.utils.enum_type import ModelType
[docs]def get_local_time():
r"""Get current time
Returns:
str: current time
"""
cur = datetime.datetime.now()
cur = cur.strftime('%b-%d-%Y_%H-%M-%S')
return cur
[docs]def ensure_dir(dir_path):
r"""Make sure the directory exists, if it does not exist, create it
Args:
dir_path (str): directory path
"""
if not os.path.exists(dir_path):
os.makedirs(dir_path)
[docs]def get_model(model_name):
r"""Automatically select model class based on model name
Args:
model_name (str): model name
Returns:
Recommender: model class
"""
model_submodule = [
'general_recommender', 'context_aware_recommender', 'sequential_recommender', 'knowledge_aware_recommender',
'exlib_recommender'
]
model_file_name = model_name.lower()
model_module = None
for submodule in model_submodule:
module_path = '.'.join(['recbole.model', submodule, model_file_name])
if importlib.util.find_spec(module_path, __name__):
model_module = importlib.import_module(module_path, __name__)
break
if model_module is None:
raise ValueError('`model_name` [{}] is not the name of an existing model.'.format(model_name))
model_class = getattr(model_module, model_name)
return model_class
[docs]def get_trainer(model_type, model_name):
r"""Automatically select trainer class based on model type and model name
Args:
model_type (ModelType): model type
model_name (str): model name
Returns:
Trainer: trainer class
"""
try:
return getattr(importlib.import_module('recbole.trainer'), model_name + 'Trainer')
except AttributeError:
if model_type == ModelType.KNOWLEDGE:
return getattr(importlib.import_module('recbole.trainer'), 'KGTrainer')
elif model_type == ModelType.TRADITIONAL:
return getattr(importlib.import_module('recbole.trainer'), 'TraditionalTrainer')
else:
return getattr(importlib.import_module('recbole.trainer'), 'Trainer')
[docs]def early_stopping(value, best, cur_step, max_step, bigger=True):
r""" validation-based early stopping
Args:
value (float): current result
best (float): best result
cur_step (int): the number of consecutive steps that did not exceed the best result
max_step (int): threshold steps for stopping
bigger (bool, optional): whether the bigger the better
Returns:
tuple:
- float,
best result after this step
- int,
the number of consecutive steps that did not exceed the best result after this step
- bool,
whether to stop
- bool,
whether to update
"""
stop_flag = False
update_flag = False
if bigger:
if value >= best:
cur_step = 0
best = value
update_flag = True
else:
cur_step += 1
if cur_step > max_step:
stop_flag = True
else:
if value <= best:
cur_step = 0
best = value
update_flag = True
else:
cur_step += 1
if cur_step > max_step:
stop_flag = True
return best, cur_step, stop_flag, update_flag
[docs]def calculate_valid_score(valid_result, valid_metric=None):
r""" return valid score from valid result
Args:
valid_result (dict): valid result
valid_metric (str, optional): the selected metric in valid result for valid score
Returns:
float: valid score
"""
if valid_metric:
return valid_result[valid_metric]
else:
return valid_result['Recall@10']
[docs]def dict2str(result_dict):
r""" convert result dict to str
Args:
result_dict (dict): result dict
Returns:
str: result str
"""
return ' '.join([str(metric) + ' : ' + str(value) for metric, value in result_dict.items()])
[docs]def init_seed(seed, reproducibility):
r""" init random seed for random functions in numpy, torch, cuda and cudnn
Args:
seed (int): random seed
reproducibility (bool): Whether to require reproducibility
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if reproducibility:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
[docs]def get_tensorboard(logger):
r""" Creates a SummaryWriter of Tensorboard that can log PyTorch models and metrics into a directory for
visualization within the TensorBoard UI.
For the convenience of the user, the naming rule of the SummaryWriter's log_dir is the same as the logger.
Args:
logger: its output filename is used to name the SummaryWriter's log_dir.
If the filename is not available, we will name the log_dir according to the current time.
Returns:
SummaryWriter: it will write out events and summaries to the event file.
"""
base_path = 'log_tensorboard'
dir_name = None
for handler in logger.handlers:
if hasattr(handler, "baseFilename"):
dir_name = os.path.basename(getattr(handler, 'baseFilename')).split('.')[0]
break
if dir_name is None:
dir_name = '{}-{}'.format('model', get_local_time())
dir_path = os.path.join(base_path, dir_name)
writer = SummaryWriter(dir_path)
return writer
[docs]def get_gpu_usage(device=None):
r""" Return the reserved memory and total memory of given device in a string.
Args:
device: cuda.device. It is the device that the model run on.
Returns:
str: it contains the info about reserved memory and total memory of given device.
"""
reserved = torch.cuda.max_memory_reserved(device) / 1024 ** 3
total = torch.cuda.get_device_properties(device).total_memory / 1024 ** 3
return '{:.2f} G/{:.2f} G'.format(reserved, total)