Improved memory utilization in API
Created by: stephenroller
🚀 API Performance Improvements
The current API implementation seems to have a memory leak, preventing larger batch sizes from being used. While I haven't instrumented this yet, my gut instinct is it's coming from fragmentation.
In particular, I believe the beam reordering aspects of generation are causing fragmentation. Whenever we reorder incremental state or the logprobs
It's likely that we're creating partial views that hang onto tensors and use too much memory when that state is no longer needed.
Furthermore, the returned logits are also likely highly fragmented:
In all likelihood, a very carefully placed .continguous() or two is likely to significantly relieve memory pressure.
As a first step, we can add a special case check that the indexes of reordering are not [0, 1, 2, 3, ...] which will create an expensive view but output the same tensor as the input. If we are doing such an identity reordering, we should avoid actually calling reorder_incremental_state.