Coverage for cosmolayer / utils.py: 76%
17 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 14:25 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 14:25 +0000
1"""
2.. module:: cosmolayer.utils
3 :synopsis: Utility functions for the COSMO-related computations.
5.. functionauthor:: Charlles Abreu <craabreu@gmail.com>
6"""
8import inspect
10import torch
13def log_matmul_exp(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
14 r"""Compute :math:`\log(\exp(A) \exp(B))` stably in log-space.
16 Parameters
17 ----------
18 A : torch.Tensor
19 Tensor of shape (..., M, K).
20 B : torch.Tensor
21 Tensor of shape (..., K, N).
23 Returns
24 -------
25 torch.Tensor
26 Tensor of shape (..., M, N).
27 """
28 if A.shape[-1] != B.shape[-2]:
29 raise ValueError("Last dimension of A must match second-to-last dimension of B")
30 return torch.logsumexp(A.unsqueeze(-1) + B.unsqueeze(-3), dim=-2)
33def is_loss_function(func: object) -> bool:
34 if not callable(func):
35 return False
37 try:
38 sig = inspect.signature(func)
39 except (TypeError, ValueError):
40 return False
42 params = list(sig.parameters.values())
44 if len(params) < 2: # noqa: PLR2004
45 return False
47 return params[0].name == "input" and params[1].name == "target"