diff --git a/n3fit/src/n3fit/checks.py b/n3fit/src/n3fit/checks.py index 0c3c9e59ecfd84147c635d56fc14fa7c7e298435..ef7385a812d8f30b58cd93d34295b5fde6e0c67a 100644 --- a/n3fit/src/n3fit/checks.py +++ b/n3fit/src/n3fit/checks.py @@ -2,6 +2,8 @@ This module contains checks to be perform by n3fit on the input """ import logging +import numbers +import numpy as np from reportengine.checks import make_argcheck, CheckError from validphys.pdfbases import check_basis from n3fit.hyper_optimization import penalties as penalties_module @@ -11,6 +13,19 @@ log = logging.getLogger(__name__) NN_PARAMETERS = ["nodes_per_layer", "optimizer", "activation_per_layer"] + +def _is_floatable(num): + """Check that num is a number or, worst case scenario, a number that can + be casted to a float (such as a tf scalar)""" + if isinstance(num, numbers.Number): + return True + try: + np.float(num) + return True + except (ValueError, TypeError): + return False + + # Checks on the NN parameters def check_existing_parameters(parameters): """ Check that non-optional parameters are defined and are not empty """ @@ -96,6 +111,7 @@ def check_dropout(parameters): def check_tensorboard(tensorboard): + """ Check that the tensorbard callback can enabled correctly """ if tensorboard is not None: weight_freq = tensorboard.get("weight_freq", 0) if weight_freq < 0: @@ -104,6 +120,20 @@ def check_tensorboard(tensorboard): ) +def check_lagrange_multipliers(parameters, key): + """Checks the parameters in a lagrange multiplier dictionary + are correct (positivity, integrability)""" + lagrange_dict = parameters.get(key) + if lagrange_dict is None: + return + multiplier = lagrange_dict.get("multiplier") + if multiplier is not None and multiplier <= 0: + log.warning("The %s multiplier is below 0, it will produce a negative loss", key) + threshold = lagrange_dict.get("threshold") + if threshold is not None and not _is_floatable(threshold): + raise CheckError(f"The {key}::threshold must be a number, received: {threshold}") + + @make_argcheck def wrapper_check_NN(fitting): """ Wrapper function for all NN-related checks """ @@ -114,6 +144,8 @@ def wrapper_check_NN(fitting): check_basis_with_layers(fitting, parameters) check_stopping(parameters) check_dropout(parameters) + check_lagrange_multipliers(parameters, "integrability") + check_lagrange_multipliers(parameters, "positivity") # Checks that need to import the backend (and thus take longer) should be done last check_optimizer(parameters["optimizer"]) check_initializer(parameters["initializer"]) @@ -167,14 +199,6 @@ def check_hyperopt_positivity(positivity_dict): ) if max_ini <= min_ini: raise CheckError("The minimum initial value cannot be greater than the maximum") - threshold = positivity_dict.get("threshold") - if threshold is not None: - try: - 4.0 < threshold - except TypeError as e: - raise CheckError( - f"The positivity::threshold must be a number, received: {threshold}" - ) from e def check_kfold_options(kfold):