- class BaseSampler(data: Graph, device: torch.device, loss_info: Dict)
Bases:
ABC
Base class for sampling of positive and negative edges for unsupervised loss function
- Parameters:
data – (Graph): Input dataset
device – (device): Either ‘cuda’ or ‘cpu’
loss_info – (dict): Dict of parameters of unsupervised loss function
- abstract sample(batch: torch_geometric.typing.Tensor) torch_geometric.typing.Tensor
Sample edges. Must be implemented
- Parameters:
batch – (Batch): Nodes for sampling positive edges for them
- class BaseSamplerWithNegative(data: Graph, device: torch.device, loss_info: Dict)
Bases:
BaseSampler
Sampler for negative edges
- Parameters:
data – (Graph): Input Graph data
device – (device): Either ‘cuda’ or ‘cpu’
- sample(batch: torch_geometric.data.Batch) Tuple[torch_geometric.typing.Tensor, torch_geometric.typing.Tensor]
Sample positive and negative edges for batch nodes
- Parameters:
batch – (Batch): Nodes for positive and negative sampling from them
- Returns:
(Tensor, Tensor): positive and negative samples