class Extrapolate(dataset: List[Graph], model: torch.nn.Module)

Bases: object

An Extrapolate class for both node and graph classification.

How to build extrapolation:

exptrapolation = Extrapolate(dataset=dataset, model=model)
(train_dataset, test_dataset, val_dataset,) = Extrapolation(train_indices,val_indices,init_edges,remove_init_edges,white_list,score_func)
Parameters:
  • dataset – ([Graph]): Dataset of class Graph.

  • model – (Module): The model to explain.