class ModelFactory

Bases: object

Factory responsible for flexible model cration based on user input

build_model(conv: str, loss_function: Dict[Any, Any], device: torch.device, num_features: int, hidden_layer: int, out_layer: int, num_layers: int, dropout: float, heads: int = 0) BaseNet

Build model based on input

Parameters:
  • device – (device): Either ‘cuda’ or ‘cpu’

  • hidden_layer – (int): The size of hidden layer (default:64)

  • out_layer – (int): The size of output layer (default:128)

  • dropout – (float): Dropout (default:0.0)

  • num_layers – (int): Number of layers in the model (default:2)

  • heads – (int): Number of heads in GAT conv (default:1)

  • conv – (str): Either ‘GCN’, ‘GAT’ or ‘SAGE’ convolution

Returns:

Model