import os
os.chdir('../..')
from pathlib import Path
import torch
from torchvision.models import resnet18
from pathlib import Path
from torchvision.datasets import CIFAR10
import torchvision.transforms as TF
from examples.CIFAR10.models import *
from eXNN.InnerNeuralViz import VisualizeNetSpace
# prepare data
_normalize = TF.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
tfm = TF.Compose([TF.ToTensor(), _normalize])
test_ds = CIFAR10(root='./.cache', train=False, download=False, transform=tfm)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=128, shuffle=True)
data, labels = [], []
itr = iter(test_dl)
for i in range(10):
batch = next(itr)
data.append(batch[0])
labels.append(batch[1])
data = torch.cat(data, dim=0)
labels = torch.cat(labels, dim=0)
# download repository https://github.com/Med-AI-Lab/eXNN-task-CIFAR10
# change model_repo to the root of the downloaded repository
model_repo = Path('../eXNN-task-CIFAR10')
# load pretrained model
device = torch.device('cuda:0')
simple_model = resnet18(num_classes=10)
simple_model.load_state_dict(torch.load(model_repo / "ResNet18.sd.pt", map_location=device));
simple_model = simple_model.eval()
layers = ['layer1', 'layer2', 'layer3', 'layer4', 'avgpool', 'fc']
res = VisualizeNetSpace(simple_model, 'umap', data, layers, labels=labels, chunk_size=128)
Let's look at how well the trained model separated classes
Input data is not split at all
res['input']
After half of a network some structure emerges
res['layer2']
Finally, after the last layer classes are well split
res['fc']