Coverage for cosmolayer / utils.py: 76%

17 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-11 14:25 +0000

1""" 

2.. module:: cosmolayer.utils 

3 :synopsis: Utility functions for the COSMO-related computations. 

4 

5.. functionauthor:: Charlles Abreu <craabreu@gmail.com> 

6""" 

7 

8import inspect 

9 

10import torch 

11 

12 

13def log_matmul_exp(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: 

14 r"""Compute :math:`\log(\exp(A) \exp(B))` stably in log-space. 

15 

16 Parameters 

17 ---------- 

18 A : torch.Tensor 

19 Tensor of shape (..., M, K). 

20 B : torch.Tensor 

21 Tensor of shape (..., K, N). 

22 

23 Returns 

24 ------- 

25 torch.Tensor 

26 Tensor of shape (..., M, N). 

27 """ 

28 if A.shape[-1] != B.shape[-2]: 

29 raise ValueError("Last dimension of A must match second-to-last dimension of B") 

30 return torch.logsumexp(A.unsqueeze(-1) + B.unsqueeze(-3), dim=-2) 

31 

32 

33def is_loss_function(func: object) -> bool: 

34 if not callable(func): 

35 return False 

36 

37 try: 

38 sig = inspect.signature(func) 

39 except (TypeError, ValueError): 

40 return False 

41 

42 params = list(sig.parameters.values()) 

43 

44 if len(params) < 2: # noqa: PLR2004 

45 return False 

46 

47 return params[0].name == "input" and params[1].name == "target"