You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add a new loss in the cross_entropy_loss.py file that inherits from SFT loss but calls the Liger fused_linear_cross_entropy loss. It will need to handle if the input is a DTensor and convert it before calling the liger loss.
Edge Case: if the model output is a tied embedding and TP sharded (DTensor). Then either we'll have to unshard and then reshard the weight every step, or throw an error for that case. (This assumes that liger losses don't work with sharded weights)
A good validation of this feature would be to see if this loss even further improves the numbers here over compiled linear cross entropy loss.
The text was updated successfully, but these errors were encountered:
Add a new loss in the cross_entropy_loss.py file that inherits from SFT loss but calls the Liger fused_linear_cross_entropy loss. It will need to handle if the input is a DTensor and convert it before calling the liger loss.
Edge Case: if the model output is a tied embedding and TP sharded (DTensor). Then either we'll have to unshard and then reshard the weight every step, or throw an error for that case. (This assumes that liger losses don't work with sharded weights)
A good validation of this feature would be to see if this loss even further improves the numbers here over compiled linear cross entropy loss.
The text was updated successfully, but these errors were encountered: