Skip to content

SimulationParameters

svetlanna.SimulationParameters

SimulationParameters(axes: Mapping[str, Tensor | float])
SimulationParameters(
    *,
    x: Tensor | float,
    y: Tensor | float,
    wavelength: Tensor | float,
    **additional_axes: Tensor | float
)
SimulationParameters(
    axes: Mapping[str, Tensor | float] | None = None,
    /,
    **kwaxes: Tensor | float,
)

Bases: Module

Simulation parameters. Manages coordinate systems and physical parameters for optical simulations. Required axes: x, y, wavelength. Additional axes can be added.

Inherits from nn.Module so that axes are registered as buffers and participate in automatic device management when used as submodules of Elements.

Note

Axes are registered as non-persistent buffers (persistent=False). This means they are not included in state_dict() and will not be saved during checkpointing. The simulation grid must be provided when constructing the model; it does not need to be restored from a checkpoint.

Examples:

Let's define simalation grid of width and height of 1 mm with 512 points for both axes (Nx=Ny=512) and wavelength of 632.8 nm:

1
2
3
4
5
6
7
8
9
import svetlanna as sv
from svetlanna.units import ureg
import torch

sim_params = sv.SimulationParameters(
    x=torch.linspace(-0.5, 0.5, 512) * ureg.mm,
    y=torch.linspace(-0.5, 0.5, 512) * ureg.mm,
    wavelength=632.8 * ureg.nm,
)
You can make wavelength an array for polychromatic simulations:
1
2
3
4
5
sim_params = sv.SimulationParameters(
    x=torch.linspace(-0.5, 0.5, 512) * ureg.mm,
    y=torch.linspace(-0.5, 0.5, 512) * ureg.mm,
    wavelength=torch.linspace(600, 800, 10) * ureg.nm,
)

The order of axes matters! It defines the order of dimensions in wavefront tensors. In first case above, all optical elements will expect wavefront tensors with shape (..., Ny, Nx), while in the second case, the expected shape will be (..., Nwavelength, Ny, Nx). ... means any number of leading dimensions (e.g., for batch).

If you change the order:

1
2
3
4
5
sim_params = sv.SimulationParameters(
    x=torch.linspace(-0.5, 0.5, 512) * ureg.mm,
    wavelength=torch.linspace(600, 800, 10) * ureg.nm,
    y=torch.linspace(-0.5, 0.5, 512) * ureg.mm,
)
the expected order of axes is ('y', 'wavelength', 'x'), so all optical elements will expect wavefront tensors with shape (..., Ny, Nwavelength, Nx).

You can add custom axes as needed:

1
2
3
4
5
6
sim_params = sv.SimulationParameters(
    t=torch.linspace(0, 1, 5) * ureg.s,  # time axis
    x=torch.linspace(-0.5, 0.5, 512) * ureg.mm,
    wavelength=632.8 * ureg.nm,
    y=torch.linspace(-0.5, 0.5, 512) * ureg.mm,
)
In this case, the expected order of axes is ('y', 'x', 't') as wavelength is scalar, so all optical elements will expect wavefront tensors with shape (..., Ny, Nx, Nt).

axis_names property

axis_names: tuple[str, ...]

Get names of non-scalar axes (those with length > 1).

device property

device: device

Get the device where all axes are stored.

from_ranges classmethod

from_ranges(
    *,
    x_range: tuple[float, float],
    x_points: int,
    y_range: tuple[float, float],
    y_points: int,
    wavelength: Tensor | float,
    **additional_axes: Tensor | float
) -> Self

Create SimulationParameters from coordinate ranges.

Parameters:

  • x_range (tuple[float, float]) –

    (min, max) range for x-axis. Use ureg for units.

  • x_points (int) –

    Number of points along x-axis.

  • y_range (tuple[float, float]) –

    (min, max) range for y-axis. Use ureg for units.

  • y_points (int) –

    Number of points along y-axis.

  • wavelength (Tensor | float) –

    Optical wavelength. Use ureg for units.

  • **additional_axes (Tensor | float, default: {} ) –

    Additional axes.

Examples:

1
2
3
4
5
6
>>> from svetlanna.units import ureg
>>> params = SimulationParameters.from_ranges(
...     x_range=(-1*ureg.mm, 1*ureg.mm), x_points=256,
...     y_range=(-1*ureg.mm, 1*ureg.mm), y_points=256,
...     wavelength=632.8*ureg.nm
... )

from_dict classmethod

from_dict(axes_dict: Mapping[str, Tensor | float]) -> Self

Create SimulationParameters from a dictionary.

Parameters:

  • axes_dict (Mapping[str, Tensor | float]) –

    Dictionary with axis names as keys and tensor/scalar values.

clone

clone() -> 'SimulationParameters'

Create a deep copy of the SimulationParameters instance.

Returns:

equal

equal(value: SimulationParameters) -> bool

Check equality with another SimulationParameters instance. The comparison between tensor axes is based on torch.equal, see documentation for more details. Comparing instances on diffrent devices will raise RuntimeError because torch.equal requires tensors to be on the same device.

Parameters:

Returns:

  • bool

    True if all axes are equal, False otherwise.

meshgrid

meshgrid(x_axis: str, y_axis: str) -> tuple[Tensor, Tensor]

Create a coordinate meshgrid from two axes.

Parameters:

  • x_axis (str) –

    Name of the axis for x-coordinates (typically 'x').

  • y_axis (str) –

    Name of the axis for y-coordinates (typically 'y').

Returns:

  • tuple[Tensor, Tensor]

    2D coordinate grids with 'xy' indexing convention.

Examples:

import svetlanna as sv
import torch

sim_params = sv.SimulationParameters(
    x=torch.linspace(-0.5, 0.5, 10),
    y=torch.linspace(-0.5, 0.5, 12),
    wavelength=1,
)

X, Y = sim_params.meshgrid("x", "y")
print(X.shape)  # torch.Size([12, 10])

axis_sizes

axis_sizes(axs: tuple[str, ...] | None = None) -> Size

Get the size of specified axes in order (cached for performance).

Parameters:

  • axs (tuple[str, ...] | None, default: None ) –

    Tuple of axis names in the desired order.

Returns:

  • Size

    Size object with lengths of specified axes.

Examples:

import svetlanna as sv
import torch

sim_params = sv.SimulationParameters(
    x=torch.linspace(-0.5, 0.5, 10),
    y=torch.linspace(-0.5, 0.5, 12),
    wavelength=1,
)

print(sim_params.axis_sizes(('y', 'x')))  # torch.Size([12, 10])

index

index(name: str) -> int

Get the negative index of an axis in tensors.

Parameters:

  • name (str) –

    Name of the axis.

Returns:

  • int

    Negative index for use in tensor operations.

Raises:

  • AxisNotFound

    If the axis doesn't exist or is scalar.

cast

cast(
    tensor: Tensor, *axes: str, shape_check: bool = True
) -> Tensor

Cast tensor to match simulation parameters axes for broadcasting.

Reshapes tensor so it can be broadcast with wavefront tensors. Scalar axes are skipped (they don't affect tensor shape).

Parameters:

  • tensor (Tensor) –

    Input tensor whose trailing dimensions correspond to axes.

  • *axes (str, default: () ) –

    Axes names corresponding to tensor's trailing dimensions.

Returns:

  • Tensor

    Tensor reshaped for broadcasting with wavefront.

Examples:

import svetlanna as sv
import torch

sim_params = sv.SimulationParameters(
    x=torch.linspace(-0.5, 0.5, 3),
    y=torch.linspace(-0.5, 0.5, 2),
    wavelength=torch.linspace(1, 2, 5),
)
# axes: (wavelength, y, x)
print(sim_params.axis_sizes(("wavelength", "y", "x")))  # torch.Size([5, 2, 3])

a = torch.rand(2, 3)  # y, x
a = sim_params.cast(a, "y", "x")
print(a.shape)  # torch.Size([1, 2, 3])
# a is now ready to broadcast with tensor of shape (5, 2, 3)