Convolutional Diffractive Network

This notebook is based on the article “Optical Diffractive Convolutional Neural Networks Implemented in an All-Optical Way” [1].

… combining the 4f system as an optical convolutional layer and the diffractive networks

Imports

[1]:
import os
import sys
import random
[2]:
import time
import json
[3]:
# import warnings
# warnings.simplefilter("always")  # always show warnings!
[4]:
import numpy as np
[5]:
from collections import Counter
[6]:
import torch
from torch.utils.data import Dataset
[7]:
from torch import nn
[8]:
from torch.nn import functional
[9]:
import torchvision
import torchvision.transforms as transforms
[10]:
from torchvision.transforms import InterpolationMode
[11]:
# our library
from svetlanna import SimulationParameters
from svetlanna.parameters import ConstrainedParameter
[12]:
# our library
from svetlanna import Wavefront
from svetlanna import elements
from svetlanna.detector import Detector, DetectorProcessorClf
[13]:
from svetlanna.transforms import ToWavefront
[14]:
# dataset
from src.wf_datasets import DatasetOfWavefronts
[15]:
from tqdm import tqdm
[16]:
from datetime import datetime
[17]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from mpl_toolkits.axes_grid1 import make_axes_locatable

plt.style.use('dark_background')
%matplotlib inline
# %config InlineBackend.figure_format = 'retina'
[ ]:

[18]:
today_date = datetime.today().strftime('%d-%m-%Y_%H-%M')  # date for a results folder name
today_date
[18]:
'11-04-2025_00-38'
[19]:
# Define all necessery variables for that notebook
VARIABLES = {
    # FILEPATHES
    'data_path': './data',  # folder which will be created (if not exists) to load/store Weizmann dataset
    'results_path': f'models/convolutional/conv_exp_{today_date}',  # filepath to save results!

    # GENERAL SETTINGS - SECTION 1 of the notebook
    'wavelength': 750 * 1e-6,  # working wavelength, in [m]
    'neuron_size': 0.5 * 750 * 1e-6,  # size of a pixel for DiffractiveLayers, in [m]
    'mesh_size': (200, 200),  # full size of a layer = numerical mesh
        # Comment: value from the article [1] - (200, 200)

    'use_apertures': False,  # if we need to add apertures before each Diffractie layer
        # Comment: value from the article [1] - unknown
    'aperture_size': (64, 64),  # size of each aperture = a detector square for classes zones
        # Comment: value from the article [1] - unknown

    # DATASET OF SUBSEQUENCES SETTINGS - SECTION 2 of the notebook
    'resize': (28, 28),  # size to resize pictures to add 0th padding then (up to the mesh size)
        # Comment: value from the article: the input image of 28 * 28 pixels size
    'modulation': 'amp',  # modulation type to make a wavefront from each picture mask (see 2.3.2.)
        # Comment: can be equal to `phase`, `amp` or `both`

    # NETWORK - SECTION 3 of the notebook
    'max_phase': 2 * torch.pi,  # maximal possible phase for each DiffractiveLayer
    'free_space_method': 'AS',  # propagation method
        # Comment: can be 'AS' or 'fresnel'
    'distance': 3 * 1e-2,  # distance between diffractive layers

    # 4F-SYSTEM
    'focal_length': 3 * 1e-2,  # in [m]
    'lens_radius': torch.inf,
        # Comment: if a lens radius is equal to torch.inf - analytical lens!
    'learnable_kernel': False,

    # DIFFRACTIVE LAYERS
    'use_slm': False,  # use SLM (if True) or DiffractiveLayers (if False)
    'num_layers': 5,
    'init_phases': torch.pi,
    # value or a list of initial constant phases for DiffractiveLayers OR SLM
    # SLM settings - if 'use_slm' == True
        # Comment: a size of each SLM is equal to SimulationParameters!
    'slm_shapes': [(200, 200), (200, 200), (100, 100), (100, 100), (100, 100)],
        # list of size 'encoder_num_layers'
    'slm_levels': 256,
        # value OR a list (len = 'encoder_num_layers') of numbers of levels for each SLM
    'slm_step_funcs': 'linear',  # value OR a list of step function names
        # Comment: available stp functions names - 'linear'

    # NETWORK LEARNING - SECTION 4 of the notebook
    'calculate_accuracies': True,  # will be always True for CrossEnthropyLoss! (see 3.1.4.)
        # Comment: MSE loss used!
    'DEVICE': 'cpu',  # if `cuda` - we will check if it is available (see first cells of Sec. 4)
    'train_batch_size': 64,  # batch sizes for training (see 4.1.1.)
    'val_batch_size': 64,
        # Comment: value from the article [1] - 64  # for both train and test?
    'adam_lr': 0.01,  # learning rate for Adam optimizer (see 4.1.2.)
        # Comment: value from the article [1] - 0.01
    'number_of_epochs': 20,  # number of epochs to train
        # Comment: value from the article [1] - 100-300 ?!
}
[20]:
# functions for SLM step (look documentation of SLM)
SLM_STEPS = {
    'linear': lambda x: x,
}
[21]:
RESULTS_FOLDER = VARIABLES['results_path']

# create a directory to store results
if not os.path.exists(RESULTS_FOLDER):
    os.makedirs(RESULTS_FOLDER)
[22]:
RESULTS_FOLDER
[22]:
'models/convolutional/conv_exp_11-04-2025_00-38'
[23]:
# save experiment conditions (VARIABLES dictionary)
with open(f'{RESULTS_FOLDER}/conditions.json', 'w', encoding ='utf8') as json_file:
    json.dump(VARIABLES, json_file, ensure_ascii = True)
[ ]:

1. Simulation parameters

Since in [1] there are no details, we took physical parameters from another article [2], which we used in some previous notebooks.

[24]:
working_wavelength = VARIABLES['wavelength']  # [m] - like for MNIST

c_const = 299_792_458  # [m / s]
working_frequency = c_const / working_wavelength # [Hz]
[25]:
print(f'lambda    = {working_wavelength * 1e3:.3f} mm')
print(f'frequency = {working_frequency / 1e12:.3f} THz')
lambda    = 0.750 mm
frequency = 0.400 THz
[26]:
# neuron size (square)
neuron_size = VARIABLES['neuron_size']  # [m] - like for MNIST
print(f'neuron size = {neuron_size * 1e3:.3f} mm')
neuron size = 0.375 mm
[27]:
APERTURES = VARIABLES['use_apertures']  # add apertures BEFORE each diffractive layer or not
[28]:
LAYER_SIZE = VARIABLES['mesh_size']  # mesh size
DETECTOR_SIZE = VARIABLES['aperture_size']
[29]:
# number of neurons in simulation
x_layer_nodes = LAYER_SIZE[1]
y_layer_nodes = LAYER_SIZE[0]
# Comment: Same size as proposed!

print(f'Layer size (in neurons): {x_layer_nodes} x {y_layer_nodes} = {x_layer_nodes * y_layer_nodes}')
Layer size (in neurons): 200 x 200 = 40000
[30]:
# physical size of each layer
x_layer_size_m = x_layer_nodes * neuron_size  # [m]
y_layer_size_m = y_layer_nodes * neuron_size

print(f'Layer size (in mm): {x_layer_size_m * 1e3 :.3f} x {y_layer_size_m * 1e3 :.3f}')
Layer size (in mm): 75.000 x 75.000
[31]:
X_LAYER_SIZE_M = x_layer_size_m
Y_LAYER_SIZE_M = y_layer_size_m
[ ]:

[32]:
# simulation parameters for the rest of the notebook

SIM_PARAMS = SimulationParameters(
    axes={
        'W': torch.linspace(-x_layer_size_m / 2, x_layer_size_m / 2, x_layer_nodes),
        'H': torch.linspace(-y_layer_size_m / 2, y_layer_size_m / 2, y_layer_nodes),
        'wavelength': working_wavelength,  # only one wavelength!
    }
)
[ ]:

2. Dataset preparation

2.1. MNIST Dataset

[33]:
# initialize a directory for a dataset
MNIST_DATA_FOLDER = VARIABLES['data_path']  # folder to store data

NUM_CLASSES = 10

2.1.1. Load Train and Test datasets of images

[34]:
# TRAIN (images)
mnist_train_ds = torchvision.datasets.MNIST(
    root=MNIST_DATA_FOLDER,
    train=True,  # for train dataset
    download=True,
)
[35]:
# TEST (images)
mnist_test_ds = torchvision.datasets.MNIST(
    root=MNIST_DATA_FOLDER,
    train=False,  # for test dataset
    download=True,
)
[36]:
print(f'Train data: {len(mnist_train_ds)}')
print(f'Test data : {len(mnist_test_ds)}')
Train data: 60000
Test data : 10000

2.1.2. Create Train and Test datasets of wavefronts

the input image of \(28 \times 28\) pixels size was expanded to \(200 \times 200\) with zero padding

[37]:
# select modulation type
MODULATION_TYPE = VARIABLES['modulation']  # using ONLY amplitude to encode each picture in a Wavefront!
RESIZE_SHAPE = VARIABLES['resize']  # size to resize pictures to add 0th padding then (up to the mesh size)
[38]:
resize_y = RESIZE_SHAPE[0]
resize_x = RESIZE_SHAPE[1]  # shape for transforms.Resize

# paddings along OY
pad_top = int((y_layer_nodes - resize_y) / 2)
pad_bottom = y_layer_nodes - pad_top - resize_y
# paddings along OX
pad_left = int((x_layer_nodes - resize_x) / 2)
pad_right = x_layer_nodes - pad_left - resize_x  # params for transforms.Pad
[39]:
# compose all transforms!
image_transform_for_ds = transforms.Compose(
  [
      transforms.ToTensor(),
      transforms.Resize(
          size=(resize_y, resize_x),
          interpolation=InterpolationMode.NEAREST,
      ),
      transforms.Pad(
          padding=(
              pad_left,  # left padding
              pad_top,  # top padding
              pad_right,  # right padding
              pad_bottom  # bottom padding
          ),
          fill=0,
      ),  # padding to match sizes!
      ToWavefront(modulation_type=MODULATION_TYPE)  # <- selected modulation type here!!!
  ]
)
[40]:
import src.detector_segmentation as detector_segmentation
[41]:
number_of_classes = NUM_CLASSES
[42]:
detector_segment_size = 6.4 * working_wavelength
[43]:
# size of each segment in neurons
x_segment_nodes = int(detector_segment_size / neuron_size)
y_segment_nodes = int(detector_segment_size / neuron_size)
# each segment of size = (y_segment_nodes, x_segment_nodes)
[44]:
y_boundary_nodes = y_segment_nodes * 12
x_boundary_nodes = x_segment_nodes * 12
[45]:
DETECTOR_MASK = detector_segmentation.squares_mnist(
    y_boundary_nodes, x_boundary_nodes,  # size of a detector or an aperture (in the middle of detector)
    SIM_PARAMS
)

To visualize detector zones (for further use)

[46]:
ZONES_HIGHLIGHT_COLOR = 'r'
ZONES_LW = 0.5
selected_detector_mask = DETECTOR_MASK.clone().detach()
[47]:
def get_zones_patches(detector_mask):
    """
    Returns a list of patches to draw zones in final visualisation
    """
    zones_patches = []

    delta = 1 #0.5

    for ind_class in range(number_of_classes):
        idx_y, idx_x = (detector_mask == ind_class).nonzero(as_tuple=True)

        zone_rect = patches.Rectangle(
            (idx_x[0] - delta, idx_y[0] - delta),
            idx_x[-1] - idx_x[0] + 2 * delta, idx_y[-1] - idx_y[0] + 2 * delta,
            linewidth=ZONES_LW,
            edgecolor=ZONES_HIGHLIGHT_COLOR,
            facecolor='none'
        )

        zones_patches.append(zone_rect)

    return zones_patches

Visualize mask

[48]:
fig, ax0 = plt.subplots(1, 1, figsize=(3, 3))

ax0.set_title(f'Detector segments')
ax0.imshow(selected_detector_mask, cmap='grey')

for zone in get_zones_patches(selected_detector_mask):
    # add zone's patches to the axis
    # zone_copy = copy(zone)
    ax0.add_patch(zone)

plt.show()
../../_images/examples_notebook_diff_networks_examples_09_convolutional_4f_62_0.png
[49]:
# TRAIN dataset of WAVEFRONTS
mnist_wf_train_ds = DatasetOfWavefronts(
    init_ds=mnist_train_ds,  # dataset of images
    transformations=image_transform_for_ds,  # image transformation
    sim_params=SIM_PARAMS,  # simulation parameters
    target='detector',
    detector_mask=DETECTOR_MASK
)
[50]:
# TEST dataset of WAVEFRONTS
mnist_wf_test_ds = DatasetOfWavefronts(
    init_ds=mnist_test_ds,  # dataset of images
    transformations=image_transform_for_ds,  # image transformation
    sim_params=SIM_PARAMS,  # simulation parameters
    target='detector',
    detector_mask=DETECTOR_MASK
)
[51]:
print(f'Train data: {len(mnist_train_ds)}')
print(f'Test data : {len(mnist_test_ds)}')
Train data: 60000
Test data : 10000
[52]:
# plot several EXAMPLES from TRAIN dataset
n_examples= 4  # number of examples to plot
# choosing indecies of images (from train) to plot
random.seed(78)
train_examples_ids = random.sample(range(len(mnist_train_ds)), n_examples)

all_examples_wavefronts = []

n_lines = 3
fig, axs = plt.subplots(n_lines, n_examples, figsize=(n_examples * 3, n_lines * 3.2))
for ind_ex, ind_train in enumerate(train_examples_ids):
    image, label = mnist_train_ds[ind_train]

    axs[0][ind_ex].set_title(f'id={ind_train} [{label}]')
    axs[0][ind_ex].imshow(image, cmap='gray')

    wavefront, target_image = mnist_wf_train_ds[ind_train]
    assert isinstance(wavefront, Wavefront)

    all_examples_wavefronts.append(wavefront)

    axs[1][ind_ex].set_title(f'$|WF|^2$')
    # here we can plot intensity for a wavefront
    axs[1][ind_ex].imshow(
        wavefront.intensity, cmap='gray',
        vmin=0, vmax=1
    )

    axs[2][ind_ex].set_title(f'Target image')
    axs[2][ind_ex].imshow(
        target_image, cmap='gray',
        vmin=0, vmax= 1
    )

    for zone in get_zones_patches(selected_detector_mask):
        # add zone's patches to the axis
        # zone_copy = copy(zone)
        axs[2][ind_ex].add_patch(zone)

plt.show()
../../_images/examples_notebook_diff_networks_examples_09_convolutional_4f_68_0.png
[ ]:

3. Diffractive Network with Convolutional Layer

[53]:
FS_METHOD = VARIABLES['free_space_method']
FS_DISTANCE = VARIABLES['distance']  # [m] - distance between difractive layers

MAX_PHASE = VARIABLES['max_phase']

3.1. Optical Convolutional Layer

See Figure 2 in [1]!

[54]:
FOCAL_LENGTH = VARIABLES['focal_length']
LENS_R = VARIABLES['lens_radius']

LEARN_CONV = VARIABLES['learnable_kernel']

3.1.1. Function to return 4f system

[55]:
def get_free_space(
    freespace_sim_params,
    freespace_distance,  # in [m]!
    freespace_method='AS',
):
    """
    Returns FreeSpace layer with a bounded distance parameter.
    """
    return elements.FreeSpace(
        simulation_parameters=freespace_sim_params,
        distance=freespace_distance,  # distance is not learnable!
        method=freespace_method
    )
[56]:
def get_4f_convolutional_layer(
    sim_params,
    focal_length,  # in [m]
    lens_radius,  # in [m]
    convolutional_mask,  # mask from 0 to max_phase
    learnable_mask = False,  # if a convolutional mask is learnable
    max_phase=2 * torch.pi,
    freespace_method='AS',
):
    """
    Returns a list of elements for a 4f system with a Diffractive layer in a Fourier plane.
    """
    if learnable_mask:
        diff_layer = elements.DiffractiveLayer(
            simulation_parameters=sim_params,
            mask=ConstrainedParameter(
                convolutional_mask,
                min_value=0,
                max_value=max_phase
            ),  # HERE WE ARE USING CONSTRAINED PARAMETER!
        )
    else:
        diff_layer = elements.DiffractiveLayer(
            simulation_parameters=sim_params,
            mask=convolutional_mask,  # mask is not changing during the training!
        )

    return [
        get_free_space(
            sim_params, focal_length, freespace_method
        ),
        elements.ThinLens(sim_params, focal_length, lens_radius),
        get_free_space(
            sim_params, focal_length, freespace_method
        ),
        diff_layer,  # DiffractiveLayer in a Fourier plane!
        get_free_space(
            sim_params, focal_length, freespace_method
        ),
        elements.ThinLens(sim_params, focal_length, lens_radius),
        get_free_space(
            sim_params, focal_length, freespace_method
        ),
    ]

3.1.2. Mask for a DiffractiveLayer placed in a Fourier plane

Citations from [1]

In the convolution layer, \(16\) convolution kernels were discretized into a \(4 \times 4\) array and tiled into a \(200 \times 200\) size planar space, as shown in Figure 6.

[57]:
from kernels.kernels import *
[58]:
# Generate All 16 Kernels

def generate_16_kernels(size=9):
    k3 = predefined_3x3_kernels()

    kernels = [
        gaussian_kernel(size, sigma=1.0),
        gaussian_kernel(size, sigma=2.0),
        laplacian_of_gaussian(size, sigma=1.0),
        gabor_kernel(size, theta=0),
        gabor_kernel(size, theta=math.pi / 4),
        gabor_kernel(size, theta=math.pi / 2),
        gabor_kernel(size, theta=3 * math.pi / 4),
        upscale_kernel(k3['sobel_x'], size),
        upscale_kernel(k3['sobel_y'], size),
        upscale_kernel(k3['prewitt_x'], size),
        upscale_kernel(k3['prewitt_y'], size),
        upscale_kernel(k3['emboss'], size),
        identity_kernel(size),
        center_surround_edge_kernel(size),
        gabor_kernel(size, theta=math.pi / 8),
        gabor_kernel(size, theta=5 * math.pi / 8)
    ]

    return torch.stack(kernels)  # shape: (16, 9, 9)
[59]:
# TODO: Randomly mix kernels before arrangement!
def embed_equally_spaced_kernels(canvas_size=200, grid_size=4, kernel_size=9):

    spacing = (canvas_size - (grid_size * kernel_size)) // (grid_size + 1)  # == 32
    kernels = generate_16_kernels()  # (16, 9, 9)
    canvas = torch.zeros((canvas_size, canvas_size))

    idx = 0
    for i in range(grid_size):
        for j in range(grid_size):
            y = spacing + i * (kernel_size + spacing)
            x = spacing + j * (kernel_size + spacing)
            canvas[y:y + kernel_size, x:x + kernel_size] = kernels[idx]
            idx += 1

    return canvas
[60]:
KERNELS_MASK = embed_equally_spaced_kernels() * MAX_PHASE
[61]:
fig, ax0 = plt.subplots(1, 1, figsize=(4, 4))

ax0.set_title(f'Convolutional DiffractiveLayer')
ax0.imshow(KERNELS_MASK, cmap='grey', vmin=0, vmax=MAX_PHASE)

plt.show()
../../_images/examples_notebook_diff_networks_examples_09_convolutional_4f_84_0.png

3.1.3. Convolution Layer (4f system)

[62]:
CONV_LAYER = get_4f_convolutional_layer(
    SIM_PARAMS,
    FOCAL_LENGTH,  # in [m]
    LENS_R,  # in [m]
    KERNELS_MASK,  # mask from 0 to max_phase
    learnable_mask=LEARN_CONV,  # if a convolutional mask is learnable
    max_phase=MAX_PHASE,
    freespace_method=FS_METHOD,
)
[ ]:

3.2. Optical Network after a Convolutional Layer

[63]:
USE_SLM = VARIABLES['use_slm']
NUM_LAYERS = VARIABLES['num_layers']  # number of diffractive layers
[64]:
if isinstance(VARIABLES['init_phases'], list):
    INIT_PHASES = VARIABLES['init_phases']
else:
    INIT_PHASES = [VARIABLES['init_phases'] for _ in range(NUM_LAYERS)]


assert len(INIT_PHASES) == NUM_LAYERS
[65]:
if USE_SLM:
    SLM_VARIABLES = {}

    for key in ['slm_shapes', 'slm_levels', 'slm_step_funcs']:
        if key != 'slm_step_funcs':

            if isinstance(VARIABLES[key], list):
                SLM_VARIABLES[key] = VARIABLES[key]
            else:  # all SLM's have the same parameter
                SLM_VARIABLES[key] = [VARIABLES[key] for _ in range(NUM_LAYERS)]

        else:  # for step functions!

            if isinstance(VARIABLES[key], list):
                SLM_VARIABLES[key] = [SLM_STEPS[name] for name in VARIABLES[key]]
            else:  # all SLM's have the same parameter
                SLM_VARIABLES[key] = [SLM_STEPS[VARIABLES[key]] for _ in range(NUM_LAYERS)]

        assert len(SLM_VARIABLES[key]) == NUM_LAYERS

3.2.1. Functions to get new elements

[66]:
# functions that return single elements for further architecture
def get_const_phase_layer(
    sim_params: SimulationParameters,
    value, max_phase=2 * torch.pi
):
    """
    Returns DiffractiveLayer with a constant phase mask.
    """
    x_nodes, y_nodes = sim_params.axes_size(axs=('W', 'H'))

    const_mask = torch.ones(size=(y_nodes, x_nodes)) * value

    return elements.DiffractiveLayer(
        simulation_parameters=sim_params,
        mask=ConstrainedParameter(
            const_mask,
            min_value=0,
            max_value=max_phase
        ),  # HERE WE ARE USING CONSTRAINED PARAMETER!
    )  # ATTENTION TO DOCUMENTATION!


# CHANGE ACCORDING TO THE DOCUMENTATION OF SLM!
def get_const_slm_layer(
    sim_params: SimulationParameters,
    mask_shape,
    phase_value,
    num_levels,
    step_func,
    height_m=Y_LAYER_SIZE_M,
    width_m=X_LAYER_SIZE_M,
    max_phase=2 * torch.pi
):
    """
    Returns SpatialLightModulator with a constant phase mask.
    """
    y_nodes, x_nodes = mask_shape
    const_mask = torch.ones(size=(y_nodes, x_nodes)) * phase_value

    return elements.SpatialLightModulator(
        simulation_parameters=sim_params,
        mask=ConstrainedParameter(
            const_mask,
            min_value=0,
            max_value=max_phase
        ),  # HERE WE ARE USING CONSTRAINED PARAMETER!
        height=height_m,
        width=width_m,
        # location=(0., 0.),  # by default
        number_of_levels=num_levels,
        step_function=step_func,
        # mode='nearest',  # by default it is 'nearest'
    )  # ATTENTION TO DOCUMENTATION!

3.2.2. Elements list

Function to get a list of elements to reproduce an architecture:

[67]:
def get_elements_list(
    num_layers,
    simulation_parameters,
    freespace_distance,
    freespace_method,
    apertures=False,
    aperture_size=(100, 100),
):
    """
    Composes a list of elements for an encoder (with no Detector).
    ...

    Parameters
    ----------
    num_layers : int
        Number of layers in the system.
    simulation_parameters : SimulationParameters()
        A simulation parameters for a task.
    freespace_distance : float,
        A distance between phase layers in [m].
    freespace_method : str
        Propagation method for free spaces in a setup.

    apertures : bool
        If True, than before each DiffractiveLayer (and detector) we add a square aperture.
        Comment: there are strickt square apertures!
    aperture_size : tuple
        A size of square apertures.

    mode : str
        Takes values: 'encoder' or 'decoder'.

    Returns
    -------
    elements_list : list(Element)
        List of Elements for an encoder/decoder.
    """
    elements_list = []  # list of elements

    use_slm = USE_SLM
    init_phases = INIT_PHASES
    if use_slm:
        slm_masks_shape = SLM_VARIABLES['encoder_slm_shapes']
        slm_levels = SLM_VARIABLES['encoder_slm_levels']
        slm_funcs = SLM_VARIABLES['encoder_slm_step_funcs']

    if apertures:  # equal masks for all apertures (select a part in the middle)
        aperture_mask = torch.ones(size=aperture_size)

        y_nodes, x_nodes = simulation_parameters.axes_size(axs=('H', 'W'))
        y_mask, x_mask = aperture_mask.size()
        pad_top = int((y_nodes - y_mask) / 2)
        pad_bottom = y_nodes - pad_top - y_mask
        pad_left = int((x_nodes - x_mask) / 2)
        pad_right = x_nodes - pad_left - x_mask  # params for transforms.Pad

        # padding transform to match aperture size with simulation parameters
        aperture_mask = functional.pad(
            input=aperture_mask,
            pad=(pad_left, pad_right, pad_top, pad_bottom),
            mode='constant',
            value=0
        )

    # first FreeSpace layer before first DiffractiveLayer
    # after 4f-system already have a FreeSpace

    # compose the architecture
    for ind_layer in range(num_layers):

        # add strickt square Aperture
        if apertures:
            elements_list.append(
                elements.Aperture(
                    simulation_parameters=simulation_parameters,
                    mask=aperture_mask
                )
            )

        # ------------------------------------------------------------PHASE LAYER
        if use_slm:  # add a phase layer (SLM or DiffractiveLayer)
            # add SLM (learnable phase mask)
            elements_list.append(
                get_const_slm_layer(
                    simulation_parameters,
                    mask_shape=slm_masks_shape[ind_layer],
                    phase_value=init_phases[ind_layer],
                    num_levels=slm_levels[ind_layer],
                    step_func=slm_funcs[ind_layer],
                    max_phase=MAX_PHASE
                )
            )
        else:
            # add DiffractiveLayer (learnable phase mask)
            elements_list.append(
                get_const_phase_layer(
                    simulation_parameters,
                    value=init_phases[ind_layer],
                    max_phase=MAX_PHASE
                )
            )
        # -----------------------------------------------------------------------

        # add FreeSpace
        elements_list.append(
            get_free_space(
                simulation_parameters,  # simulation parameters for the notebook
                freespace_distance,  # in [m]
                freespace_method=freespace_method,
            )
        )

    # ---------------------------------------------------------------------------
    # add Detector in the end of the system!
    elements_list.append(
        Detector(
            simulation_parameters=simulation_parameters,
            func='intensity'  # detector that returns intensity
        )
    )

    return elements_list
[ ]:

3.3. Model with a Convolutional Layer (4f System)

3.3.1. Model Class

[68]:
class ConvolutionalSystem(nn.Module):
    """
    A simple convolutional network with a 4f system as an optical convolutional layer.
    """

    def __init__(
        self,
        sim_params: SimulationParameters,
        conv_layer_list: list,
        num_layers: int,
        fs_distance: float,
        fs_method: str = 'AS',
        device: str | torch.device = torch.get_default_device(),
    ):
        """
        Parameters
        ----------
        sim_params : SimulationParameters
            Simulation parameters for the task.
        conv_layer_list : list
            List of Elements for a Convolutional layer (4f system).
        num_layers : int
            Number of DiffractiveLayer's after 4f-system.
        fs_distance : float,
            A distance between phase layers in [m].
        fs_method : str
            Propagation method for free spaces in a setup.
        elements_list : list
            List of elements to compose a network after a Convolutional Layer.
        """
        super().__init__()

        self.sim_params = sim_params

        self.h, self.w = self.sim_params.axes_size(
            axs=('H', 'W')
        )  # height and width for a wavefronts

        self.__device = device
        self.fs_method = fs_method

        # CONVOLUTIONAL LAYER
        self.conv_layer_list = conv_layer_list
        self.conv_layer = nn.Sequential(*conv_layer_list).to(self.__device)

        # NETWORK
        elements_list = get_elements_list(
            num_layers,
            self.sim_params,
            fs_distance,
            fs_method,
            apertures=VARIABLES['use_apertures'],
            aperture_size=VARIABLES['aperture_size'],
        )  # Detector is here!

        # self.encoder_elements = encoder_elements_list
        self.net = nn.Sequential(*elements_list).to(self.__device)

    def stepwise_propagation(self, input_wavefront: Wavefront, mode: str='encode'):
        """
        Function that consistently applies forward method of each element of
        Convolution Layer or the Network after Convolution
        to an `input_wavefront`.

        Parameters
        ----------
        input_wavefront : torch.Tensor
            A wavefront that enters the optical network.
        mode : str
            Specify a mode 'convolution' or 'after convolution'.

        Returns
        -------
        str
            A string that represents a scheme of a propagation through a setup.
        list(torch.Tensor)
            A list of an input wavefront evolution during a propagation through a setup.
        """
        this_wavefront = input_wavefront
        # list of wavefronts while propagation of an initial wavefront through the system
        steps_wavefront = [this_wavefront]  # input wavefront is a zeroth step

        optical_scheme = ''  # string that represents a linear optical setup (schematic)

        if mode == 'convolution':
            net = self.conv_layer
        if mode == 'after convolution':
            net = self.net

        net.eval()
        for ind_element, element in enumerate(net):
            # for visualization in a console
            element_name = type(element).__name__
            optical_scheme += f'-({ind_element})-> [{ind_element + 1}. {element_name}] '

            if ind_element == len(net) - 1:
                optical_scheme += f'-({ind_element + 1})->'

            # element forward
            with torch.no_grad():
                this_wavefront = element.forward(this_wavefront)

            steps_wavefront.append(this_wavefront)  # add a wavefront to list of steps

        return optical_scheme, steps_wavefront

    def forward(self, wavefront_in):
        """
        Parameters
        ----------
        wavefront_in: Wavefront('bs', 'H', 'W')

        Returns
        -------
        detector_image : torch.Tensor
            Image on a Detector.
        """
        if len(wavefront_in.shape) > 2:  # if a batch is an input
            batch_flag = True
            bs = wavefront_in.shape[0]
        else:
            batch_flag = False

        # convolutional layer
        wavefront_after_convolution = self.conv_layer(wavefront_in)
        # other layers
        detector_image = self.net(wavefront_after_convolution)

        return detector_image

3.3.2. Empty model

[69]:
def get_net():
    return ConvolutionalSystem(
        SIM_PARAMS,
        CONV_LAYER,
        NUM_LAYERS,
        FS_DISTANCE,
        FS_METHOD,
    )
[ ]:

3.4. Detector processor (to calculate accuracies only)

Comment: DetectorProcessor in our library is used to process an information on detector. For example, for the current task DetectorProcessor must return only 10 values (1 value per 1 class).

[70]:
CALCULATE_ACCURACIES = VARIABLES['calculate_accuracies']
# if False, accuracies will not be calculated!
[71]:
# create a DetectorProcessorOzcanClf object
if CALCULATE_ACCURACIES:
    detector_processor = DetectorProcessorClf(
        simulation_parameters=SIM_PARAMS,
        num_classes=NUM_CLASSES,
        segmented_detector=DETECTOR_MASK,
    )
else:
    detector_processor = None
[ ]:

3.1.4 Detector processor (to calculate accuracies only)

Comment: DetectorProcessor in our library is used to process an information on detector. For example, for the current task DetectorProcessor must return only 10 values (1 value per 1 class).

[72]:
CALCULATE_ACCURACIES = VARIABLES['calculate_accuracies']  # if False, accuracies will not be calculated!
[73]:
# create a DetectorProcessorOzcanClf object
if CALCULATE_ACCURACIES:
    detector_processor = DetectorProcessorClf(
        simulation_parameters=SIM_PARAMS,
        num_classes=NUM_CLASSES,
        segmented_detector=DETECTOR_MASK,
    )
else:
    detector_processor = None
[ ]:

4. Training of the network

[74]:
DEVICE = VARIABLES['DEVICE']  # 'mps' is not support a CrossEntropyLoss
[75]:
if DEVICE == 'cuda':
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

DEVICE
[75]:
'cpu'

4.1. Prepare some stuff for training

4.1.1. DataLoader’s

Citations from methods of [1]:

The batch size set for the training process was \(64\).

[76]:
train_bs = VARIABLES['train_batch_size']  # a batch size for training set
val_bs = VARIABLES['val_batch_size']

Forthis task, phase-only transmission masks weredesigned by training a five-layer \(D^2 NN\) with \(55000\) images (\(5000\) validation images) from theMNIST (Modified National Institute of Stan-dards and Technology) handwritten digit data-base.

[77]:
# mnist_wf_train_ds
train_wf_ds, val_wf_ds = torch.utils.data.random_split(
    dataset=mnist_wf_train_ds,
    lengths=[55000, 5000],  # sizes from the article
    generator=torch.Generator().manual_seed(178)  # for reproducibility
)
[78]:
train_wf_loader = torch.utils.data.DataLoader(
    train_wf_ds,
    batch_size=train_bs,
    shuffle=True,
    # num_workers=2,
    drop_last=False,
)

val_wf_loader = torch.utils.data.DataLoader(
    val_wf_ds,
    batch_size=val_bs,
    shuffle=False,
    # num_workers=2,
    drop_last=False,
)
[79]:
test_wf_loader = torch.utils.data.DataLoader(
    mnist_wf_test_ds,
    batch_size=val_bs,
    shuffle=True,
    # num_workers=2,
    drop_last=False,
)
[ ]:

4.1.2. Optimizer and loss function

[80]:
LR = VARIABLES['adam_lr']
[81]:
def get_adam_optimizer(net):
    return torch.optim.Adam(
        params=net.parameters(),  # NETWORK PARAMETERS!
        lr=LR
    )
[82]:
LOSS = 'MSE'
[83]:
if LOSS == 'MSE':
    loss_func_clf = nn.MSELoss()  # by default: reduction='mean'
    loss_func_name = 'MSE'
[ ]:

4.1.3. Training and evaluation loops

[84]:
def onn_train_mse(
    optical_net, wavefronts_dataloader,
    detector_processor_clf,  # DETECTOR PROCESSOR needed for accuracies only!
    loss_func, optimizer,
    device='cpu', show_process=False
):
    """
    Function to train `optical_net` (classification task)
    ...

    Parameters
    ----------
        optical_net : torch.nn.Module
            Neural Network composed of Elements.
        wavefronts_dataloader : torch.utils.data.DataLoader
            A loader (by batches) for the train dataset of wavefronts.
        detector_processor_clf : DetectorProcessorClf
            A processor of a detector image for a classification task, that returns `probabilities` of classes.
        loss_func :
            Loss function for a multi-class classification task.
        optimizer: torch.optim
            Optimizer...
        device : str
            Device to computate on...
        show_process : bool
            Flag to show (or not) a progress bar.

    Returns
    -------
        batches_losses : list[float]
            Losses for each batch in an epoch.
        batches_accuracies : list[float]
            Accuracies for each batch in an epoch.
        epoch_accuracy : float
            Accuracy for an epoch.
    """
    optical_net.train()  # activate 'train' mode of a model
    batches_losses = []  # to store loss for each batch
    batches_accuracies = []  # to store accuracy for each batch

    correct_preds = 0
    size = 0

    for batch_wavefronts, batch_targets in tqdm(
        wavefronts_dataloader,
        total=len(wavefronts_dataloader),
        desc='train', position=0,
        leave=True, disable=not show_process
    ):  # go by batches
        # batch_wavefronts - input wavefronts, batch_labels - labels
        batch_size = batch_wavefronts.size()[0]

        batch_wavefronts = batch_wavefronts.to(device)
        batch_targets = batch_targets.to(device)

        optimizer.zero_grad()

        # forward of an optical network
        detector_output = optical_net(batch_wavefronts)

        # calculate loss for a batch
        loss = loss_func(detector_output, batch_targets)

        loss.backward()
        optimizer.step()

        # ACCURACY
        if CALCULATE_ACCURACIES:
            # process a detector image
            batch_labels = detector_processor_clf.batch_forward(batch_targets).argmax(1)
            batch_probas = detector_processor_clf.batch_forward(detector_output)

            batch_correct_preds = (
                batch_probas.argmax(1) == batch_labels
            ).type(torch.float).sum().item()

            correct_preds += batch_correct_preds
            size += batch_size

        # accumulate losses and accuracies for batches
        batches_losses.append(loss.item())
        if CALCULATE_ACCURACIES:
            batches_accuracies.append(batch_correct_preds / batch_size)
        else:
            batches_accuracies.append(0.)

    if CALCULATE_ACCURACIES:
        epoch_accuracy = correct_preds / size
    else:
        epoch_accuracy = 0.

    return batches_losses, batches_accuracies, epoch_accuracy
[85]:
def onn_validate_mse(
    optical_net, wavefronts_dataloader,
    detector_processor_clf,  # DETECTOR PROCESSOR NEEDED!
    loss_func,
    device='cpu', show_process=False
    ):
    """
    Function to validate `optical_net` (classification task)
    ...

    Parameters
    ----------
        optical_net : torch.nn.Module
            Neural Network composed of Elements.
        wavefronts_dataloader : torch.utils.data.DataLoader
            A loader (by batches) for the train dataset of wavefronts.
        detector_processor_clf : DetectorProcessorClf
            A processor of a detector image for a classification task, that returns `probabilities` of classes.
        loss_func :
            Loss function for a multi-class classification task.
        device : str
            Device to computate on...
        show_process : bool
            Flag to show (or not) a progress bar.

    Returns
    -------
        batches_losses : list[float]
            Losses for each batch in an epoch.
        batches_accuracies : list[float]
            Accuracies for each batch in an epoch.
        epoch_accuracy : float
            Accuracy for an epoch.
    """
    optical_net.eval()  # activate 'eval' mode of a model
    batches_losses = []  # to store loss for each batch
    batches_accuracies = []  # to store accuracy for each batch

    correct_preds = 0
    size = 0

    for batch_wavefronts, batch_targets in tqdm(
        wavefronts_dataloader,
        total=len(wavefronts_dataloader),
        desc='validation', position=0,
        leave=True, disable=not show_process
    ):  # go by batches
        # batch_wavefronts - input wavefronts, batch_labels - labels
        batch_size = batch_wavefronts.size()[0]

        batch_wavefronts = batch_wavefronts.to(device)
        batch_targets = batch_targets.to(device)

        with torch.no_grad():
            detector_outputs = optical_net(batch_wavefronts)
            # calculate loss for a batch
            loss = loss_func(detector_outputs, batch_targets)

        # ACCURACY
        if CALCULATE_ACCURACIES:
            # process a detector image
            batch_labels = detector_processor_clf.batch_forward(batch_targets).argmax(1)
            batch_probas = detector_processor_clf.batch_forward(detector_outputs)

            batch_correct_preds = (
                batch_probas.argmax(1) == batch_labels
            ).type(torch.float).sum().item()

            correct_preds += batch_correct_preds
            size += batch_size

        # accumulate losses and accuracies for batches
        batches_losses.append(loss.item())
        if CALCULATE_ACCURACIES:
            batches_accuracies.append(batch_correct_preds / batch_size)
        else:
            batches_accuracies.append(0.)

    if CALCULATE_ACCURACIES:
        epoch_accuracy = correct_preds / size
    else:
        epoch_accuracy = 0.

    return batches_losses, batches_accuracies, epoch_accuracy
[ ]:

4.2. Training of the optical network

4.2.1. Before training

a diffractive layer … neurons … were initialized with \(\pi\) for phase values and \(1\) for amplitude values …

[86]:
SIM_PARAMS = SIM_PARAMS.to(DEVICE)

untrained_net = get_net().to(DEVICE)
if detector_processor:
    detector_processor = detector_processor.to(DEVICE)
[87]:
test_losses_0, _, test_accuracy_0 = onn_validate_mse(
    untrained_net,  # optical network composed in 3.
    test_wf_loader,  # dataloader of training set
    detector_processor,  # detector processor
    loss_func_clf,
    device=DEVICE,
    show_process=True,
)  # evaluate the model

print(
    'Results before training on TEST set:\n' +
    f'\t{loss_func_name} : {np.mean(test_losses_0):.6f}\n' +
    f'\tAccuracy : {(test_accuracy_0*100):>0.1f} %'
)
validation: 100%|██████████████████████████████████████████| 157/157 [00:38<00:00,  4.10it/s]
Results before training on TEST set:
        MSE : 0.006371
        Accuracy : 9.9 %

[ ]:

4.2.2. Training

[88]:
n_epochs = VARIABLES['number_of_epochs']
print_each = 2  # print each n'th epoch info
[89]:
scheduler = None  # sheduler for a lr tuning during training
[ ]:

[91]:
# Recreate a system to restart training!
SIM_PARAMS = SIM_PARAMS.to(DEVICE)
net_to_train = get_net().to(DEVICE)

# Linc optimizer to a recreated net!
optimizer_clf = get_adam_optimizer(net_to_train)
[ ]:

[92]:
train_epochs_losses = []
val_epochs_losses = []  # to store losses of each epoch

train_epochs_acc = []
val_epochs_acc = []  # to store accuracies

torch.manual_seed(98)  # for reproducability?

for epoch in range(n_epochs):
    if (epoch == 0) or ((epoch + 1) % print_each == 0) or (epoch == n_epochs - 1):
        print(f'Epoch #{epoch + 1}: ', end='')
        show_progress = True
    else:
        show_progress = False

    # TRAIN
    start_train_time = time.time()  # start time of the epoch (train)
    train_losses, _, train_accuracy = onn_train_mse(
        net_to_train,  # optical network composed in 3.
        train_wf_loader,  # dataloader of training set
        detector_processor,  # detector processor
        loss_func_clf,
        optimizer_clf,
        device=DEVICE,
        show_process=show_progress,
    )  # train the model
    mean_train_loss = np.mean(train_losses)

    if (epoch == 0) or ((epoch + 1) % print_each == 0) or (epoch == n_epochs - 1):  # train info
        print('Training results')
        print(f'\t{loss_func_name} : {mean_train_loss:.6f}')
        if CALCULATE_ACCURACIES:
            print(f'\tAccuracy : {(train_accuracy*100):>0.1f} %')
        print(f'\t------------   {time.time() - start_train_time:.2f} s')

    # VALIDATION
    start_val_time = time.time()  # start time of the epoch (validation)
    val_losses, _, val_accuracy = onn_validate_mse(
        net_to_train,  # optical network composed in 3.
        val_wf_loader,  # dataloader of validation set
        detector_processor,  # detector processor
        loss_func_clf,
        device=DEVICE,
        show_process=show_progress,
    )  # evaluate the model
    mean_val_loss = np.mean(val_losses)

    if (epoch == 0) or ((epoch + 1) % print_each == 0) or (epoch == n_epochs - 1):  # validation info
        print('Validation results')
        print(f'\t{loss_func_name} : {mean_val_loss:.6f}')
        if CALCULATE_ACCURACIES:
            print(f'\tAccuracy : {(val_accuracy*100):>0.1f} %')
        print(f'\t------------   {time.time() - start_val_time:.2f} s')

    if scheduler:
        scheduler.step(mean_val_loss)

    # save losses
    train_epochs_losses.append(mean_train_loss)
    val_epochs_losses.append(mean_val_loss)
    # seve accuracies
    train_epochs_acc.append(train_accuracy)
    val_epochs_acc.append(val_accuracy)
Epoch #1:
train: 100%|███████████████████████████████████████████████| 860/860 [08:08<00:00,  1.76it/s]
Training results
        MSE : 0.005861
        Accuracy : 51.9 %
        ------------   488.81 s
validation: 100%|████████████████████████████████████████████| 79/79 [00:28<00:00,  2.75it/s]
Validation results
        MSE : 0.005754
        Accuracy : 64.0 %
        ------------   28.74 s
Epoch #2:
train: 100%|███████████████████████████████████████████████| 860/860 [08:58<00:00,  1.60it/s]
Training results
        MSE : 0.005718
        Accuracy : 64.7 %
        ------------   538.35 s
validation: 100%|████████████████████████████████████████████| 79/79 [00:30<00:00,  2.63it/s]
Validation results
        MSE : 0.005709
        Accuracy : 63.8 %
        ------------   30.09 s
Epoch #4:
train: 100%|███████████████████████████████████████████████| 860/860 [10:12<00:00,  1.41it/s]
Training results
        MSE : 0.005675
        Accuracy : 67.2 %
        ------------   612.07 s
validation: 100%|████████████████████████████████████████████| 79/79 [00:30<00:00,  2.61it/s]
Validation results
        MSE : 0.005678
        Accuracy : 67.5 %
        ------------   30.33 s
Epoch #6:
train: 100%|███████████████████████████████████████████████| 860/860 [09:11<00:00,  1.56it/s]
Training results
        MSE : 0.005661
        Accuracy : 67.9 %
        ------------   551.55 s
validation: 100%|████████████████████████████████████████████| 79/79 [00:30<00:00,  2.57it/s]
Validation results
        MSE : 0.005667
        Accuracy : 68.6 %
        ------------   30.70 s
Epoch #8:
train: 100%|███████████████████████████████████████████████| 860/860 [10:45<00:00,  1.33it/s]
Training results
        MSE : 0.005653
        Accuracy : 68.3 %
        ------------   645.20 s
validation: 100%|████████████████████████████████████████████| 79/79 [00:32<00:00,  2.41it/s]
Validation results
        MSE : 0.005660
        Accuracy : 67.1 %
        ------------   32.82 s
Epoch #10:
train: 100%|███████████████████████████████████████████████| 860/860 [11:06<00:00,  1.29it/s]
Training results
        MSE : 0.005648
        Accuracy : 68.6 %
        ------------   666.04 s
validation: 100%|████████████████████████████████████████████| 79/79 [00:31<00:00,  2.53it/s]
Validation results
        MSE : 0.005655
        Accuracy : 68.3 %
        ------------   31.26 s
Epoch #12:
train: 100%|███████████████████████████████████████████████| 860/860 [09:17<00:00,  1.54it/s]
Training results
        MSE : 0.005645
        Accuracy : 68.8 %
        ------------   557.04 s
validation: 100%|████████████████████████████████████████████| 79/79 [00:30<00:00,  2.61it/s]
Validation results
        MSE : 0.005652
        Accuracy : 68.8 %
        ------------   30.32 s
Epoch #14:
train: 100%|███████████████████████████████████████████████| 860/860 [09:12<00:00,  1.56it/s]
Training results
        MSE : 0.005643
        Accuracy : 68.9 %
        ------------   552.11 s
validation: 100%|████████████████████████████████████████████| 79/79 [00:30<00:00,  2.61it/s]
Validation results
        MSE : 0.005651
        Accuracy : 68.0 %
        ------------   30.30 s
Epoch #16:
train: 100%|███████████████████████████████████████████████| 860/860 [09:34<00:00,  1.50it/s]
Training results
        MSE : 0.005641
        Accuracy : 68.9 %
        ------------   574.58 s
validation: 100%|████████████████████████████████████████████| 79/79 [00:31<00:00,  2.53it/s]
Validation results
        MSE : 0.005650
        Accuracy : 68.5 %
        ------------   31.20 s
Epoch #18:
train: 100%|███████████████████████████████████████████████| 860/860 [09:37<00:00,  1.49it/s]
Training results
        MSE : 0.005640
        Accuracy : 68.9 %
        ------------   577.47 s
validation: 100%|████████████████████████████████████████████| 79/79 [00:31<00:00,  2.50it/s]
Validation results
        MSE : 0.005649
        Accuracy : 68.9 %
        ------------   31.63 s
Epoch #20:
train: 100%|███████████████████████████████████████████████| 860/860 [10:10<00:00,  1.41it/s]
Training results
        MSE : 0.005639
        Accuracy : 69.1 %
        ------------   610.99 s
validation: 100%|████████████████████████████████████████████| 79/79 [00:31<00:00,  2.54it/s]
Validation results
        MSE : 0.005648
        Accuracy : 68.6 %
        ------------   31.16 s

[ ]:

[93]:
# learning curve
fig, axs = plt.subplots(1, 2, figsize=(10, 3))

axs[0].plot(range(1, n_epochs + 1), np.array(train_epochs_losses) * 1e3, label='train')
axs[0].plot(range(1, n_epochs + 1), np.array(val_epochs_losses) * 1e3, linestyle='dashed', label='validation')

axs[1].plot(range(1, n_epochs + 1), train_epochs_acc, label='train')
axs[1].plot(range(1, n_epochs + 1), val_epochs_acc, linestyle='dashed', label='validation')

axs[0].set_ylabel(loss_func_name + r' $\times 10^3$')
axs[0].set_xlabel('Epoch')
axs[0].legend()

axs[1].set_ylabel('Accuracy')
axs[1].set_xlabel('Epoch')
axs[1].legend()

plt.show()
../../_images/examples_notebook_diff_networks_examples_09_convolutional_4f_148_0.png
[ ]:

[ ]:

[ ]:

[ ]:

[95]:
n_cols = NUM_LAYERS  # number of columns for DiffractiveLayer's masks visualization
n_rows = 1

# plot wavefronts phase
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3 + 2, n_rows * 3.2))
ind_diff_layer = 0

cmap = 'gist_stern'  # 'gist_stern' 'rainbow'

net_to_plot = net_to_train.net

for ind_layer, layer in enumerate(net_to_plot):
    if isinstance(layer, elements.DiffractiveLayer) or isinstance(layer, elements.SpatialLightModulator):
        # plot masks for Diffractive layers
        if n_rows > 1:
            ax_this = axs[ind_diff_layer // n_cols][ind_diff_layer % n_cols]
        else:
            ax_this = axs[ind_diff_layer % n_cols]

        # ax_this.set_title(titles[ind_module])

        trained_mask = layer.mask.detach()

        phase_mask_this = ax_this.imshow(
            trained_mask, cmap=cmap,
            # vmin=0, vmax=MAX_PHASE
        )
        ind_diff_layer += 1

        if APERTURES:  # select only a part within apertures!
            x_frame = (x_layer_nodes - DETECTOR_SIZE[1]) / 2
            y_frame = (y_layer_nodes - DETECTOR_SIZE[0]) / 2
            ax_this.axis([x_frame, x_layer_nodes - x_frame, y_layer_nodes - y_frame, y_frame])

fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.87, 0.15, 0.01, 0.7])
plt.colorbar(phase_mask_this, cax=cbar_ax)

plt.show()
../../_images/examples_notebook_diff_networks_examples_09_convolutional_4f_154_0.png
[ ]:

[78]:
# array with all losses
all_lasses_header = ','.join([
    f'{loss_func_name.split()[0]}_train', f'{loss_func_name.split()[0]}_val',
    'accuracy_train', 'accuracy_val'
])
all_losses_array = np.array(
    [train_epochs_losses, val_epochs_losses, train_epochs_acc, val_epochs_acc]
).T
[79]:
# filepath to save the model
model_filepath = f'{RESULTS_FOLDER}/conv_net.pth'
# filepath to save losses
losses_filepath = f'{RESULTS_FOLDER}/training_curves.csv'
[80]:
# saving model
torch.save(net_to_train.state_dict(), model_filepath)
[81]:
# saving losses
np.savetxt(
    losses_filepath, all_losses_array,
    delimiter=',', header=all_lasses_header, comments=""
)
[ ]:

[ ]:

[ ]:

4.4. Examples of encoding/decoding

4.4.1. Select a sample to encode/decode

[108]:
n_test_examples = 5

random.seed(78)
test_examples_ids = random.sample(range(len(mnist_wf_test_ds)), n_test_examples)
[ ]:

[109]:
fig, axs = plt.subplots(1, n_test_examples, figsize=(n_test_examples * 3, 3))
cmap = 'grey'

for ax, ind_test in zip(axs, test_examples_ids):
    test_wavefront, test_target = mnist_wf_test_ds[ind_test]

    ax.set_title(f'{ind_test}')
    ax.imshow(test_wavefront.intensity, cmap=cmap)

    # ax.set_xticks([])
    # ax.set_yticks([])

plt.show()
../../_images/examples_notebook_diff_networks_examples_09_convolutional_4f_169_0.png
[ ]:

4.4.2. Encode/decode an example

[93]:
def encode_and_decode(autoencoder, input_wf, use_encoder_aperture=False):
    # if use_encoder_aperture is True - apply strickt square aperture to encoded image
    # aperture is defined as REGION_MASK, that was used in loss!

    with torch.no_grad():
        # ENCODE
        encoded_image = autoencoder.encode(input_wf)
        if not PRESERVE_PHASE:
            encoded_image = encoded_image.abs() + 0j  # reset phases before decoding!
        if use_encoder_aperture:
            encoded_image = encoded_image * REGION_MASK  # apply aperture for encoded image

        # DECODE
        decoded_image = autoencoder.decode(encoded_image)
        if not PRESERVE_PHASE:
            decoded_image = decoded_image.abs() + 0j  # reset phases before decoding!

    return encoded_image, decoded_image
[ ]:

[110]:
n_lines = 3  # image / encoded / decoded

fig, axs = plt.subplots(n_lines, n_test_examples, figsize=(n_test_examples * 3, n_lines * 3.2))

to_plot = 'amp'  # <--- chose what to plot
cmap = 'grey'  # choose colormaps
use_encoder_aperture = True

max_amp = 1  # upper limits for colorplots
max_phase = 2 * torch.pi

for ind_ex, ind_test in enumerate(test_examples_ids):
    ax_image, ax_encoded, ax_decoded = axs[0][ind_ex], axs[1][ind_ex], axs[2][ind_ex]

    test_wavefront, test_target = mnist_wf_test_ds[ind_test]

    ax_image.set_title(f'id={ind_test} [{test_target}]')
    if to_plot == 'amp':
        ax_image.imshow(
            test_wavefront.intensity, cmap=cmap,
            # vmin=0, vmax=max_amp
        )
    if to_plot == 'phase':
        ax_image.imshow(
            test_wavefront.phase, cmap=cmap,
            # vmin=0, vmax=max_phase
        )

    encoded_this, decoded_this = encode_and_decode(
        autoencoder_to_train, test_wavefront, use_encoder_aperture
    )

    ax_encoded.set_title('encoded')
    if to_plot == 'amp':
        ax_encoded.imshow(
            encoded_this.intensity, cmap=cmap,
            # vmin=0, vmax=max_amp
        )
    if to_plot == 'phase':
        ax_encoded.imshow(
            encoded_this.phase, cmap=cmap,
            # vmin=0, vmax=max_phase
        )
    if use_encoder_aperture:  # select only a part within apertures!
        x_frame = (x_layer_nodes - REGION_MASK_SIZE[1]) / 2
        y_frame = (y_layer_nodes - REGION_MASK_SIZE[0]) / 2
        ax_encoded.set_xlim([x_frame, x_layer_nodes - x_frame])
        ax_encoded.set_ylim([y_layer_nodes - y_frame, y_frame])

    ax_decoded.set_title('decoded')
    if to_plot == 'amp':
        ax_decoded.imshow(
            decoded_this.intensity, cmap=cmap,
            # vmin=0, vmax=max_amp
        )
    if to_plot == 'phase':
        ax_decoded.imshow(
            decoded_this.phase, cmap=cmap,
            # vmin=0, vmax=max_phase
        )

plt.show()
../../_images/examples_notebook_diff_networks_examples_09_convolutional_4f_176_0.png
[ ]:

[ ]: