ec6adbea0cc543aa91ceda56475cc905

Self-supervised denoising using ensembles

Authors: Eric Roberts and Petrus Zwart

E-mail: PHZwart@lbl.gov, EJRoberts@lbl.gov

This notebook highlights some basic functionality with the dlsia package.

Here we will demonstrate how to use the randomized networks to perform self-supervised denoising of images corrupted by Gaussian noise. The denoising approach is based on training a neural network to minimize a Total Variation target. We will use an ensemble of networks to do so, will cluster these networks on the basis of their performance, and average the results. Because the total number of parameters in network is typically lower than the number of pixels, the data to parameter ratio can be more favourable. By using a ensemble approach, we can also produce an error estimates.

Imports and utility functions

[1]:
import sys
import os
import numpy as np
import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary

from dlsia.core import helpers
from dlsia.core.train_scripts import train_regression
from dlsia.core.networks import smsnet
from dlsia.viz_tools import plots

from torch.utils.data import DataLoader, TensorDataset
import einops
import matplotlib.pyplot as plt
import umap

from sklearn.cluster import KMeans
from scipy.signal import medfilt2d

Utility functions

A number of utility functions for this notebook provided here

[2]:
class random_atom_structure(object):
    """
    Build a random structure based from atoms.

    Args:
        n_atoms (int):        The number of peaks we add in.

        min_distance (float): The minimum distance to its
                              nearest neighbour.

        max_distance (float): The maximal distance of a new atom to
                              any other atom.

        gridXY (int):         The size of the grid on which we return.

        border (int):         The border-bumper: we don't want atoms
                              on the border.

    Return:
        An objkect that builds coordinates that fullfill the criteria
        stipulated above

    """

    def __init__(self,
                 n_atoms: int,
                 min_distance: float,
                 max_distance: float,
                 gridXY: int = 128,
                 border: int = 20,
                 max_tries = 100):
        """
        Args:
            n_atoms (int):        The number of peaks we add in.

            min_distance (float): The minimum distance to its
                                  nearest neighbour.

            max_distance (float): The maximal distance of a new atom to
                                  any other atom.

            gridXY (int):         The size of the grid on which we return.

            border (int):         The border-bumper: we don't want atoms
                                  on the border.

            max_tries (int):      How many tries we are willing to take.


        Return:
            An object that builds coordinates that fullfill the criteria
            stipulated above

        """

        self.n_atoms = n_atoms
        self.min_distance = min_distance
        self.max_distance = max_distance
        self.gridXY = gridXY
        self.border = border
        self.max_tries = max_tries

    def random_xy(self):
        """
        Get a random vector
        """
        this_x = np.random.uniform(self.border, self.gridXY-self.border)
        this_y = np.random.uniform(self.border, self.gridXY-self.border)
        return np.array([[this_x, this_y]])

    def draw(self):
        """
        Construct a random structure with parameters given above.
        Args:
            None
        Output:
            An np.array of the rights size
        """

        coordinates = None
        for ii in range(self.n_atoms):
            ok = False
            count=0
            while not ok:
                xy = self.random_xy()
                if coordinates is None:
                    coordinates = xy
                    ok = True
                else:
                    dd = coordinates - xy
                    dd = np.sum(dd*dd, axis=1)
                    observed_min_distance = np.min(dd)
                    if observed_min_distance > self.min_distance:
                        if observed_min_distance < self.max_distance:
                            ok = True
                            coordinates = np.concatenate([coordinates, xy])
                count += 1
                if count > self.max_tries:
                    ok = True

        mxy = np.mean(coordinates, axis=0)
        offset = np.array([self.gridXY,self.gridXY])/2.0
        return coordinates-mxy+offset


def build_density(coordinates: np.array,
                  sigma: float,
                  gridXY: int,
                  cut: float):
    """
    Take a set of coordinates and convert it into a density.

    Args:
        coordinates (np.array( (n_atoms, 2) )): coordinates

        sigma (float): the width of the gaussian

        gridXY (int): the grid size

    Return:
        A 2D array with the desired image

    """

    result = np.zeros((gridXY,gridXY))
    x = np.linspace(0,gridXY,gridXY)
    X,Y = np.meshgrid(x,x)
    for atom in coordinates:
        this_x, this_y = atom
        Dsq = (X-this_x)**2 + (Y-this_y)**2
        partial_map = np.exp( -Dsq / (2.0 * sigma**2.0 ) )
        result += partial_map
    sel = result > cut
    result[sel]=1.0
    result[~sel]=0.0
    return result

Create data

[3]:
n_atoms = 80
min_distance = 1500
max_distance = 5700
gridXY = 512
border = 60
sigma = 12.5
noise_level=1.50

RAS_obj=random_atom_structure(n_atoms=n_atoms,
                              min_distance=min_distance,
                              max_distance=max_distance,
                              gridXY=gridXY,
                              border=border)

xy = RAS_obj.draw()

this_map1 = build_density(coordinates=xy,
                         sigma=sigma,
                         gridXY=gridXY,
                         cut=0.75)

noise_map1 = (np.random.normal(0,noise_level, this_map1.shape))

this_map2 = build_density(coordinates=xy,
                         sigma=sigma,
                         gridXY=gridXY,
                         cut=0.75)

noise_map2 = (np.random.normal(0,noise_level, this_map2.shape))


View data

[4]:
plt.figure(figsize=(5,5))
plt.title("Noise Free Data")
plt.imshow(this_map1)
plt.colorbar(shrink=0.75)
plt.show()

plt.figure(figsize=(5,5))
plt.title("Noisy Data")
plt.imshow(this_map1+noise_map1)
plt.colorbar(shrink=0.75)
plt.show()

plt.figure(figsize=(5,5))
plt.title("Median filtered data")
plt.imshow(medfilt2d(this_map1+noise_map1, 11))
plt.colorbar(shrink=0.75)
plt.show()

../_images/tutorialLinks_denoising_selfSupervised_8_0.png
../_images/tutorialLinks_denoising_selfSupervised_8_1.png
../_images/tutorialLinks_denoising_selfSupervised_8_2.png

Create networks and train

Total variational loss

We define our own custom loss function for our networks to minimization. This loss function supplants the familiar mean square error and cross entropy losses commonly used for regression and segmentation/classification.

[5]:
class tv_target(nn.Module):
    def __init__(self, smooth_target, noise_target, weight=1.0):
        super(tv_target, self).__init__()
        self.smooth_target = smooth_target
        self.noise_target = noise_target
        self.weight = weight

    def forward(self,inp,x):
        shape = x.shape
        result = 0.0
        result += self.weight*self.smooth_target(x[:,:,:,1:],x[:,:,:,0:-1])
        result += self.weight*self.smooth_target(x[:,:,1:,:],x[:,:,0:-1,:])
        result += self.noise_target(inp, x)
        return result

[6]:
data1 = torch.Tensor(this_map1+noise_map1).unsqueeze(0).unsqueeze(0)
data2 = torch.Tensor(this_map2+noise_map2).unsqueeze(0).unsqueeze(0)
dataset = TensorDataset(data1)
dataloader = DataLoader(dataset)

Create networks

Hyperparameters governing the degree of randomness in network structures are defined first.

[7]:
in_channels=1
out_channels=1
num_layers = 10
alpha =0.50
gamma = 0
max_k = 5
min_k = 5
hidden_out_channels = [5]
dilation_choices = [1,2,3,4,5]
[8]:
nets = []
N_networks = 25

for ii in range(N_networks):

    layer_probabilities={'LL_alpha':alpha,
                         'LL_gamma': gamma,
                         'LL_max_degree':max_k,
                         'LL_min_degree':min_k,
                         'IL': 0.1,
                         'LO': 0.1,
                         'IO': False}
    sizing_settings = {'stride_base':2,
                       'min_power':0,
                       'max_power':0}
    dilation_mode = "Edges"
    network_type = "Regression_Sigmoid"
    network_mode = "Full"
    netSMS = smsnet.random_SMS_network(in_channels=in_channels,
                                       out_channels=out_channels,
                                       layers=num_layers,
                                       dilation_choices=dilation_choices,
                                       hidden_out_channels=hidden_out_channels,
                                       layer_probabilities=layer_probabilities,
                                       sizing_settings = sizing_settings,
                                       dilation_mode=dilation_mode,
                                       network_type=network_type,
                                       network_mode=network_mode
                                      )
    pytorch_total_params = sum(p.numel() for p in netSMS.parameters() if p.requires_grad)
    print("Total number of refineable parameters: ", pytorch_total_params)
    nets.append(netSMS)
Total number of refineable parameters:  13140
Total number of refineable parameters:  13180
Total number of refineable parameters:  9945
Total number of refineable parameters:  16755
Total number of refineable parameters:  11345
Total number of refineable parameters:  18795
Total number of refineable parameters:  11505
Total number of refineable parameters:  17165
Total number of refineable parameters:  13955
Total number of refineable parameters:  12225
Total number of refineable parameters:  14340
Total number of refineable parameters:  15415
Total number of refineable parameters:  12805
Total number of refineable parameters:  10725
Total number of refineable parameters:  12080
Total number of refineable parameters:  13505
Total number of refineable parameters:  10860
Total number of refineable parameters:  11117
Total number of refineable parameters:  13210
Total number of refineable parameters:  9565
Total number of refineable parameters:  9835
Total number of refineable parameters:  14325
Total number of refineable parameters:  11600
Total number of refineable parameters:  10910
Total number of refineable parameters:  18070

Training loop

We define a custom training lsoop first.

[9]:
def simple_train_loop(net, dataloader, optim, crit, epochs, device):
    for ii in range(epochs):
        value = 0.0
        for batch in dataloader:
            inp = batch[0]
            inp = inp.to(device)

            x = net(inp)

            loss = crit(inp,x)
            value += loss.item()
            if (ii+1) % epochs ==0:
                print("Epoch", ii, "Loss", loss.item())

            optim.zero_grad()
            loss.backward()
            optim.step()
    return value
[10]:
obtained_target_values = []
for net in nets:
    epochs = 50
    criterion_noise = nn.MSELoss()
    criterion_tv = nn.L1Loss()
    weight=2.5
    combo_crit = tv_target(criterion_tv,criterion_noise, weight)
    LEARNING_RATE = 1e-2
    optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
    device = helpers.get_device()
    res = simple_train_loop(net.to(device),
                      dataloader,
                      optimizer,
                      combo_crit,
                      epochs,
                      device)
    obtained_target_values.append(res)
Epoch 49 Loss 2.2903685569763184
Epoch 49 Loss 2.293128252029419
Epoch 49 Loss 2.2999167442321777
Epoch 49 Loss 2.3059940338134766
Epoch 49 Loss 2.2927701473236084
Epoch 49 Loss 2.2938392162323
Epoch 49 Loss 2.2998194694519043
Epoch 49 Loss 2.2990071773529053
Epoch 49 Loss 2.291795492172241
Epoch 49 Loss 2.2963905334472656
Epoch 49 Loss 2.2961220741271973
Epoch 49 Loss 2.294356346130371
Epoch 49 Loss 2.2918319702148438
Epoch 49 Loss 2.2986533641815186
Epoch 49 Loss 2.2950170040130615
Epoch 49 Loss 2.3031492233276367
Epoch 49 Loss 2.293179512023926
Epoch 49 Loss 2.301177740097046
Epoch 49 Loss 2.296251058578491
Epoch 49 Loss 2.314724922180176
Epoch 49 Loss 2.2965447902679443
Epoch 49 Loss 2.29355788230896
Epoch 49 Loss 2.311861276626587
Epoch 49 Loss 2.293119430541992
Epoch 49 Loss 2.2987704277038574

Reconstruction and network evaluation

[11]:
recons = []
with torch.no_grad():
    for net in nets:
        dn = net.cpu()(data1)
        recons.append(dn.numpy()[0,0])
recons = einops.rearrange(recons, "N Y X -> N Y X")
[12]:
cc_mat = np.zeros((N_networks,N_networks))
for ii in range(N_networks):
    for jj in range(N_networks):
        this_cc = np.corrcoef( recons[ii].flatten(), recons[jj].flatten())
        cc_mat[ii,jj]=this_cc[0,1]

plt.imshow(cc_mat, vmin=0.85,vmax=1.0)
plt.title("Correlation between filtered images")
plt.colorbar()
plt.show()

def cc_metric(a,b):
    results = []
    for aa,bb in zip(a,b):
        tmp = (1.0 - cc_mat[int(aa),int(bb)])
        results.append(tmp)
    results = np.array(results)
    return results

X = np.arange(N_networks).astype(int).reshape(-1,1)
mapper = umap.UMAP(min_dist=0, n_neighbors=2, metric=cc_metric)
U = mapper.fit_transform(X)
plt.plot(U[:,0], U[:,1], '.')
plt.show()
../_images/tutorialLinks_denoising_selfSupervised_21_0.png
/home/ejroberts/anaconda3/envs/dlsia/lib/python3.9/site-packages/umap/umap_.py:1772: UserWarning: custom distance metric does not return gradient; inverse_transform will be unavailable. To enable using inverse_transform method, define a distance function that returns a tuple of (distance [float], gradient [np.array])
  warn(
../_images/tutorialLinks_denoising_selfSupervised_21_2.png
[13]:
n_cluster=1
cobj = KMeans(n_clusters=n_cluster)
cluster = cobj.fit_predict(U)

means = []
stds = []
cluster_size = []
cluster_name = np.arange(n_cluster).astype(int)

for ii in range(n_cluster):
    sel = cluster==ii
    plt.plot(U[sel,0], U[sel,1], '.')
    m = np.mean(recons[sel,...], axis=0)
    s = np.std(recons[sel,...], axis=0)
    means.append(m)
    stds.append(s)
    cluster_size.append(np.sum(sel))
plt.legend(cluster_name)
plt.show()


for m,s,ss,cn in zip(means, stds, cluster_size, cluster_name):
    plt.figure(figsize=(5,5))
    plt.imshow(m)
    plt.title("Cluster %i - %i members"%(cn,ss))
    plt.colorbar(shrink=0.75)
/home/ejroberts/anaconda3/envs/dlsia/lib/python3.9/site-packages/sklearn/cluster/_kmeans.py:870: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning
  warnings.warn(
../_images/tutorialLinks_denoising_selfSupervised_22_1.png
../_images/tutorialLinks_denoising_selfSupervised_22_2.png
[14]:
sel = np.where(cluster == 0)[0]
print(sel)
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24]
[15]:
these_nets = []
for ii in sel:
    these_nets.append(nets[ii])
[16]:
new_recons = []
with torch.no_grad():
    for net in nets:
        dn = net.cpu()(data2)
        dn = dn / torch.max(dn)
        new_recons.append(dn.numpy()[0,0])
new_recons = einops.rearrange(new_recons, "N Y X -> N Y X")
new_mean = np.mean(new_recons, axis=0)
new_std = np.std(new_recons, axis=0)
[17]:
plt.figure(figsize=(5,5))
plt.imshow(new_mean)
plt.title("Denoised - Mean")
plt.colorbar(shrink=0.75)
plt.show()

plt.figure(figsize=(5,5))
plt.imshow(new_std)
plt.title("Denoised - STD")
plt.colorbar(shrink=0.75)
plt.show()


plt.figure(figsize=(5,5))
plt.imshow( this_map2+noise_map2 )
plt.title("Input Noisy Data")
plt.colorbar(shrink=0.75)
plt.show()
../_images/tutorialLinks_denoising_selfSupervised_26_0.png
../_images/tutorialLinks_denoising_selfSupervised_26_1.png
../_images/tutorialLinks_denoising_selfSupervised_26_2.png
[ ]: