tsadar.inverse.loss_function#
Classes
|
LossFunction is a class responsible for managing the forward pass and loss computation for inverse Thomson scattering analysis. |
- class tsadar.inverse.loss_function.LossFunction(cfg: Dict, scattering_angles, dummy_batch)Source#
LossFunction is a class responsible for managing the forward pass and loss computation for inverse Thomson scattering analysis. This class encapsulates the logic for: - Normalizing input and output data based on configuration. - Computing theoretical spectra using a ThomsonScatteringDiagnostic instance. - Calculating loss values and gradients for optimization, supporting various loss metrics (L1, L2, log-cosh, Poisson). - Handling multiplexed analysis with EDF rotation if required. - Applying additional penalties and moment regularization to the loss. - Providing interfaces for loss, gradient, and Hessian computation compatible with optimization routines. .. attribute:: cfg
Configuration dictionary constructed from user inputs.
- type:
Dict
- ts_diag#
Diagnostic object for theoretical spectrum calculation.
- multiplex_ang#
Indicates if multiplexed analysis with EDF rotation is enabled.
- Type:
bool
- i_norm, e_norm
Normalization factors for output data.
- Type:
float
- i_input_norm, e_input_norm
Normalization factors for input data.
- Type:
float
- _loss_, _vg_func_, _h_func_
JIT-compiled loss, value-and-grad, and Hessian functions.
- Type:
callable
- array_loss#
JIT-compiled postprocessing loss function.
- Type:
callable
- __init__(cfg, scattering_angles, dummy_batch)Source#
Initializes the LossFunction with configuration, angles, and dummy data for normalization.
- vg_loss(diff_weights, static_weights, batch)Source#
Computes the loss value and gradient with respect to weights for optimization.
- h_loss_wrt_params(weights, batch)Source#
Computes the Hessian of the loss with respect to parameters.
- calc_ei_error(batch, ThryI, lamAxisI, ThryE, lamAxisE, uncert, reduce_func)Source#
Calculates the error between experimental and theoretical spectra for IAW and EPW.
- calc_loss(ts_params, batch, denom, reduce_func)Source#
Computes the total loss, including penalties and normalization, for a given parameter set and batch.
- __loss__(diff_weights, static_weights, batch)Source#
Internal loss function wrapper for optimization routines.
- loss_functionals(d, t, uncert, method='l2')Source#
Computes the element-wise loss between data and theory using the specified metric.
- penalties(weights)Source#
Computes additional penalties (e.g., parameter bounds, moment regularization) to be added to the loss.
- _moment_loss_(params)Source#
Computes regularization losses for the moments (density, temperature, momentum) of the distribution function.
- Usage:
Instantiate with configuration, scattering angles, and dummy data. Use vg_loss or loss for optimization routines.
- vg_loss(diff_weights, static_weights: Dict, batch: Dict)Source#
Computes the value of the loss function and its gradient with respect to the weights for optimization. This function serves as the main interface for evaluating the loss and its gradient, which are used to assess the goodness-of-fit and to update the model weights during optimization. It handles necessary pre- and post- processing steps required by the optimization software. The behavior of this function depends on the optimizer method specified in the configuration:
For “l-bfgs-b”, it unravels the weights, computes the loss and gradient, flattens the gradient, and returns both the loss value and the flattened gradient.
For other methods, it directly returns the result of the internal loss function, which is a PyTree.
- Parameters:
diff_weights – The differentiable (trainable) weights to be optimized, possibly in a flattened format.
static_weights (Dict) – The static (non-trainable) weights used in the computation.
batch (Dict) – The batch of data used for evaluating the loss and gradient.
- Returns:
Tuple[float, np.ndarray] or Any – - If using “l-bfgs-b” optimizer: Returns a tuple containing the loss value and the flattened gradient array. - Otherwise: Returns the result of the internal loss function, which is a tuple containing the loss value and the structured gradient tree.
- calc_ei_error(batch, ThryI, lamAxisI, ThryE, lamAxisE, uncert, reduce_func=<function mean>)Source#
Calculates the error metrics for ion and electron spectral fits based on theoretical and experimental data. This function computes the error between measured and theoretical spectra for both ion (IAW) and electron (EPW) features, applying configurable fitting ranges and loss methods. The errors are reduced using the specified reduction function (default is mean), and squared deviations are accumulated for further analysis. :param batch: Dictionary containing experimental data arrays with keys “i_data” (ion data) and “e_data” (electron data). :type batch: dict :param ThryI: Theoretical ion spectrum corresponding to i_data. :type ThryI: array-like :param lamAxisI: Wavelength axis for the ion spectrum. :type lamAxisI: array-like :param ThryE: Theoretical electron spectrum corresponding to e_data. :type ThryE: array-like :param lamAxisE: Wavelength axis for the electron spectrum. :type lamAxisE: array-like :param uncert: Tuple or list containing uncertainty arrays for ion and electron data, respectively. :type uncert: tuple or list :param reduce_func: Function to reduce the error array to a scalar (e.g., jnp.mean, jnp.sum). Defaults to jnp.mean. :type reduce_func: callable, optional
- Returns:
tuple – i_error (float): Reduced error metric for the ion feature (IAW). e_error (float): Reduced error metric for the electron feature (EPW). sqdev (dict): Dictionary with keys “ion” and “ele” containing arrays of squared deviations for ion and electron data, respectively.
Notes
The function uses configuration options from self.cfg to determine which features to fit and the wavelength ranges.
If both blue and red EPW features are fit, the electron error is averaged accordingly.
NaN values are used to mask out-of-range points and are handled with jnp.nan_to_num when accumulating squared deviations.
- calc_loss(ts_params, batch: Dict, denom, reduce_func)Source#
Calculates the total loss for the inverse Thomson scattering model, including electron and ion errors, and applies any necessary penalties. Handles both multiplexed and non-multiplexed angular configurations. :param ts_params: Dictionary of Thomson scattering parameters, including electron distribution. :type ts_params: dict :param batch: Batch of experimental data. If multiplex_ang is True, expects keys “b1” and “b2”. :type batch: Dict :param denom: Denominator(s) for normalization. If empty, will be set to theoretical values. :type denom: list or [] :param reduce_func: Function to reduce error arrays (e.g., sum, mean). :type reduce_func: callable
- Returns:
tuple – total_loss (float): The computed total loss value (sum of scaled ion error, electron error, and penalties). sqdev (Any): Squared deviation(s) between theoretical and experimental data. ThryE (Any): Theoretical electron spectrum. ThryI (Any): Theoretical ion spectrum. ts_params (dict): (Possibly updated) Thomson scattering parameters.
- loss(weights, batch: Dict)Source#
High level function that returns the value of the loss function for a given set of weights and a batch of data. Depending on the optimizer method specified in the configuration, this function may first convert the flat weights array into a pytree structure before computing the loss.
- Parameters:
weights – The weights to be used in the loss function, either in a flat format or as a pytree.
batch (Dict) – A dictionary containing the data to be used in the loss function.
- Returns:
float – The computed loss value.
- loss_functionals(d, t, uncert, method='l2')Source#
Computes the loss between predicted and target values using various loss functionals.
Parameters#
- darray-like
Data values.
- tarray-like
Theroetical values.
- uncertarray-like
Uncertainty values used for normalization.
- methodstr, optional
- The loss functional to use. Options are:
“l1”: Mean absolute error, normalized by uncertainty.
“l2”: Mean squared error, normalized by uncertainty.
“log-cosh”: Log-cosh loss.
“poisson”: Poisson loss.
Returns#
- _error_array-like
Computed loss values according to the selected method.
- penalties(weights)Source#
Computes the total penalty for the given model parameters (weights), including parameter constraints, optional moment losses, and an optional strict penalty on the electron distribution function. :param weights: Dictionary containing model parameters for each species. Each species entry is itself
a dictionary of parameter arrays.
- Returns:
jnp.ndarray – The total penalty value as a scalar.
- Penalties included:
Parameter penalty: Applies a log-based penalty to all parameters except ‘fe’ for each species.
Moment loss: If enabled in the configuration, adds density, temperature, and momentum losses.
Electron distribution penalty: If enabled in the configuration, penalizes increases in the electron distribution function (‘fe’) along the velocity axis.