from copy import deepcopy
import torch
from torch import nn
[docs]
class MappingModule(nn.Module):
"""
A module representing a common fully connected part of a simultaneous mapping network and blocks of concepts.
Attributes
----------
common_layers : nn.Sequential
The shared layers.
output_layers_list : nn.ModuleList
A list of output layers, each of which maps the input tensor to an output tensor.
sigmoid : nn.Sigmoid
The sigmoid function used to transform the output tensor(s).
Methods
-------
generate_layers(num_neurons)
Generates a list of PyTorch layers based on the number of neurons in each layer.
forward(x)
Forward pass through the module.
"""
def __init__(self, in_features, num_shared_neurons, num_output_neurons, num_outs=1):
"""
Sets all the necessary attributes for the MappingModule object.
Parameters
----------
in_features : int
The number of input features.
num_shared_neurons : list[int]
The number of neurons in consecutive fully connected layers of the common part of the network
(internal representation of the simultaneous extraction network).
num_output_neurons : list[int]
The number of neurons in consecutive fully connected layers of each of the concept blocks.
num_outs : int
The number of outputs of the simultaneous extraction network. It is determined by the number of extracted
concepts.
"""
super(MappingModule, self).__init__()
num_shared_neurons = deepcopy(num_shared_neurons)
num_output_neurons = deepcopy(num_output_neurons)
if len(num_shared_neurons) != 0 and num_shared_neurons[-1] != num_output_neurons[0]:
raise ValueError('The last element of num_shared_neurons list must have the same value as the first '
'element of num_output_neurons list.')
if len(num_shared_neurons) != 0:
num_shared_neurons.insert(0, in_features)
common_layers = self.generate_layers(num_shared_neurons)
common_layers.append(nn.ReLU())
self.common_layers = nn.Sequential(*tuple(common_layers))
else:
num_output_neurons.insert(0, in_features)
output_layers = self.generate_layers(num_output_neurons)
self.output_layers_list = nn.ModuleList()
for i in range(num_outs):
self.output_layers_list.append(deepcopy(nn.Sequential(*tuple(output_layers))))
self.sigmoid = nn.Sigmoid()
[docs]
@staticmethod
def generate_layers(num_neurons):
"""
Generates a list of PyTorch layers based on the number of neurons in each layer.
Parameters
----------
num_neurons : list[int]
The number of neurons in consecutive fully connected layers.
Returns
-------
list[nn.Module]
A list of PyTorch layers.
"""
layers = []
for i in range(len(num_neurons)):
if i != 0 and i != (len(num_neurons) - 1):
layers.append(nn.ReLU())
if i + 1 < len(num_neurons):
layers.append(nn.Linear(num_neurons[i], num_neurons[i + 1]))
return layers
[docs]
def forward(self, x):
"""
Forward pass through the module.
Parameters
----------
x : torch.Tensor
The input tensor.
Returns
-------
tuple[torch.Tensor]
The output tensor(s).
"""
x = self.common_layers(x)
outs = []
for i, output_layers in enumerate(self.output_layers_list):
outs.append(output_layers(x))
outs[i] = self.sigmoid(outs[i])
return tuple(outs)
[docs]
class LayerDecoder(nn.Module):
"""
Module consisting of a 1x1 convolution layer, followed by a ReLU activation function, a global average pooling
layer, and a flattening layer.
Parameters
----------
in_channels : int
The number of input channels to the 1x1 convolution layer.
out_channels : int
The number of output channels from the 1x1 convolution layer.
Attributes
----------
layers : nn.Sequential
A sequential container of the layers that make up this module.
Methods
-------
forward(x)
Forward pass through the module.
"""
def __init__(self, in_channels, out_channels):
"""
Sets all the necessary attributes for the LayerDecoder object.
Parameters
----------
in_channels : int
The number of input channels to the 1x1 convolution layer.
out_channels : int
The number of output channels from the 1x1 convolution layer.
"""
super(LayerDecoder, self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten()
)
[docs]
def forward(self, x):
"""
Forward pass through the module.
Parameters
----------
x : torch.Tensor
The input tensor of shape (batch_size, in_channels, height, width).
Returns
-------
torch.Tensor
The output tensor of shape (batch_size, out_channels).
"""
x = self.layers(x)
return x
[docs]
class SimultaneousMappingNet(nn.Module):
"""
Simultaneous Mapping Network for RevelioNN.
Receives an input tuple of activations of the specified convolutional network layers, after which the input tensors
are processed by decoder blocks. The output tensors of each of the decoders are concatenated and fed into a common
fully connected part of the network. This is followed by blocks of concepts (one for each of the concepts), which
are sets of fully connected layers having 1 neuron and a sigmoid at the output.
Attributes
----------
decoder_channels : int
The number of decoder channels. The output number of channels of the convolutional layer of the decoder or the
output number of neurons of the decoder of the fully connected layer.
num_shared_neurons : list[int]
The number of neurons in consecutive fully connected layers of the common part of the network
(internal representation of the simultaneous extraction network).
num_output_neurons : list[int]
The number of neurons in consecutive fully connected layers of each of the concept blocks.
num_outs : int
The number of outputs of the simultaneous extraction network. It is determined by the number of extracted
concepts.
decoders : torch.nn.ModuleList
Contains the generated decoder blocks in the list.
Methods
-------
forward(x)
Forward pass through the network.
get_decoder_channels()
Returns the number of decoder channels.
get_num_shared_neurons()
Returns the number of neurons in consecutive fully connected layers of the common part of the network.
get_num_output_neurons()
Returns the number of neurons in consecutive fully connected layers of each of the concept blocks.
get_num_outs()
Returns the number of outputs of the simultaneous extraction network.
"""
def __init__(self, activation_extractor, decoder_channels, num_shared_neurons, num_output_neurons, num_outs):
"""
Sets all the necessary attributes for the SimultaneousMappingNet object.
Parameters
----------
activation_extractor : ActivationExtractor
Class for identifying layers of a convolutional neural network and for extracting activations produced
during network inference from a selected set of layers.
decoder_channels : int
The number of decoder channels. The output number of channels of the convolutional layer of the decoder or
the output number of neurons of the decoder of the fully connected layer.
num_shared_neurons : list[int]
The number of neurons in consecutive fully connected layers of the common part of the network
(internal representation of the simultaneous extraction network).
num_output_neurons : list[int]
The number of neurons in consecutive fully connected layers of each of the concept blocks.
num_outs : int
The number of outputs of the simultaneous extraction network. It is determined by the number of extracted
concepts.
"""
super(SimultaneousMappingNet, self).__init__()
self.decoder_channels = decoder_channels
self.num_shared_neurons = num_shared_neurons
self.num_output_neurons = num_output_neurons
self.num_outs = num_outs
self.decoders = nn.ModuleList()
if not activation_extractor.is_concatenate:
layers_dict = activation_extractor.get_layers_dict()
layers_for_research = activation_extractor.get_layers_for_research()
for layer_name in layers_for_research:
if isinstance(layers_dict[layer_name], torch.nn.BatchNorm2d):
self.decoders.append(LayerDecoder(layers_dict[layer_name].num_features, decoder_channels))
if isinstance(layers_dict[layer_name], torch.nn.Conv2d):
self.decoders.append(LayerDecoder(layers_dict[layer_name].out_channels, decoder_channels))
if isinstance(layers_dict[layer_name], torch.nn.Linear):
self.decoders.append(nn.Linear(layers_dict[layer_name].out_features, decoder_channels))
self.mapping_module = MappingModule(decoder_channels * len(self.decoders), num_shared_neurons,
num_output_neurons, num_outs)
else:
raise ValueError("ActivationExtractor.is_concatenate must be set to False for its use in "
"SimultaneousMappingNet.")
[docs]
def forward(self, activations):
"""
Forward pass through the network.
Parameters
----------
activations : tuple[torch.Tensor]
A list of input activations.
Returns
-------
torch.Tensor
The output tensor.
"""
outs = []
if len(self.decoders) != 0:
for i, decoder in enumerate(self.decoders):
outs.append(decoder(activations[i]))
outs = torch.cat(tuple(outs), dim=1)
outs = self.mapping_module(outs)
return outs
[docs]
def get_decoder_channels(self):
"""
Return the number of decoder channels.
Returns
-------
int
The number of decoder channels. The output number of channels of the convolutional layer of the decoder or
the output number of neurons of the decoder of the fully connected layer.
"""
return self.decoder_channels
[docs]
def get_num_shared_neurons(self):
"""
Return the number of shared neurons.
Returns
-------
list[int]
The number of neurons in consecutive fully connected layers of the common part of the network
(internal representation of the simultaneous extraction network).
"""
return self.num_shared_neurons
[docs]
def get_num_output_neurons(self):
"""
Return the number of output neurons.
Returns
-------
list[int]
The number of neurons in consecutive fully connected layers of each of the concept blocks.
"""
return self.num_output_neurons
[docs]
def get_num_outs(self):
"""
Return the number of outputs of the simultaneous extraction network.
Returns
-------
int
The number of outputs of the simultaneous extraction network.
It is determined by the number of extracted concepts.
"""
return self.num_outs