"""Gating function"""
import torch
from collections import OrderedDict, defaultdict
[docs]class DenseGatingFunction(torch.nn.Module):
"""DenseGatingFunction"""
def __init__(
self,
beta,
gate_layers=[128, 256, 256],
num_reps=1,
gate_dropout=None,
device="cpu",
):
"""
Build a gating function that is a densly connected feedforward network
that maps one input embedding to a series of tensors representing the
einet's parameters
"""
super(DenseGatingFunction, self).__init__()
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(1)
beta.num_reps = num_reps
layers_shapes = gate_layers
gate_layers = []
for i, l in enumerate(layers_shapes[:-1]):
gate_layers.extend(
[
(
f"linear{i+1}",
torch.nn.Linear(l, layers_shapes[i + 1], device=device),
),
(f"relu{i+1}", torch.nn.ReLU()),
]
)
if gate_dropout:
gate_layers.append(
(f"dropout{i+1}", torch.nn.Dropout(p=args.gate_dropout))
)
print(gate_layers)
self.gate = torch.nn.Sequential(OrderedDict(gate_layers))
num_branches = defaultdict(list)
for node in beta.positive_iter():
if node.is_decomposition():
num_branches[len(node.positive_elements)].append(node)
# elif node.is_mixing():
# num_branches[len(node.elements)].append(node)
elif node.is_true():
num_branches[2].append(node)
beta.num_branches = num_branches
# Output_shapes: a list of (num_decision_nodes, num_children) tuples
self.output_shapes = (
[(len(v), k, num_reps) for k, v in num_branches.items()]
+ [(num_reps,)]
+ [(len(beta.elements), num_reps)]
)
flattened_shapes = (
[len(v) * k * num_reps for k, v in num_branches.items()]
+ [num_reps]
+ [len(beta.elements) * num_reps]
)
self.outputs = []
for i, o in enumerate(flattened_shapes):
setattr(
self,
f"output{i}",
torch.nn.Sequential(
torch.nn.Linear(layers_shapes[-1], o, device=device),
),
)
self.outputs.append(getattr(self, f"output{i}"))
self.initialize()
[docs] def move_on_device(self, device: str):
self.gate = self.gate.to(device)
for o in self.outputs:
o = o.to(device)
[docs] def forward(self, x):
x = self.gate(x)
return [o(x).reshape(-1, *s) for o, s in zip(self.outputs, self.output_shapes)]
[docs] def get_output(self, x):
return self.gate(x)
[docs] def init_weights(self, m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
[docs] def initialize(self):
self.gate.apply(self.init_weights)
for o in self.outputs:
self.init_weights(o)