- class ModelGraphClassification(*args: Any, **kwargs: Any)
Bases:
Module
Model for Graph Classification task
- Parameters:
dataset – ([Graph]): List of input graphs
device – (device): Device – ‘cuda’ or ‘cpu’
conv – (str): Name of the convolution used for Neural Network
hidden_layer – (int): The size of hidden layer (default: 64)
dropout – (int): Dropout (default: 0)
num_layers – (int): Number of layers in the model (default: 2)
ssl_flag – (bool): If True, self supervised loss would be alsooptimized during the training, in addition to semi-supervised
heads – (int): Number of heads in GAT layer
- static convert_dataset(data: List[Graph], train_indices: List[int], val_indices: List[int]) Tuple[List[Graph], List[Graph], List[Graph], int]
Convert input dataset to train,test, val according to provided indices
- Parameters:
data – ([Graph]): List of graphs as input dataset
train_indices – ([int]): List of indices for train dataset
val_indices – ([int]): List of indices for validation dataset
- Returns:
([Graph],[Graph],[Graph], int): Lists of train and validation graphs and the minimum size among all graphs
- forward(x: torch_geometric.typing.Tensor, edge_index: torch_geometric.typing.Adj, batch: torch_geometric.typing.Tensor) Tuple[torch_geometric.typing.Tensor, torch_geometric.typing.Tensor]
Count the representation of node on the next layer of the model
- Parameters:
x – (Tensor) Input features
edge_index – (Adj) Edge index of a batch
batch – Batch of data
- Returns:
(Tensor, Tensor): Predicted probabilities of labels and predicted degrees of graphs
- static loss_sup(pred: torch_geometric.typing.Tensor, label: torch_geometric.typing.Tensor) torch_geometric.typing.Tensor
Negative log likelihood loss
- Parameters:
pred – (Tensor): Predicted labels
label – (Tensor): Genuine labels
- Returns:
(Tensor): Loss
- static self_supervised_loss(deg_pred: torch_geometric.typing.Tensor, batch: torch_geometric.data.Batch) torch_geometric.typing.Tensor
Self Supervised Loss for Graph Classsification task, MSE between predicted average degree of each graph and genuine ones
- Parameters:
deg_pred – (Tensor): Tensor of predicted degrees of graphs in dataset
batch – (): Batch of train data
- Returns:
(Tensor): Loss