- class ModelFactory
Bases:
object
Factory responsible for flexible model cration based on user input
- build_model(conv: str, loss_function: Dict[Any, Any], device: torch.device, num_features: int, hidden_layer: int, out_layer: int, num_layers: int, dropout: float, heads: int = 0) BaseNet
Build model based on input
- Parameters:
device – (device): Either ‘cuda’ or ‘cpu’
hidden_layer – (int): The size of hidden layer (default:64)
out_layer – (int): The size of output layer (default:128)
dropout – (float): Dropout (default:0.0)
num_layers – (int): Number of layers in the model (default:2)
heads – (int): Number of heads in GAT conv (default:1)
conv – (str): Either ‘GCN’, ‘GAT’ or ‘SAGE’ convolution
- Returns:
Model