# -*- coding: utf-8 -*-
# @Time : 2020/9/16
# @Author : Shanlei Mu
# @Email : slmu@ruc.edu.cn
"""
recbole.model.init
########################
"""
import torch.nn as nn
from torch.nn.init import xavier_normal_, xavier_uniform_, constant_
[docs]def xavier_normal_initialization(module):
r""" using `xavier_normal_`_ in PyTorch to initialize the parameters in
nn.Embedding and nn.Linear layers. For bias in nn.Linear layers,
using constant 0 to initialize.
.. _`xavier_normal_`:
https://pytorch.org/docs/stable/nn.init.html?highlight=xavier_normal_#torch.nn.init.xavier_normal_
Examples:
>>> self.apply(xavier_normal_initialization)
"""
if isinstance(module, nn.Embedding):
xavier_normal_(module.weight.data)
elif isinstance(module, nn.Linear):
xavier_normal_(module.weight.data)
if module.bias is not None:
constant_(module.bias.data, 0)