Multi Replica PDF
Created by: APJansen
Question
This will be some work, so before continuing past this I'd like to confirm that you agree that once finished this will be a beneficial change.
Idea
The idea of this PR is to refactor the tensorflow model from taking a list of single-replica pdfs into taking a single multiple replica pdf, a single pdf whose output has an extra axis representing the replica. This is much faster on the GPU, see tests below.
The main ingredient to make this possible is a MultiDense
layer, (see here) which is essentially just a dense layer where the weights have one extra dimension, with size the number of replicas. For the first layer, which takes x's as input, this is exactly it. For deeper layers, the input already has a replica axis, and so the right index of the input has to be multiplied by the corresponding axis of the weights.
Development Strategy
To integrate this into the code, many small changes are necessary. To make it as simple as possible to review and test, I aim to make small, independent changes that ideally are beneficial, or at least not detrimental, on their own. Wherever it's sensible I'll first create a unit test that covers the changes I want to make, and make sure it still passes after, and wherever possible I'll try to have the outputs be identical up to numerical errors. I'll put all of these on their own branch and with their own PR (maybe I should create a special label for those PRs?).
Once those small changes are merged, the actual implementation should be easily managable to review.
This PR itself for now is a placeholder, where I just added the commit so that I can create a draft PR and so you can check out the MultiDense
layer.
I expect that as a final result you'll still want single replica pdf. I will add code that, once all computations are done, just splits the multi replica pdf into single ones, so the saving and any interaction with validphys will remain unchanged.
Performance
Timing
These are the timing tests I did on a 1/4 node on Snellius, with one GPU. I'm reporting the average seconds per epoch that is printed in debug mode.
runcard | replicas | multi_replica_pdf_test | trvl-mask-layers | master |
---|---|---|---|---|
Basic | 200 | 0.12 | 1.2 | 2.3 |
NNPDF40_nnlo_as_01180_1000 | 200 | out of memory | out of memory | - |
NNPDF40_nnlo_as_01180_1000 | 100 | 0.76 | 1.12 | - |
Memory
Memory also appears to be significantly reduced. I checked the peak cpu memory usage using libmemprofile, on the basic runcard with 200 replicas, and found 3.5Gb versus 16.5 for the trvl-mask-layers branch.
Status
I have a test branch where this is working up to the end of the model training, which is what I used to obtain the timings above.
branch | finished | tested | merged | comments |
---|---|---|---|---|
refactor_xintegrator | X | unit | X | |
refactor_msr | X | unit | X | |
refactor_preprocessing | X | unit | X | |
refactor_rotations | X | unit | X | |
refactor_stopping | X | unit | X | |
multi-dense-logistics | currently working on this | |||
multi_replica_pdf-test | This is my test branch, which has the 4 above, and trvl-mask-layers, merged into it and has the code that will eventually go into this PR |