This RFC seeks to include a new API in the array API specification for the purpose of computing the log of summed exponentials.
Overview
The Array API specification currently includes logaddexp which performs an element-wise operation on two input arrays, but does not include the reduction logsumexp. This API is commonly implemented in accelerator libraries for better numerical stability in deep learning applications.
- logaddexp:
|
def logaddexp(x1: array, x2: array, /) -> array: |
This can be implemented using log(sum(exp)); however, such an implementation is not likely to be numerically stable.
Prior art
Proposal:
def logsumexp(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[dtype] = None, keepdims: bool = False) -> array
dtype kwarg is for consistency with sum et al
Related
cc @kgryte
This RFC seeks to include a new API in the array API specification for the purpose of computing the log of summed exponentials.
Overview
The Array API specification currently includes
logaddexpwhich performs an element-wise operation on two input arrays, but does not include the reductionlogsumexp. This API is commonly implemented in accelerator libraries for better numerical stability in deep learning applications.array-api/src/array_api_stubs/_2022_12/elementwise_functions.py
Line 1533 in 3d91878
This can be implemented using
log(sum(exp)); however, such an implementation is not likely to be numerically stable.Prior art
logaddexp.reduce.scipy.specialnamespace.torch.special: https://pytorch.org/docs/stable/special.html#torch.special.logsumexp)Proposal:
dtypekwarg is for consistency withsumet alRelated
cc @kgryte