Skip to content

Make n3fit compatible with TF > 2.1

Emanuele Roberto Nocera requested to merge beyond_tf21 into master

Created by: scarlehoff

Now that the release candidate of TF 2.3 is out I thought it was time to revisit this. This PR aims to make TF compatible with the versions beyond 2.1.

The current code is already functional and seem to be as fast as it was before (actually, almost a 10% faster, but it could be a fluctuation, the speed seem to be the same for 2.1). But I want to do some proper benchmarks first and I need to clean it so it looks less hacky (making sure the imports are in the correct place and so).

The problem

TF 2.2 removed the target_tensors argument from the model.compile method (without any mentions in the changelog) which broke the way n3fit works. I should note they silently removed it from the documentation as well. It might be that it was a bug of TF 2.1 that I was exploiting and not a feature, who knows. The right approach is to pass the target when performing the fit, but doing it in this way triggers a partial recompilation of the code for every epoch. This resulted in an overhead of a factor of 2 when performing fits. This is true for the training and well as the validation model.

The solution

The solution for the training model is easy. Doing it "the correct way" now I call the fit method with the total number of epochs and pass it two callbacks, one for the dynamical positivity lambda and another for early stopping. Then at the end of every epoch it does whatever it needs to do without triggering recompilations.

The solution for the validation model not so much because our training and validation models are different (different covmats, for instance) so we cannot use the TF API for that. Tensorflow assumes everywhere that the validation and the training refer to the same model (i.e., you are allowed to have training/validation data but not training/validation models).

The solution I've come up with (I'll iterate it a bit before marking the PR as ready, the commit is called first trial because I wasn't sure it would work) is to create a different function and compile the target data together with the loss. This is the hacky part of the code and I need to make it a bit more robust before calling it final (so that when TF 2.4 comes out it doesn't break because of some other stupid reason).

This works with TF 2.2 and 2.3.

Note: the fact that target_tensors is broken is right now accepted as a bug https://github.com/tensorflow/tensorflow/issues/41248 but I think given they conciously removed it from the docs it is likely they will drop support for it, so I think we should merge this rather than hoping for TF to fix the bug.

Merge request reports

Loading