In [3]:
import os
os.chdir('../..')
In [4]:
from torchvision.datasets import MNIST
import torch
import torch.nn as nn
import torchvision.transforms as TF
from tqdm.auto import tqdm
from eXNN.NetBayesianization import wrap, api
In [5]:
train_ds = MNIST(root='./.cache', train=True, download=True, 
                 transform=TF.ToTensor()) 
test_ds = MNIST(root='./.cache', train=False, download=False, 
                transform=TF.ToTensor())
In [6]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=36, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=36, shuffle=False)
In [7]:
num_classes = 10
In [8]:
model = nn.Sequential(nn.Flatten() ,nn.Linear(28*28, 128), 
                       nn.ReLU(), nn.Linear(128, 64), 
                       nn.ReLU(), nn.Linear(64, num_classes), nn.Softmax(dim=1))
optimizer = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
criterion = nn.CrossEntropyLoss()
images, labels = next(iter(train_dl))
images = images.view(images.shape[0], -1)
logps = model(images)
loss = criterion(logps, labels)
In [9]:
# train
n_epochs = 20
for e in range(n_epochs):
    running_loss = 0
    for images, labels in train_dl:
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    else:
        print("Epoch {} - Training loss: {}".format(e, running_loss/len(train_dl)))
Epoch 0 - Training loss: 2.261323278294971
Epoch 1 - Training loss: 1.805009106139473
Epoch 2 - Training loss: 1.6562042623919215
Epoch 3 - Training loss: 1.6340304391428462
Epoch 4 - Training loss: 1.6253673680613838
Epoch 5 - Training loss: 1.6201491944433761
Epoch 6 - Training loss: 1.6161638898054282
Epoch 7 - Training loss: 1.6131943888341014
Epoch 8 - Training loss: 1.610495502461054
Epoch 9 - Training loss: 1.6082047563723338
Epoch 10 - Training loss: 1.6061707546987
Epoch 11 - Training loss: 1.6044750517545951
Epoch 12 - Training loss: 1.6026566071501733
Epoch 13 - Training loss: 1.6009767129168084
Epoch 14 - Training loss: 1.5995301769104415
Epoch 15 - Training loss: 1.598026476223882
Epoch 16 - Training loss: 1.596490390156298
Epoch 17 - Training loss: 1.5951921355602765
Epoch 18 - Training loss: 1.5939387459917989
Epoch 19 - Training loss: 1.592780917435974
In [10]:
correct_count, all_count = 0, 0
for images,labels in test_dl:
  for i in range(len(labels)):
    img = images[i].view(1, 784)
    with torch.no_grad():
        logps = model(img)

    ps = torch.exp(logps)
    probab = list(ps.numpy()[0])
    pred_label = probab.index(max(probab))
    true_label = labels.numpy()[i]
    if(true_label == pred_label):
      correct_count += 1
    all_count += 1

print("Number Of Images Tested =", all_count)
print("\nModel Accuracy =", (correct_count/all_count))
Number Of Images Tested = 10000

Model Accuracy = 0.8676
In [11]:
# build classical bayesian model
bayes_model = api.BasicBayesianWrapper(model, 'basic', 0.1, None, None)
In [12]:
# predict
n_iter = 5
bayes_model.predict(images, n_iter)
Out[12]:
{'mean': tensor([[2.9532e-11, 5.2978e-16, 8.1444e-11, 1.7176e-14, 9.9914e-01, 1.0547e-10,
          4.2251e-07, 3.6712e-07, 1.2748e-06, 8.5830e-04],
         [1.1691e-05, 4.5427e-19, 1.2401e-11, 2.8679e-11, 6.7911e-02, 8.0002e-10,
          8.9902e-11, 6.6425e-03, 2.2742e-02, 9.0269e-01],
         [3.4553e-11, 2.5543e-14, 1.0935e-09, 5.7184e-15, 9.9760e-01, 4.0230e-10,
          1.0476e-06, 2.0018e-06, 4.5529e-07, 2.3946e-03],
         [1.1240e-07, 1.3264e-07, 4.9493e-02, 9.4469e-01, 2.1763e-11, 7.1732e-09,
          5.9953e-12, 1.9604e-07, 5.8205e-03, 6.3144e-09],
         [6.4917e-07, 1.7414e-11, 6.5679e-06, 3.0981e-14, 3.3005e-05, 3.6413e-10,
          9.9993e-01, 4.1751e-09, 2.9818e-05, 8.0287e-07],
         [6.5410e-11, 6.6342e-12, 1.3535e-09, 2.4292e-14, 9.9905e-01, 1.4828e-09,
          8.2082e-06, 1.3887e-04, 6.4032e-06, 7.9413e-04],
         [3.4536e-11, 9.9948e-01, 4.6569e-05, 2.0753e-05, 1.9267e-09, 3.0979e-08,
          1.0521e-10, 1.5474e-04, 2.9737e-04, 8.0818e-08],
         [7.6249e-08, 2.1923e-21, 1.2633e-12, 5.8596e-11, 3.3252e-11, 1.5483e-11,
          1.8945e-14, 9.9996e-01, 1.1970e-06, 3.6505e-05],
         [2.4746e-08, 2.5182e-20, 3.7466e-01, 6.2241e-01, 1.0300e-22, 7.1300e-15,
          7.8195e-20, 2.9997e-06, 2.9281e-03, 4.5721e-16],
         [4.4059e-06, 3.3438e-10, 2.1919e-03, 1.0622e-08, 2.4677e-07, 2.4794e-09,
          9.9780e-01, 1.8539e-09, 9.9301e-08, 1.0529e-11],
         [3.8775e-07, 4.5346e-09, 3.2732e-01, 1.9245e-07, 1.0176e-10, 6.6516e-10,
          6.7267e-01, 1.9395e-06, 4.1945e-06, 9.6492e-08],
         [9.9698e-01, 2.1078e-25, 4.9260e-09, 8.6904e-08, 2.2775e-20, 2.0775e-14,
          6.3757e-16, 3.5064e-10, 3.0176e-03, 4.6168e-13],
         [7.1520e-11, 9.9198e-01, 6.9043e-06, 4.0900e-03, 2.6829e-09, 4.6651e-08,
          6.5929e-08, 1.4372e-03, 2.2930e-03, 1.9325e-04],
         [5.5255e-12, 5.9469e-11, 9.9917e-01, 2.7469e-04, 1.4915e-16, 4.7527e-13,
          8.6415e-14, 1.2295e-05, 5.4089e-04, 4.8728e-12],
         [7.5882e-12, 3.1069e-07, 1.0082e-06, 6.7307e-01, 5.6121e-14, 2.4525e-12,
          1.3250e-14, 3.4160e-10, 3.2693e-01, 4.9761e-10],
         [1.5300e-12, 7.8408e-19, 8.0296e-12, 2.2149e-12, 9.9957e-01, 2.8388e-11,
          2.0583e-09, 2.8162e-09, 1.0111e-05, 4.2022e-04],
         [1.0636e-04, 5.3497e-11, 4.1289e-08, 6.9251e-01, 8.5493e-09, 8.0031e-09,
          5.8158e-13, 1.5342e-09, 3.0738e-01, 1.4467e-08],
         [5.7821e-13, 1.5704e-21, 2.8036e-08, 1.0877e-19, 1.5467e-13, 7.6538e-19,
          1.0000e+00, 4.6054e-23, 1.4071e-10, 6.7137e-19],
         [1.7867e-14, 1.1333e-17, 6.0350e-08, 3.9653e-07, 2.8885e-20, 2.2173e-15,
          1.6792e-23, 1.0000e+00, 1.2793e-08, 4.4273e-09],
         [2.2531e-07, 2.0562e-16, 3.1036e-06, 1.2432e-01, 1.7421e-16, 8.4244e-13,
          7.1240e-18, 1.4506e-11, 8.7567e-01, 1.7567e-08],
         [1.7934e-06, 5.6103e-23, 6.5966e-12, 1.1016e-13, 1.3281e-01, 3.6958e-11,
          4.3662e-10, 4.8125e-04, 1.1116e-03, 8.6560e-01],
         [8.0582e-01, 1.0909e-19, 2.0707e-10, 6.6656e-08, 1.9221e-15, 2.1462e-14,
          9.4625e-08, 5.3875e-17, 1.9418e-01, 1.4121e-14],
         [1.2741e-12, 9.9996e-01, 5.2313e-06, 1.2778e-06, 2.7131e-12, 1.7222e-10,
          1.1467e-06, 3.5611e-09, 3.3198e-05, 1.0357e-09],
         [1.1162e-16, 1.0103e-17, 1.0000e+00, 1.8743e-06, 1.3710e-27, 1.8964e-20,
          3.2808e-18, 1.2958e-15, 4.9383e-09, 4.9819e-25],
         [6.3087e-13, 2.3703e-16, 7.0743e-08, 1.0000e+00, 4.8865e-25, 4.2783e-17,
          1.5532e-20, 1.4318e-11, 1.0737e-09, 3.8811e-18],
         [4.2325e-13, 6.2211e-16, 2.2513e-12, 6.0268e-12, 9.9894e-01, 1.1815e-10,
          2.1535e-09, 1.9690e-07, 2.1638e-04, 8.4165e-04],
         [8.0499e-03, 3.5356e-04, 1.4036e-04, 2.7409e-05, 2.0006e-02, 7.3161e-06,
          5.4193e-02, 4.8783e-06, 9.1722e-01, 8.5895e-07],
         [3.9207e-11, 1.5274e-26, 3.8151e-08, 1.0181e-21, 3.5139e-11, 3.2051e-19,
          1.0000e+00, 6.7910e-22, 8.7406e-12, 7.1344e-18]],
        grad_fn=<SelectBackward0>),
 'std': tensor([[4.7925e-11, 1.1722e-15, 1.4181e-10, 3.4199e-14, 1.7284e-03, 1.3500e-10,
          8.8716e-07, 3.9786e-07, 2.4465e-06, 1.7287e-03],
         [1.7169e-05, 9.9804e-19, 1.1622e-11, 5.7100e-11, 1.4875e-01, 1.4220e-09,
          1.9612e-10, 1.4485e-02, 5.0496e-02, 1.4791e-01],
         [4.3818e-11, 5.6730e-14, 1.3513e-09, 1.2557e-14, 5.0983e-03, 7.4148e-10,
          2.1789e-06, 3.1245e-06, 6.4612e-07, 5.0951e-03],
         [2.0008e-07, 2.9414e-07, 1.0632e-01, 1.0954e-01, 4.8663e-11, 1.1523e-08,
          8.3925e-12, 3.5558e-07, 8.3777e-03, 1.3800e-08],
         [1.0861e-06, 3.1384e-11, 1.4095e-05, 3.5292e-14, 3.3999e-05, 4.9820e-10,
          9.6448e-05, 7.0105e-09, 6.5681e-05, 1.1212e-06],
         [6.5112e-11, 1.3023e-11, 1.6513e-09, 5.0884e-14, 1.7283e-03, 1.7609e-09,
          1.7632e-05, 1.5617e-04, 1.2652e-05, 1.6612e-03],
         [4.2326e-11, 6.2184e-04, 9.0331e-05, 2.4987e-05, 3.1510e-09, 3.9686e-08,
          1.1450e-10, 1.5214e-04, 5.2036e-04, 1.7650e-07],
         [1.2575e-07, 3.6428e-21, 2.6019e-12, 1.2171e-10, 4.4922e-11, 3.4149e-11,
          4.1989e-14, 7.3088e-05, 2.6642e-06, 7.0291e-05],
         [5.2352e-08, 5.6310e-20, 4.6282e-01, 4.6573e-01, 2.3033e-22, 1.5927e-14,
          1.7477e-19, 6.7035e-06, 6.5472e-03, 1.0150e-15],
         [5.8167e-06, 3.5647e-10, 3.6039e-03, 1.0500e-08, 4.8324e-07, 3.8403e-09,
          3.6089e-03, 2.5534e-09, 2.1755e-07, 9.6248e-12],
         [7.4842e-07, 6.9095e-09, 4.2577e-01, 3.1649e-07, 2.0835e-10, 9.7648e-10,
          4.2577e-01, 3.2136e-06, 9.1822e-06, 1.8669e-07],
         [6.7430e-03, 4.7132e-25, 1.0522e-08, 1.4448e-07, 5.0553e-20, 2.7837e-14,
          1.3654e-15, 7.3977e-10, 6.7431e-03, 8.3217e-13],
         [5.8630e-11, 8.1378e-03, 1.1172e-05, 7.2165e-03, 5.8489e-09, 4.8659e-08,
          1.1402e-07, 1.0817e-03, 2.8049e-03, 4.1093e-04],
         [1.1167e-11, 1.3298e-10, 1.0722e-03, 4.1221e-04, 3.3350e-16, 9.1999e-13,
          1.8367e-13, 2.4293e-05, 1.1090e-03, 8.8584e-12],
         [1.4543e-11, 6.9453e-07, 1.2641e-06, 4.6243e-01, 1.2549e-13, 4.2679e-12,
          2.6758e-14, 4.2605e-10, 4.6243e-01, 7.6616e-10],
         [1.5618e-12, 1.7208e-18, 1.3017e-11, 3.6381e-12, 9.1010e-04, 3.6254e-11,
          4.5470e-09, 4.2433e-09, 1.9078e-05, 9.1356e-04],
         [1.7740e-04, 1.1625e-10, 6.2768e-08, 4.2865e-01, 1.9070e-08, 1.1822e-08,
          1.0361e-12, 1.9665e-09, 4.2852e-01, 2.9499e-08],
         [8.7616e-13, 2.4444e-21, 5.9493e-08, 1.5868e-19, 2.0463e-13, 1.3917e-18,
          5.3312e-08, 8.5570e-23, 3.1455e-10, 1.0424e-18],
         [3.0328e-14, 2.5217e-17, 1.2540e-07, 6.9571e-07, 6.4578e-20, 4.0572e-15,
          2.3519e-23, 8.2763e-07, 2.8305e-08, 8.4444e-09],
         [4.5363e-07, 4.5544e-16, 6.9381e-06, 2.5616e-01, 3.1772e-16, 1.8771e-12,
          1.5927e-17, 3.2298e-11, 2.5617e-01, 3.9281e-08],
         [3.8750e-06, 1.2510e-22, 9.9205e-12, 1.4337e-13, 2.5096e-01, 8.0319e-11,
          9.5395e-10, 1.0463e-03, 2.4851e-03, 2.5061e-01],
         [4.2089e-01, 2.4393e-19, 4.2671e-10, 1.0233e-07, 4.2787e-15, 2.5957e-14,
          1.6387e-07, 1.0790e-16, 4.2089e-01, 3.1468e-14],
         [1.0217e-12, 3.9416e-05, 9.3934e-06, 2.2460e-06, 5.8968e-12, 1.5237e-10,
          1.6955e-06, 2.7872e-09, 3.8172e-05, 1.4270e-09],
         [1.9022e-16, 2.2591e-17, 2.4912e-06, 2.4697e-06, 3.0656e-27, 3.4979e-20,
          6.2412e-18, 2.6981e-15, 1.1018e-08, 1.0853e-24],
         [1.4065e-12, 4.8949e-16, 1.5814e-07, 1.5994e-07, 1.0927e-24, 7.5065e-17,
          2.5006e-20, 2.6360e-11, 1.6734e-09, 8.6304e-18],
         [8.3632e-13, 1.2440e-15, 3.8195e-12, 8.5341e-12, 1.1175e-03, 2.3071e-10,
          4.7824e-09, 1.9505e-07, 4.1092e-04, 1.1906e-03],
         [1.7891e-02, 4.7673e-04, 3.1139e-04, 5.1004e-05, 4.3868e-02, 1.5333e-05,
          4.7196e-02, 1.0831e-05, 9.5765e-02, 1.9041e-06],
         [7.5856e-11, 3.3296e-26, 6.4628e-08, 2.2363e-21, 7.6273e-11, 6.3039e-19,
          5.3312e-08, 9.9086e-22, 1.9528e-11, 1.4239e-17]],
        grad_fn=<SelectBackward0>)}
In [13]:
# build bayesian model with beta distibution
bayes_model = api.BasicBayesianWrapper(model, 'beta', None, 0.6, 0.3)
In [14]:
# predict
n_iter = 5
bayes_model.predict(images, n_iter)
Out[14]:
{'mean': tensor([[0.0441, 0.0519, 0.0385, 0.0234, 0.6511, 0.0391, 0.0440, 0.0502, 0.0187,
          0.0391],
         [0.0441, 0.0520, 0.0385, 0.0234, 0.1724, 0.0391, 0.0441, 0.0503, 0.0183,
          0.5178],
         [0.0436, 0.0514, 0.0381, 0.0233, 0.6506, 0.0387, 0.0436, 0.0497, 0.0225,
          0.0387],
         [0.0441, 0.0519, 0.0511, 0.6107, 0.0511, 0.0391, 0.0440, 0.0502, 0.0188,
          0.0391],
         [0.0437, 0.0514, 0.0381, 0.0233, 0.0584, 0.0387, 0.6343, 0.0497, 0.0225,
          0.0400],
         [0.0436, 0.0514, 0.0381, 0.0233, 0.6506, 0.0387, 0.0436, 0.0497, 0.0225,
          0.0387],
         [0.0441, 0.6481, 0.0417, 0.0237, 0.0511, 0.0391, 0.0440, 0.0503, 0.0189,
          0.0391],
         [0.0441, 0.0520, 0.0385, 0.0234, 0.0512, 0.0391, 0.0441, 0.4600, 0.0182,
          0.2293],
         [0.0441, 0.0520, 0.4107, 0.2512, 0.0512, 0.0391, 0.0441, 0.0503, 0.0182,
          0.0391],
         [0.0436, 0.0514, 0.0409, 0.0233, 0.0506, 0.0387, 0.6407, 0.0497, 0.0225,
          0.0387],
         [0.0445, 0.0524, 0.5514, 0.0274, 0.0516, 0.0395, 0.0487, 0.0507, 0.0152,
          0.1186],
         [0.6441, 0.0520, 0.0385, 0.0234, 0.0512, 0.0391, 0.0441, 0.0503, 0.0183,
          0.0391],
         [0.0441, 0.6105, 0.0387, 0.0575, 0.0511, 0.0391, 0.0440, 0.0503, 0.0191,
          0.0458],
         [0.0441, 0.0519, 0.6283, 0.0335, 0.0511, 0.0391, 0.0440, 0.0502, 0.0187,
          0.0391],
         [0.0441, 0.0519, 0.0385, 0.6210, 0.0511, 0.0391, 0.0440, 0.0502, 0.0211,
          0.0391],
         [0.0436, 0.0514, 0.0381, 0.0233, 0.6506, 0.0387, 0.0436, 0.0497, 0.0225,
          0.0387],
         [0.0437, 0.0514, 0.0381, 0.5495, 0.0506, 0.0387, 0.0436, 0.0497, 0.0962,
          0.0387],
         [0.0436, 0.0514, 0.0381, 0.0233, 0.0506, 0.0387, 0.6436, 0.0497, 0.0225,
          0.0387],
         [0.0441, 0.0520, 0.1508, 0.1100, 0.0512, 0.0391, 0.0441, 0.4504, 0.0182,
          0.0400],
         [0.0436, 0.0514, 0.0381, 0.2216, 0.0506, 0.0387, 0.0436, 0.0497, 0.4241,
          0.0387],
         [0.0441, 0.0520, 0.0385, 0.0234, 0.2171, 0.0391, 0.0441, 0.0503, 0.0182,
          0.4732],
         [0.6425, 0.0514, 0.0381, 0.0233, 0.0506, 0.0387, 0.0436, 0.0497, 0.0236,
          0.0387],
         [0.0441, 0.6499, 0.0404, 0.0235, 0.0511, 0.0391, 0.0440, 0.0502, 0.0188,
          0.0391],
         [0.0441, 0.0519, 0.6353, 0.0265, 0.0511, 0.0391, 0.0440, 0.0502, 0.0187,
          0.0391],
         [0.0441, 0.0520, 0.0385, 0.6234, 0.0512, 0.0391, 0.0441, 0.0503, 0.0182,
          0.0391],
         [0.0441, 0.0520, 0.0385, 0.0234, 0.6511, 0.0391, 0.0441, 0.0503, 0.0183,
          0.0391],
         [0.0440, 0.0514, 0.0381, 0.0233, 0.1766, 0.0387, 0.0490, 0.0497, 0.4906,
          0.0387],
         [0.0445, 0.0524, 0.0388, 0.0235, 0.0516, 0.0394, 0.6444, 0.0507, 0.0152,
          0.0394]], grad_fn=<SelectBackward0>),
 'std': tensor([[0.0606, 0.0714, 0.0529, 0.0387, 0.4778, 0.0537, 0.0605, 0.0690, 0.0329,
          0.0537],
         [0.0607, 0.0715, 0.0530, 0.0387, 0.2507, 0.0538, 0.0606, 0.0691, 0.0329,
          0.4563],
         [0.0599, 0.0706, 0.0523, 0.0387, 0.4785, 0.0531, 0.0598, 0.0682, 0.0338,
          0.0531],
         [0.0606, 0.0714, 0.0453, 0.5049, 0.0702, 0.0537, 0.0605, 0.0690, 0.0328,
          0.0537],
         [0.0598, 0.0706, 0.0523, 0.0387, 0.0643, 0.0531, 0.4800, 0.0682, 0.0338,
          0.0519],
         [0.0599, 0.0706, 0.0523, 0.0387, 0.4785, 0.0531, 0.0598, 0.0682, 0.0338,
          0.0531],
         [0.0606, 0.4732, 0.0504, 0.0384, 0.0702, 0.0537, 0.0605, 0.0689, 0.0327,
          0.0537],
         [0.0607, 0.0715, 0.0530, 0.0387, 0.0703, 0.0538, 0.0606, 0.4927, 0.0329,
          0.4049],
         [0.0607, 0.0715, 0.4774, 0.4206, 0.0703, 0.0538, 0.0606, 0.0691, 0.0329,
          0.0538],
         [0.0599, 0.0706, 0.0499, 0.0386, 0.0694, 0.0531, 0.4855, 0.0682, 0.0338,
          0.0531],
         [0.0612, 0.0721, 0.4479, 0.0366, 0.0710, 0.0543, 0.0575, 0.0697, 0.0339,
          0.1628],
         [0.4873, 0.0715, 0.0530, 0.0387, 0.0703, 0.0538, 0.0606, 0.0691, 0.0329,
          0.0538],
         [0.0606, 0.4466, 0.0527, 0.0723, 0.0702, 0.0537, 0.0605, 0.0689, 0.0326,
          0.0495],
         [0.0606, 0.0714, 0.4862, 0.0376, 0.0702, 0.0537, 0.0605, 0.0690, 0.0329,
          0.0537],
         [0.0606, 0.0714, 0.0529, 0.5141, 0.0702, 0.0537, 0.0605, 0.0690, 0.0314,
          0.0537],
         [0.0599, 0.0706, 0.0523, 0.0387, 0.4785, 0.0531, 0.0598, 0.0682, 0.0338,
          0.0531],
         [0.0599, 0.0706, 0.0523, 0.4568, 0.0694, 0.0531, 0.0598, 0.0682, 0.0923,
          0.0531],
         [0.0599, 0.0706, 0.0523, 0.0387, 0.0694, 0.0531, 0.4881, 0.0682, 0.0338,
          0.0531],
         [0.0607, 0.0715, 0.2346, 0.1843, 0.0703, 0.0538, 0.0606, 0.5043, 0.0329,
          0.0530],
         [0.0599, 0.0706, 0.0523, 0.4319, 0.0694, 0.0531, 0.0598, 0.0682, 0.5262,
          0.0531],
         [0.0607, 0.0715, 0.0530, 0.0387, 0.3448, 0.0538, 0.0606, 0.0691, 0.0329,
          0.4793],
         [0.4869, 0.0706, 0.0523, 0.0387, 0.0694, 0.0531, 0.0598, 0.0682, 0.0329,
          0.0531],
         [0.0606, 0.4748, 0.0513, 0.0386, 0.0702, 0.0537, 0.0605, 0.0690, 0.0328,
          0.0537],
         [0.0606, 0.0714, 0.4922, 0.0369, 0.0702, 0.0537, 0.0605, 0.0690, 0.0329,
          0.0537],
         [0.0607, 0.0715, 0.0530, 0.5162, 0.0703, 0.0538, 0.0606, 0.0691, 0.0329,
          0.0538],
         [0.0607, 0.0715, 0.0530, 0.0387, 0.4777, 0.0538, 0.0606, 0.0691, 0.0329,
          0.0538],
         [0.0596, 0.0706, 0.0523, 0.0387, 0.2602, 0.0531, 0.0559, 0.0682, 0.4779,
          0.0531],
         [0.0612, 0.0721, 0.0534, 0.0387, 0.0709, 0.0543, 0.4869, 0.0697, 0.0339,
          0.0543]], grad_fn=<SelectBackward0>)}
In [ ]: