189733e2f7d74eb5ab507723e043baa5

Supervised Image Denosing with MSDNets and SMSNet ensembles

Authors: Eric Roberts and Petrus Zwart

E-mail: PHZwart@lbl.gov, EJRoberts@lbl.gov

This notebook highlights some basic functionality with the dlsia package.

In this notebook we setup a Mixed-Scaled Dense Network (MSDNet) and train it to denoising image corrupted by Gaussian noise. Subsequently, we will train a number Randomized Sparse Networks (RMSNets) on the same task and show how to obtain error estimates via ensemble methods.

[1]:
import sys
import os
import numpy as np
import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset

import h5py
from torchsummary import summary
[2]:
from dlsia.core import helpers
from dlsia.core.train_scripts import train_regression
from dlsia.core.networks import msdnet, smsnet, baggins
from dlsia.test_data.two_d import build_test_data, torch_hdf5_loader
from dlsia.viz_tools import plots, draw_sparse_network

import matplotlib.pyplot as plt

Create Data

We produce noisy, single-class data consisting of peaks in guassian noise. This data will be generated using hdf5 files .

Parameters to toggle:

  • batch_size : number of images passed through GPU at a single time

  • n_imgs : number of images in training set

  • n_peaks : number of circular peaks in each image

  • n_xy : size of images

  • snr : signal-to-noise ratio; more noise for lower number

[3]:
makeData = True
batch_size = 100
num_workers = 0
showNoisyData = True
use_scaled_data = True

Generate hdf5 peak data

[4]:
if makeData == True:
    n_imgs = 200
    n_peaks = 8
    n_xy = 32
    snr=3
    mask_radius = 1.0

    build_test_data.build_data_standard_sets_2d(n_imgs=n_imgs,
                                                n_peaks=n_peaks,
                                                n_xy=n_xy,
                                                snr=snr,
                                                mask_radius=mask_radius)
100%|████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 269.52it/s]
100%|████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 465.98it/s]
100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 60.96it/s]

Load hdf5 peak data

Data generator class above can generate the following:

  • trax_GT : ground truth

  • trax_obs : obstructed, noisy images

  • trax_obs_norm : noisy images linearly scaled to interal [0,1]

  • trax_mask : binary masked images indicating peak (1) or background (0)

[5]:
if use_scaled_data == True:
    x_label = "trax_obs_norm"
else:
    x_label = "trax_obs"

f_train = "train_data_2d.hdf5"
f_test  = "test_data_2d.hdf5"
f_validation = "validate_data_2d.hdf5"

MyData_train = torch_hdf5_loader.Hdf5Dataset2D(filename=f_train,
                                               x_label=x_label,
                                               y_label="trax_GT")
MyData_validation = torch_hdf5_loader.Hdf5Dataset2D(filename=f_validation,
                                                    x_label=x_label,
                                                    y_label="trax_GT")
MyData_test = torch_hdf5_loader.Hdf5Dataset2D(filename=f_test,
                                              x_label=x_label,
                                              y_label="trax_GT")

loader_params = {'batch_size': batch_size, 'shuffle': True, 'num_workers': num_workers}
train_loader = DataLoader(MyData_train, **loader_params)
loader_params = {'batch_size': batch_size, 'shuffle': False, 'num_workers': num_workers}
validation_loader = DataLoader(MyData_validation, **loader_params)
test_loader = DataLoader(MyData_test, **loader_params)

View peak data

[6]:
if showNoisyData == True:
    for batch in train_loader:
        noisy, mask = batch

        plt.figure(figsize=(12,10))
        plt.subplot(321)
        plt.imshow(noisy[0,0,:,:]); plt.colorbar(shrink=0.8); plt.title('Noisy');
        plt.subplot(322);
        plt.imshow(mask[0,0,:,:]); plt.colorbar(shrink=0.8); plt.title('Mask');
        plt.subplot(323)
        plt.imshow(noisy[1,0,:,:]); plt.colorbar(shrink=0.8)
        plt.subplot(324);
        plt.imshow(mask[1,0,:,:]); plt.colorbar(shrink=0.8)
        plt.subplot(325)
        plt.imshow(noisy[2,0,:,:]); plt.colorbar(shrink=0.8)
        plt.subplot(326);
        plt.imshow(mask[2,0,:,:]); plt.colorbar(shrink=0.8)

        break

    plt.rcParams.update({'font.size': 18})
    plt.tight_layout()
    plt.show()
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_10_0.png

Create and train MSDNet

Lots of options to customize. See dlsia/core/networks/msdnet.py

[7]:
in_channels = 1
out_channels = 1
num_layers = 40
max_dilation = 8
activation = nn.ReLU()
normalization = nn.BatchNorm2d

msdnet_model = msdnet.MixedScaleDenseNetwork(in_channels = in_channels,
                                             out_channels = out_channels,
                                             num_layers=num_layers,
                                             max_dilation = max_dilation,
                                             activation=activation,
                                             normalization=normalization,
                                            )

pytorch_total_params = sum(p.numel() for p in msdnet_model.parameters() if p.requires_grad)
print("Total number of refineable parameters: ", pytorch_total_params)
Total number of refineable parameters:  9182

Train the model

We perform the following:

  • specify number of epochs,

  • select the L2 MSE loss as our scroing criteria,

  • select a learning rate of 1/1000

  • choose the popular Adam optimizer for traversing the loss terrain

[8]:
epochs = 100   # set number of epochs
criterion = nn.MSELoss()
learning_rate = 1e-2

optimizer = optim.Adam(msdnet_model.parameters(), lr=learning_rate)
[9]:
device = helpers.get_device()
msdnet_model.to(device)
msdnet_model, results = train_regression(msdnet_model,
                                         train_loader,
                                         validation_loader,
                                         epochs,
                                         criterion,
                                         optimizer,
                                         device=device,
                                         show=10
                                        )
torch.cuda.empty_cache()
Epoch 10 of 100 | Learning rate 1.000e-02
Training Loss: 2.0480e-02 | Validation Loss: 1.9214e-02
Training CC: 0.6652   Validation CC  : 0.6940
Epoch 20 of 100 | Learning rate 1.000e-02
Training Loss: 9.7999e-03 | Validation Loss: 9.5666e-03
Training CC: 0.8558   Validation CC  : 0.8613
Epoch 30 of 100 | Learning rate 1.000e-02
Training Loss: 7.6222e-03 | Validation Loss: 7.7206e-03
Training CC: 0.8899   Validation CC  : 0.8896
Epoch 40 of 100 | Learning rate 1.000e-02
Training Loss: 6.4154e-03 | Validation Loss: 6.6536e-03
Training CC: 0.9082   Validation CC  : 0.9057
Epoch 50 of 100 | Learning rate 1.000e-02
Training Loss: 5.4005e-03 | Validation Loss: 5.6188e-03
Training CC: 0.9232   Validation CC  : 0.9210
Epoch 60 of 100 | Learning rate 1.000e-02
Training Loss: 4.6870e-03 | Validation Loss: 4.9463e-03
Training CC: 0.9338   Validation CC  : 0.9308
Epoch 70 of 100 | Learning rate 1.000e-02
Training Loss: 4.3836e-03 | Validation Loss: 4.5876e-03
Training CC: 0.9383   Validation CC  : 0.9360
Epoch 80 of 100 | Learning rate 1.000e-02
Training Loss: 4.0819e-03 | Validation Loss: 4.4152e-03
Training CC: 0.9426   Validation CC  : 0.9385
Epoch 90 of 100 | Learning rate 1.000e-02
Training Loss: 3.9863e-03 | Validation Loss: 4.3827e-03
Training CC: 0.9440   Validation CC  : 0.9390
Epoch 100 of 100 | Learning rate 1.000e-02
Training Loss: 3.9378e-03 | Validation Loss: 4.3817e-03
Training CC: 0.9447   Validation CC  : 0.9390
[10]:
plots.plot_training_results_regression(results)

Create and train SMSNets

Sparse versions of MSDNets with randomly connected layers and dilations are created here

Specify hyperparameters

Next, the hyperparameters below govern the random network connectivity. Choices include:

  • alpha : modifies distribution of consecutive connection length between network layers/nodes,

  • gamma : modifies distribution of of layer/node degree,

  • IL : probability of connection between Input node and Layer node,

  • IO : probability of connection between Input node and Output node,

  • LO : probability of connection between Layer node and Output node,

  • dilation_choices : set of possible dilations along each individual node connection

The specific parameters and what they do are described in detail in the documentation. Please follow minor minor comments below for a more cursory explanation.

[11]:
in_channels = 1 # RGB input image
out_channels = 1 # binary output
num_layers = 40

# When alpha > 0, short-range skip connections are favoured
alpha = 0.20

# When gamma is 0, the degree of each node is chosen uniformly between 0 and max_k
# specifically, P(degree) \propto degree^-gamma
gamma = 0.0

# we can limit the maximum and minimum degree of our graph
max_k = 6
min_k = 3

# features channel posibilities per edge
hidden_out_channels = [1]

# possible dilation choices
dilation_choices = [1,2,3,4,5]

# Here are some parameters that define how networks are drawn at random
# the layer probabilities dictionairy define connections
layer_probabilities={'LL_alpha':alpha,
                     'LL_gamma': gamma,
                     'LL_max_degree':max_k,
                     'LL_min_degree':min_k,
                     'IL': 0.25,
                     'LO': 0.25,
                     'IO': False}

# if desired, one can introduce scale changes (down and upsample)
# a not-so-thorough look indicates that this isn't really super beneficial
# in the model systems we looked at
sizing_settings = {'stride_base':2, #better keep this at 2
                   'min_power': 0,
                   'max_power': 0}

# defines the type of network we want to build

network_type = "Regression"

Build networks and train

We specify the number of random networks to initialize and the number of epochs for each is trained.

[12]:
nets = []
n_networks = 7
epochs = 50     # Set number of epochs
criterion = nn.MSELoss()   # For segmenting
learning_rate = 1e-2

for ii in range(n_networks):
    torch.cuda.empty_cache()
    print("Network %i"%(ii+1))
    smsnet_model = smsnet.random_SMS_network(in_channels=in_channels,
                                             out_channels=out_channels,
                                             in_shape=(32,32),
                                             out_shape=(32,32),
                                             sizing_settings=sizing_settings,
                                             layers=num_layers,
                                             dilation_choices=dilation_choices,
                                             hidden_out_channels=hidden_out_channels,
                                             layer_probabilities=layer_probabilities,
                                             network_type=network_type
                                            )

    # lets plot the network
    net_plot,dil_plot,chan_plot = draw_sparse_network.draw_network(smsnet_model)
    plt.show()

    nets.append(smsnet_model)

    print("Start training")
    pytorch_total_params = sum(p.numel() for p in smsnet_model.parameters() if p.requires_grad)
    print("Total number of refineable parameters: ", pytorch_total_params)

    optimizer = optim.Adam(smsnet_model.parameters(), lr=learning_rate)  # Defined in loop, one per network
    device = helpers.get_device()
    smsnet_model = smsnet_model.to(device)
    tmp = train_regression(smsnet_model,
                           train_loader,
                           test_loader,
                           epochs,
                           criterion,
                           optimizer,
                           device,
                           show=10)
    smsnet_model = smsnet_model.cpu()
    plots.plot_training_results_regression(tmp[1]).show()
Network 1
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_1.png
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_2.png
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_3.png
Start training
Total number of refineable parameters:  3939
Epoch 10 of 50 | Learning rate 1.000e-02
Training Loss: 9.2175e-03 | Validation Loss: 8.0032e-03
Training CC: 0.8686   Validation CC  : 0.8800
Epoch 20 of 50 | Learning rate 1.000e-02
Training Loss: 6.0803e-03 | Validation Loss: 6.5273e-03
Training CC: 0.9140   Validation CC  : 0.9042
Epoch 30 of 50 | Learning rate 1.000e-02
Training Loss: 5.1456e-03 | Validation Loss: 5.8442e-03
Training CC: 0.9271   Validation CC  : 0.9151
Epoch 40 of 50 | Learning rate 1.000e-02
Training Loss: 4.6548e-03 | Validation Loss: 5.4656e-03
Training CC: 0.9343   Validation CC  : 0.9213
Epoch 50 of 50 | Learning rate 1.000e-02
Training Loss: 4.3191e-03 | Validation Loss: 5.2219e-03
Training CC: 0.9392   Validation CC  : 0.9249
Network 2
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_7.png
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_8.png
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_9.png
Start training
Total number of refineable parameters:  3990
Epoch 10 of 50 | Learning rate 1.000e-02
Training Loss: 9.1804e-03 | Validation Loss: 8.0699e-03
Training CC: 0.8771   Validation CC  : 0.8871
Epoch 20 of 50 | Learning rate 1.000e-02
Training Loss: 6.4843e-03 | Validation Loss: 6.6771e-03
Training CC: 0.9072   Validation CC  : 0.9031
Epoch 30 of 50 | Learning rate 1.000e-02
Training Loss: 5.6728e-03 | Validation Loss: 6.0229e-03
Training CC: 0.9194   Validation CC  : 0.9134
Epoch 40 of 50 | Learning rate 1.000e-02
Training Loss: 5.1018e-03 | Validation Loss: 5.5597e-03
Training CC: 0.9277   Validation CC  : 0.9200
Epoch 50 of 50 | Learning rate 1.000e-02
Training Loss: 4.6591e-03 | Validation Loss: 5.3416e-03
Training CC: 0.9344   Validation CC  : 0.9233
Network 3
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_13.png
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_14.png
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_15.png
Start training
Total number of refineable parameters:  3669
Epoch 10 of 50 | Learning rate 1.000e-02
Training Loss: 1.2144e-02 | Validation Loss: 1.2543e-02
Training CC: 0.8506   Validation CC  : 0.8502
Epoch 20 of 50 | Learning rate 1.000e-02
Training Loss: 6.7236e-03 | Validation Loss: 7.7115e-03
Training CC: 0.9041   Validation CC  : 0.8868
Epoch 30 of 50 | Learning rate 1.000e-02
Training Loss: 5.6058e-03 | Validation Loss: 6.9776e-03
Training CC: 0.9203   Validation CC  : 0.8990
Epoch 40 of 50 | Learning rate 1.000e-02
Training Loss: 5.0782e-03 | Validation Loss: 6.8975e-03
Training CC: 0.9282   Validation CC  : 0.9011
Epoch 50 of 50 | Learning rate 1.000e-02
Training Loss: 4.6339e-03 | Validation Loss: 6.3910e-03
Training CC: 0.9346   Validation CC  : 0.9080
Network 4
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_19.png
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_20.png
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_21.png
Start training
Total number of refineable parameters:  3663
Epoch 10 of 50 | Learning rate 1.000e-02
Training Loss: 1.3449e-02 | Validation Loss: 1.3231e-02
Training CC: 0.7986   Validation CC  : 0.8016
Epoch 20 of 50 | Learning rate 1.000e-02
Training Loss: 8.0801e-03 | Validation Loss: 8.5300e-03
Training CC: 0.8832   Validation CC  : 0.8727
Epoch 30 of 50 | Learning rate 1.000e-02
Training Loss: 6.4468e-03 | Validation Loss: 7.3383e-03
Training CC: 0.9077   Validation CC  : 0.8927
Epoch 40 of 50 | Learning rate 1.000e-02
Training Loss: 5.5732e-03 | Validation Loss: 6.9593e-03
Training CC: 0.9207   Validation CC  : 0.8992
Epoch 50 of 50 | Learning rate 1.000e-02
Training Loss: 4.9964e-03 | Validation Loss: 6.5431e-03
Training CC: 0.9292   Validation CC  : 0.9049
Network 5
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_25.png
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_26.png
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_27.png
Start training
Total number of refineable parameters:  3462
Epoch 10 of 50 | Learning rate 1.000e-02
Training Loss: 1.1155e-02 | Validation Loss: 1.0104e-02
Training CC: 0.8359   Validation CC  : 0.8501
Epoch 20 of 50 | Learning rate 1.000e-02
Training Loss: 7.0713e-03 | Validation Loss: 7.0207e-03
Training CC: 0.8985   Validation CC  : 0.8984
Epoch 30 of 50 | Learning rate 1.000e-02
Training Loss: 5.5692e-03 | Validation Loss: 5.8999e-03
Training CC: 0.9208   Validation CC  : 0.9132
Epoch 40 of 50 | Learning rate 1.000e-02
Training Loss: 4.8418e-03 | Validation Loss: 5.3266e-03
Training CC: 0.9315   Validation CC  : 0.9219
Epoch 50 of 50 | Learning rate 1.000e-02
Training Loss: 4.3537e-03 | Validation Loss: 4.8718e-03
Training CC: 0.9387   Validation CC  : 0.9287
Network 6
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_31.png
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_32.png
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_33.png
Start training
Total number of refineable parameters:  3619
Epoch 10 of 50 | Learning rate 1.000e-02
Training Loss: 1.2206e-02 | Validation Loss: 1.0908e-02
Training CC: 0.8169   Validation CC  : 0.8338
Epoch 20 of 50 | Learning rate 1.000e-02
Training Loss: 7.2262e-03 | Validation Loss: 7.5402e-03
Training CC: 0.8960   Validation CC  : 0.8896
Epoch 30 of 50 | Learning rate 1.000e-02
Training Loss: 5.7538e-03 | Validation Loss: 6.8879e-03
Training CC: 0.9181   Validation CC  : 0.8989
Epoch 40 of 50 | Learning rate 1.000e-02
Training Loss: 4.9582e-03 | Validation Loss: 6.1591e-03
Training CC: 0.9298   Validation CC  : 0.9098
Epoch 50 of 50 | Learning rate 1.000e-02
Training Loss: 4.5296e-03 | Validation Loss: 5.7084e-03
Training CC: 0.9362   Validation CC  : 0.9171
Network 7
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_37.png
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_38.png
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_20_39.png
Start training
Total number of refineable parameters:  3326
Epoch 10 of 50 | Learning rate 1.000e-02
Training Loss: 1.3265e-02 | Validation Loss: 1.2832e-02
Training CC: 0.8193   Validation CC  : 0.8177
Epoch 20 of 50 | Learning rate 1.000e-02
Training Loss: 6.9772e-03 | Validation Loss: 7.6696e-03
Training CC: 0.9005   Validation CC  : 0.8864
Epoch 30 of 50 | Learning rate 1.000e-02
Training Loss: 5.6950e-03 | Validation Loss: 6.9365e-03
Training CC: 0.9193   Validation CC  : 0.8983
Epoch 40 of 50 | Learning rate 1.000e-02
Training Loss: 5.0621e-03 | Validation Loss: 6.5130e-03
Training CC: 0.9284   Validation CC  : 0.9049
Epoch 50 of 50 | Learning rate 1.000e-02
Training Loss: 4.6441e-03 | Validation Loss: 6.2002e-03
Training CC: 0.9344   Validation CC  : 0.9100

Testing our models

Finally, we load testing data, pass it through the network, and save results as .png

Helper functions

[13]:
def regression_metrics( preds, target):
    """
    Here, the Pearson correlation coefficient is calulated between the network
    predictions and the ground truth.
    """
    tmp = corcoef.cc(preds.cpu().flatten(), target.cpu().flatten() )
    return(tmp)


def segment_imgs(testloader, net, num_display=10, plot=True, std=False):
    """
    This function makes network predictions on testing data found in the 'testloader'
    pytorch dataloader object.

    :param testloader: the pyTorch dataloader object used to retrieve testing data
    :param net: the trained deep network
    :param plot: do you want to plot the first 10 network results using matplotlib?

    :returns seg_imgs: the predicted images, concatenated into a single tensor
    :returns noisy_imgs: the input images, concatenated into a single tensor
    :returns target_imgs: the ground truth images, concatenated into a single tensor
    """
    torch.cuda.empty_cache()

    seg_imgs = []
    noisy_imgs = []
    #target_imgs = []

    #running_CC_test_val = 0.0

    counter = 0
    with torch.no_grad():
        for batch in testloader:
            num_display = np.minimum(num_display, len(batch[0]))
            noisy, target = batch

            noisy = torch.FloatTensor(noisy)
            noisy = noisy.to(device)#.unsqueeze(1)

            sigmas = None
            if not std:
                output = net.to(device)(noisy)
            else:
                output, sigmas = net.to(device)(noisy, 'cpu', True)

            if counter == 0:
                seg_imgs = output.detach().cpu()
                noisy_imgs = noisy.detach().cpu()
                target_imgs = target.detach().cpu()
                if std:
                    sigmas = sigmas.detach().cpu()
            else:
                seg_imgs = torch.cat((seg_imgs, output.detach().cpu()), 0)
                noisy_imgs = torch.cat((noisy_imgs, noisy.detach().cpu()), 0)
                target_imgs = torch.cat((target_imgs, target.detach().cpu()), 0)
                if std:
                    sigmas = sigmas.detach().cpu()


            counter+=1

            if plot==True:
                for j in range(5):
                    if not std:
                        print(f'Images for batch # {counter}, number {j}')
                        plt.figure(figsize=(22,5))
                        plt.rcParams.update({'font.size': 20})

                        plt.subplot(131)
                        plt.imshow(noisy.cpu()[j,0,:,:].data); plt.colorbar(shrink=0.8); plt.title('Noisy');
                        plt.subplot(132)
                        plt.imshow(output[j,0,:,:].detach().cpu()); plt.colorbar(shrink=0.8); plt.title('Prediction');
                        plt.subplot(133)
                        plt.imshow(target.cpu()[j,0,:,:].data); plt.colorbar(shrink=0.8); plt.title('Ground Truth');

                        plt.tight_layout()

                        plt.show()
                    else:
                        print(f'Images for batch # {counter}, number {j}')
                        plt.figure(figsize=(22,5))
                        plt.rcParams.update({'font.size': 20})

                        plt.subplot(151)
                        plt.imshow(noisy.cpu()[j,0,:,:].data);
                        plt.colorbar(shrink=0.8); plt.title('Noisy');
                        plt.subplot(152)
                        plt.imshow(output[j,0,:,:].detach().cpu());
                        plt.colorbar(shrink=0.8); plt.title('Prediction');
                        plt.subplot(153)
                        plt.imshow(sigmas[j,0,:,:].detach().cpu());
                        plt.colorbar(shrink=0.8); plt.title('Sigmas');
                        plt.subplot(154)
                        plt.imshow(output[j,0,:,:].detach().cpu() / sigmas[j,0,:,:].detach().cpu(),vmax=30);
                        plt.colorbar(shrink=0.8); plt.title('Signal to Noise');

                        plt.subplot(155)
                        plt.imshow(target.cpu()[j,0,:,:].data); plt.colorbar(shrink=0.8); plt.title('Ground Truth');

                        plt.tight_layout()

                        plt.show()




    #CC = running_CC_test_val / len(testloader)
    torch.cuda.empty_cache()
    return seg_imgs, noisy_imgs, target_imgs

MSDNet predictions

[14]:
output, noisy, target  = segment_imgs(test_loader, msdnet_model, num_display=3)
Images for batch # 1, number 0
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_24_1.png
Images for batch # 1, number 1
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_24_3.png
Images for batch # 1, number 2
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_24_5.png
Images for batch # 1, number 3
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_24_7.png
Images for batch # 1, number 4
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_24_9.png

RMSNet predictions

The multiple RMSNets are ‘ensembled’ together using bagging. Displayed here are:

  • Noisy/Prediction: original noisy images and denoised network ensemble predictions,

  • Sigmas: standard deviation of denoising predictions

  • Signal to Noise: Prediction / Sigmas

[15]:
bagged_model = baggins.model_baggin(nets,"regression", False)
[16]:
output, noisy, target  = segment_imgs(test_loader, bagged_model, std=True)
Images for batch # 1, number 0
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_27_1.png
Images for batch # 1, number 1
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_27_3.png
Images for batch # 1, number 2
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_27_5.png
Images for batch # 1, number 3
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_27_7.png
Images for batch # 1, number 4
../_images/tutorialLinks_denoising_MSDNet_SMSNetEnsemble_27_9.png