Source code for revelionn.main_module

import torch
from sklearn.metrics import accuracy_score, roc_auc_score
from torch import nn, optim

from .early_stopping import EarlyStopping


[docs] class MainModelProcessing: """ Class for training, evaluation and processing the main network model. Attributes ---------- device : torch.device Tensor processing device. main_net : torch.nn.Module The model of the main neural network. classes : dict Names of neural network output classes. Methods ------- load_model(path_to_model_dict) Loads the weights of the neural network model from a file. train_model(patience, epochs, file_name, class_label_name, module_name, main_net_class, transformation_name, img_size, num_channels) Training and validation of the main neural network. evaluate_model(test_loader) Evaluation of the model on the test set. get_main_net() Returns the main neural network. get_class_labels() Returns names of neural network output classes. get_device() Returns the current tensor processing device. """ def __init__(self, main_net, device): """ Sets all the necessary attributes for the MainModelProcessing object. Parameters ---------- main_net : torch.nn.Module The model of the main neural network. device : torch.device Tensor processing device. """ self.device = device self.main_net = main_net.to(self.device) self.classes = None
[docs] def get_main_net(self): """ Returns the main neural network. Returns ------- main_net : MainNet(nn.Module) The main neural network. """ return self.main_net
[docs] def get_class_labels(self): """ Returns names of neural network output classes. Returns ------- classes : dict Names of neural network output classes. """ return self.classes
[docs] def get_device(self): """ Returns the current tensor processing device. Returns ------- device : torch.device Tensor processing device. """ return self.device
[docs] def load_model(self, path_to_model): """ Loads the weights of the neural network model from a file. Parameters ---------- path_to_model : str The path to the file containing weights. Returns ------- None """ checkpoint = torch.load(path_to_model, map_location=self.device) self.main_net.load_state_dict(checkpoint['model_state_dict']) self.classes = checkpoint['classes']
[docs] def train_model(self, train_loader, valid_loader, patience, epochs, filename, class_label, module_name, main_net_class, transformation_name, img_size, num_channels): """ Training and validation of the main neural network. Parameters ---------- patience : int How many epochs to wait after last time validation loss improved. epochs : int The number of training epochs of the main neural network. filename : str The name of the file in which the parameters of the trained model will be saved. class_label : str The name of the label of the class used for training. module_name : str The name of the file containing the main network class. main_net_class : str Name of the main network class. transformation_name : str Name of the variable storing transformations. img_size : int The size of the image side. num_channels : int The number of image channels. Returns ------- None """ self.main_net.to(self.device) criterion = nn.BCELoss() optimizer = optim.Adam(self.main_net.parameters(), lr=0.001) early_stopping = EarlyStopping(patience=patience, verbose=True) for e in range(epochs): num_train_batches_without_auc = 0 num_valid_batches_without_auc = 0 train_loss = 0 train_acc = 0 train_auc = 0 self.main_net.train() for images, labels in train_loader: images, labels = images.to(self.device), labels.to(self.device) optimizer.zero_grad() logits = self.main_net(images) labels = labels.float() loss = criterion(logits, labels) loss.backward() optimizer.step() train_loss += loss.item() predictions = (logits > 0.5).long() train_acc += accuracy_score(labels.cpu(), predictions.cpu()) try: auc = roc_auc_score(labels.cpu(), logits.cpu().detach().numpy()) train_auc += auc except ValueError: num_train_batches_without_auc += 1 valid_loss = 0 valid_acc = 0 valid_auc = 0 self.main_net.eval() with torch.no_grad(): for images, labels in valid_loader: images, labels = images.to(self.device), labels.to(self.device) logits = self.main_net(images) labels = labels.float() batch_loss = criterion(logits, labels) valid_loss += batch_loss.item() predictions = (logits > 0.5).long() valid_acc += accuracy_score(labels.cpu(), predictions.cpu()) try: auc = roc_auc_score(labels.cpu(), logits.cpu().detach().numpy()) valid_auc += auc except ValueError: num_valid_batches_without_auc += 1 result = f"Epoch {e + 1}/{epochs}.. " \ f"Train loss: {train_loss / len(train_loader):.3f}.. " \ f"Valid loss: {valid_loss / len(valid_loader):.3f}.. " \ f"Train acc: {train_acc / len(train_loader):.3f}.. " \ f"Valid acc: {valid_acc / len(valid_loader):.3f}.. " \ f"Train AUC: {train_auc / (len(train_loader) - num_train_batches_without_auc):.3f}.. " \ f"Valid AUC: {valid_auc / (len(valid_loader) - num_valid_batches_without_auc):.3f}.. " print(result) with open(f'{filename}.txt', "a") as file: file.write(result + '\n') valid_loss_decrease = early_stopping(valid_loss / len(valid_loader)) if early_stopping.early_stop: print("Early stopping") break elif valid_loss_decrease is not None: with open(f'{filename}.txt', "a") as file: file.write(valid_loss_decrease + '\n') classes = {1: class_label, 0: f'Not{class_label}'} torch.save({'classes': classes, 'model_state_dict': self.main_net.state_dict(), 'main_net_module_name': module_name, 'main_net_class': main_net_class, 'transformation_name': transformation_name, 'img_size': img_size, 'num_channels': num_channels }, f'{filename}.rvl')
[docs] def evaluate_model(self, test_loader): """ Evaluation of the model on the test set. Parameters ---------- test_loader : torch.utils.data.DataLoader Training data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. Returns ------- test_loss : float Test loss. test_acc : float Accuracy on the test set. test_auc : float ROC AUC on the test set. """ criterion = nn.BCELoss() self.main_net.eval() num_test_batches_without_auc = 0 test_loss = 0 test_acc = 0 test_auc = 0 with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(self.device), labels.to(self.device) logits = self.main_net(images) labels = labels.float() batch_loss = criterion(logits, labels) test_loss += batch_loss.item() predictions = (logits > 0.5).long() test_acc += accuracy_score(labels.cpu(), predictions.cpu()) try: auc = roc_auc_score(labels.cpu(), logits.cpu().detach().numpy()) test_auc += auc except ValueError: num_test_batches_without_auc += 1 test_loss = test_loss / len(test_loader) test_acc = test_acc / len(test_loader) test_auc = test_auc / (len(test_loader) - num_test_batches_without_auc) print(f"Test loss: {test_loss:.4f}.. " f"Test acc: {test_acc:.4f}.. " f"Test AUC: {test_auc:.4f}.. ") return test_loss, test_acc, test_auc