Implement FkRotation as subclass of Rotation by rewriting using a rotation tensor
Created by: APJansen
Implements an old TODO by @scarlehoff to write the FkRotation
layer as a subclass of Rotation
.
Unlike for FlavourToEvolution
the rotation matrix is fixed, so I thought it made sense to define it right in the class.
Unrelated to this, the docstring isn't accurate, referring to 8 dimensions while actually there are 9. This may be a good occasion to fix it, but I'm not sure what to write as I don't know where the +1 is from, any suggestion?
I've verified that this gives the same output using the script below, and also checked that the new way is about 3-5 times faster.
rot = FkRotation()
test_x = tf.random.normal(shape=(1, 20, 9))
out1 = rot(text_x)
out2 = rot.call_old(test_x)
print(out1 - out2)