Source code for chmncc.optimizers.sgd

"""Stochastic gradient descend scheduler"""
import torch.nn as nn
import torch
from chmncc.probabilistic_circuits.GatingFunction import DenseGatingFunction


[docs]def get_sgd_optimizer( net: nn.Module, lr: float, momentum: float = 0.9 ) -> torch.optim.Optimizer: r""" SGD optimizer Args: net [nn.Module]: network architecture momentum [float]: momentum Returns: optimizer [nn.Optimizer] """ return torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum)
[docs]def get_sdg_optimizer_with_gate( net: nn.Module, gate: DenseGatingFunction, lr: float, weight_decay: float ) -> torch.optim.Optimizer: r""" ADAM optimizer Args: net [nn.Module]: network architecture lr [float]: learning rate weight_decay [float]: weight_decay Returns: optimizer [nn.Optimizer] """ return torch.optim.SGD( list(net.parameters()) + list(gate.parameters()), lr=lr, momentum=0.9 )