SimulationParameters
svetlanna.SimulationParameters
SimulationParameters(
axes: Mapping[str, Tensor | float] | None = None,
/,
**kwaxes: Tensor | float,
)
Bases: Module
Simulation parameters.
Manages coordinate systems and physical parameters for optical simulations.
Required axes: x, y, wavelength.
Additional axes can be added.
Inherits from nn.Module so that axes are registered as buffers and
participate in automatic device management when used as submodules of
Elements.
Note
Axes are registered as non-persistent buffers (persistent=False).
This means they are not included in state_dict() and will not
be saved during checkpointing. The simulation grid must be provided
when constructing the model; it does not need to be restored from
a checkpoint.
Examples:
Let's define simalation grid of width and height of 1 mm with 512 points for both axes (Nx=Ny=512) and wavelength of 632.8 nm:
wavelength an array for polychromatic simulations:
The order of axes matters! It defines the order of dimensions in wavefront tensors.
In first case above, all optical elements will expect wavefront tensors with shape (..., Ny, Nx),
while in the second case, the expected shape will be (..., Nwavelength, Ny, Nx).
... means any number of leading dimensions (e.g., for batch).
If you change the order:
('y', 'wavelength', 'x'), so all optical elements will expect wavefront tensors with shape (..., Ny, Nwavelength, Nx).
You can add custom axes as needed:
('y', 'x', 't') as wavelength is scalar, so all optical elements will expect wavefront tensors with shape (..., Ny, Nx, Nt).
axis_names
property
Get names of non-scalar axes (those with length > 1).
from_ranges
classmethod
from_ranges(
*,
x_range: tuple[float, float],
x_points: int,
y_range: tuple[float, float],
y_points: int,
wavelength: Tensor | float,
**additional_axes: Tensor | float
) -> Self
Create SimulationParameters from coordinate ranges.
Parameters:
-
x_range(tuple[float, float]) –(min, max) range for x-axis. Use
uregfor units. -
x_points(int) –Number of points along x-axis.
-
y_range(tuple[float, float]) –(min, max) range for y-axis. Use
uregfor units. -
y_points(int) –Number of points along y-axis.
-
wavelength(Tensor | float) –Optical wavelength. Use
uregfor units. -
**additional_axes(Tensor | float, default:{}) –Additional axes.
Examples:
from_dict
classmethod
Create SimulationParameters from a dictionary.
Parameters:
-
axes_dict(Mapping[str, Tensor | float]) –Dictionary with axis names as keys and tensor/scalar values.
clone
Create a deep copy of the SimulationParameters instance.
Returns:
-
SimulationParameters–A new instance with cloned axes.
equal
equal(value: SimulationParameters) -> bool
Check equality with another SimulationParameters instance.
The comparison between tensor axes is based on torch.equal,
see documentation for more details.
Comparing instances on diffrent devices will raise RuntimeError because torch.equal requires tensors to be on the same device.
Parameters:
-
value(SimulationParameters) –SimulationParameters instance to compare with.
Returns:
-
bool–Trueif all axes are equal,Falseotherwise.
meshgrid
Create a coordinate meshgrid from two axes.
Parameters:
-
x_axis(str) –Name of the axis for x-coordinates (typically 'x').
-
y_axis(str) –Name of the axis for y-coordinates (typically 'y').
Returns:
-
tuple[Tensor, Tensor]–2D coordinate grids with 'xy' indexing convention.
Examples:
axis_sizes
Get the size of specified axes in order (cached for performance).
Parameters:
-
axs(tuple[str, ...] | None, default:None) –Tuple of axis names in the desired order.
Returns:
-
Size–Size object with lengths of specified axes.
Examples:
index
Get the negative index of an axis in tensors.
Parameters:
-
name(str) –Name of the axis.
Returns:
-
int–Negative index for use in tensor operations.
Raises:
-
AxisNotFound–If the axis doesn't exist or is scalar.
cast
Cast tensor to match simulation parameters axes for broadcasting.
Reshapes tensor so it can be broadcast with wavefront tensors. Scalar axes are skipped (they don't affect tensor shape).
Parameters:
-
tensor(Tensor) –Input tensor whose trailing dimensions correspond to
axes. -
*axes(str, default:()) –Axes names corresponding to tensor's trailing dimensions.
Returns:
-
Tensor–Tensor reshaped for broadcasting with wavefront.
Examples: