Go from one step being an epoch to one step being a batch
Created by: APJansen
As part of the changes mentioned in issue #1803, this is a very simple change that results in a factor 2 speedup.
It simply copies the input x grids by the number of epochs, and calls one step a batch rather than an epoch. This avoids some Tensorflow overhead that, with the other improvements mentioned there, take up nearly 50% of the total training time.
The current state is that the fit runs, but some changes need to be made downstream as it's crashing (perhaps just undoing the changes I made just before the fit, just after the fit?).
If anyone wants to take this up, please do.
To illustrate what this does, here is a tensorboard profile without doing this: These gaps are almost completely removed by this PR.