import os
os.chdir('../..')
from torchvision.datasets import MNIST
import torch
import torch.nn as nn
import torchvision.transforms as TF
from eXNN.InnerNeuralTopology import api
train_ds = MNIST(root='./.cache', train=True, download=True, transform=TF.ToTensor())
test_ds = MNIST(root='./.cache', train=False, download=False, transform=TF.ToTensor())
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=36, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=36, shuffle=False)
import torch.nn as nn
class SimpleNN(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim, leaky_coef=0.1):
super(SimpleNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LeakyReLU(leaky_coef)
)
self.layer2 = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(leaky_coef)
)
self.layer3 = nn.Sequential(
nn.Linear(hidden_dim, output_dim)#,
#nn.Sigmoid()
)
def forward(self, x):
x = nn.Flatten()(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
num_classes = 10
model = SimpleNN(28*28, num_classes, 64)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
data = torch.stack([test_ds[i][0] for i in range(100)])
res_unnorm_before = api.NetworkHomologies(model, data, layers = ['layer2'], hom_type = "standard", coefs_type = "2")
n_epochs = 20
loss_fn = nn.CrossEntropyLoss()
for epoch in list(range(n_epochs)):
for imgs, lbls in train_dl:
optimizer.zero_grad()
out = model(imgs)
loss = loss_fn(out, lbls)
loss.backward()
optimizer.step()
print("Epoch {} loss: {}".format(epoch, loss.item()))
data = torch.stack([test_ds[i][0] for i in range(100)])
res_unnorm_after = api.NetworkHomologies(model, data, layers = ['layer2'], hom_type = "standard", coefs_type = "2")
num_classes = 20
model = SimpleNN(28*28, num_classes, 64)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay = 0.01)
n_epochs = 20
loss_fn = nn.CrossEntropyLoss()
for epoch in list(range(n_epochs)):
for imgs, lbls in train_dl:
optimizer.zero_grad()
out = model(imgs)
loss = loss_fn(out, lbls)
loss.backward()
optimizer.step()
print("Epoch {} loss: {}".format(epoch, loss.item()))
data = torch.stack([test_ds[i][0] for i in range(100)])
res_norm_after = api.NetworkHomologies(model, data, layers = ['layer2'], hom_type = "standard", coefs_type = "2")
num_classes = 20
model = SimpleNN(28*28, num_classes, 64)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay = 10)
n_epochs = 20
loss_fn = nn.CrossEntropyLoss()
for epoch in list(range(n_epochs)):
for imgs, lbls in train_dl:
optimizer.zero_grad()
out = model(imgs)
loss = loss_fn(out, lbls)
loss.backward()
optimizer.step()
print("Epoch {} loss: {}".format(epoch, loss.item()))
data = torch.stack([test_ds[i][0] for i in range(100)])
res_norm_destructive = api.NetworkHomologies(model, data, layers = ['layer2'], hom_type = "standard", coefs_type = "2")
res_unnorm_before["layer2"]
res_unnorm_after["layer2"]
res_norm_after["layer2"]
res_norm_destructive["layer2"]