- class TrainPipeline(config, item_vocab, model, optimizer)
Bases:
object
- memory_sampling(memory: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- optimize_model(memory: torch.Tensor)
- run()
- tmp_Q_eps_greedy(state, actions)
- train_gnn_model(model, optimizer, subgraph, positive_edges, negative_edges)
- train_kge_model(kge_model, train_pars, info, train_triples, valid_triples, max_steps=10)
Trainin pipeline for model