Coverage for chempropstereo/featurizers/atom.py: 96%

50 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-22 21:04 +0000

1"""Atom featurization. 

2 

3.. module:: featurizers.atom 

4.. moduleauthor:: Charlles Abreu <craabreu@mit.edu> 

5""" 

6 

7import chemprop 

8import numpy as np 

9from rdkit import Chem 

10 

11from .. import stereochemistry 

12from . import utils 

13 

14 

15class AtomCIPFeaturizer(chemprop.featurizers.MultiHotAtomFeaturizer): 

16 """Multi-hot atom featurizer that includes a CIP code if the atom is a stereocenter. 

17 

18 The featurized atoms are expected to be part of an RDKit molecule with CIP labels 

19 assigned via the `AssignCIPLabels`_ function. 

20 

21 .. _AssignCIPLabels: https://www.rdkit.org/docs/source/\ 

22rdkit.Chem.rdCIPLabeler.html#rdkit.Chem.rdCIPLabeler.AssignCIPLabels 

23 

24 Parameters 

25 ---------- 

26 mode : featurizers.AtomFeatureMode 

27 The mode to use for the featurizer. Available modes are `V1`_, `V2`_, and 

28 `ORGANIC`_. 

29 

30 .. _V1: https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/\ 

31featurizers/atom/index.html#chemprop.featurizers.atom.MultiHotAtomFeaturizer.v1 

32 .. _V2: https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/\ 

33featurizers/atom/index.html#chemprop.featurizers.atom.MultiHotAtomFeaturizer.v2 

34 .. _ORGANIC: https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/\ 

35featurizers/atom/index.html#chemprop.featurizers.atom.MultiHotAtomFeaturizer.organic 

36 

37 Examples 

38 -------- 

39 >>> from chempropstereo import AtomCIPFeaturizer 

40 >>> from rdkit import Chem 

41 >>> r_mol = Chem.MolFromSmiles("C[C@H](N)O") 

42 >>> s_mol = Chem.MolFromSmiles("C[C@@H](N)O") 

43 >>> for mol in [r_mol, s_mol]: 

44 ... Chem.AssignCIPLabels(mol) 

45 >>> r_atom = r_mol.GetAtomWithIdx(1) 

46 >>> s_atom = s_mol.GetAtomWithIdx(1) 

47 >>> featurizer = AtomCIPFeaturizer("ORGANIC") 

48 >>> for atom in [r_atom, s_atom]: 

49 ... features = featurizer(atom) 

50 ... assert len(features) == len(featurizer) 

51 ... print("".join(map(str, features))) 

52 0010000000000000010000001001000100000001000 

53 0010000000000000010000001000100100000001000 

54 

55 """ 

56 

57 def __init__(self, mode: str | chemprop.featurizers.AtomFeatureMode = "V2") -> None: 

58 featurizer = chemprop.featurizers.get_multi_hot_atom_featurizer( 

59 chemprop.featurizers.AtomFeatureMode.get(mode) 

60 ) 

61 super().__init__( 

62 atomic_nums=featurizer.atomic_nums, 

63 degrees=featurizer.degrees, 

64 formal_charges=featurizer.formal_charges, 

65 chiral_tags=list(range(3)), 

66 num_Hs=featurizer.num_Hs, 

67 hybridizations=featurizer.hybridizations, 

68 ) 

69 

70 def __call__(self, a: Chem.Atom | None) -> np.ndarray: 

71 """Featurize an RDKit atom with stereochemical information. 

72 

73 Parameters 

74 ---------- 

75 a : Chem.Atom | None 

76 The atom to featurize. If None, returns a zero array. 

77 

78 Returns 

79 ------- 

80 np.ndarray 

81 A 1D array of shape `(len(self),)` containing the following features: 

82 - One-hot encoding of the atomic number 

83 - One-hot encoding of the total degree 

84 - One-hot encoding of the formal charge 

85 - One-hot encoding of the CIP code 

86 - One-hot encoding of the total number of hydrogens 

87 - One-hot encoding of the hybridization 

88 - Boolean indicating whether the atom is aromatic 

89 - Mass of the atom divided by 100 

90 

91 """ 

92 x = np.zeros(len(self), int) 

93 

94 if a is None: 

95 return x 

96 

97 feats = [ 

98 a.GetAtomicNum(), 

99 a.GetTotalDegree(), 

100 a.GetFormalCharge(), 

101 stereochemistry.get_cip_code(a), 

102 int(a.GetTotalNumHs()), 

103 a.GetHybridization(), 

104 ] 

105 

106 i = 0 

107 for feat, choices in zip(feats, self._subfeats): 

108 j = choices.get(feat, len(choices)) 

109 x[i + j] = 1 

110 i += len(choices) + 1 

111 x[i] = int(a.GetIsAromatic()) 

112 x[i + 1] = 0.01 * a.GetMass() 

113 

114 return x 

115 

116 

117_SCAN_DIRECTIONS: tuple[stereochemistry.ScanDirection, ...] = ( 

118 stereochemistry.ScanDirection.CW, 

119 stereochemistry.ScanDirection.CCW, 

120) 

121 

122 

123class AtomStereoFeaturizer(chemprop.featurizers.base.VectorFeaturizer[Chem.Atom]): 

124 """Multi-hot atom featurizer that includes a canonical chiral tag for each atom. 

125 

126 The featurized atoms are expected to be part of an RDKit molecule with canonical 

127 chiral tags assigned via :func:`~stereochemistry.tag_tetrahedral_stereocenters`. 

128 

129 Parameters 

130 ---------- 

131 mode 

132 The mode to use for the featurizer. Available modes are `V1`_, `V2`_, and 

133 `ORGANIC`_. 

134 

135 .. _V1: https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/\ 

136featurizers/atom/index.html#chemprop.featurizers.atom.MultiHotAtomFeaturizer.v1 

137 .. _V2: https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/\ 

138featurizers/atom/index.html#chemprop.featurizers.atom.MultiHotAtomFeaturizer.v2 

139 .. _ORGANIC: https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/\ 

140featurizers/atom/index.html#chemprop.featurizers.atom.MultiHotAtomFeaturizer.organic 

141 

142 Examples 

143 -------- 

144 >>> from chempropstereo import featurizers, stereochemistry 

145 >>> from rdkit import Chem 

146 >>> featurizer = featurizers.AtomStereoFeaturizer("ORGANIC") 

147 >>> for smi in ["C[C@H](N)O", "C[C@@H](N)O"]: 

148 ... mol = Chem.MolFromSmiles(smi) 

149 ... stereochemistry.tag_stereogroups(mol) 

150 ... print(f"Molecule: {smi}") 

151 ... for atom in mol.GetAtoms(): 

152 ... print(featurizer.pretty_print(atom)) 

153 Molecule: C[C@H](N)O 

154 0: 0010000000000 0000100 000010 001 000100 00010 0 0.120 

155 1: 0010000000000 0000100 000010 100 010000 00010 0 0.120 

156 2: 0001000000000 0001000 000010 001 001000 00010 0 0.140 

157 3: 0000100000000 0010000 000010 001 010000 00010 0 0.160 

158 Molecule: C[C@@H](N)O 

159 0: 0010000000000 0000100 000010 001 000100 00010 0 0.120 

160 1: 0010000000000 0000100 000010 010 010000 00010 0 0.120 

161 2: 0001000000000 0001000 000010 001 001000 00010 0 0.140 

162 3: 0000100000000 0010000 000010 001 010000 00010 0 0.160 

163 

164 """ 

165 

166 def __init__(self, mode: str | chemprop.featurizers.AtomFeatureMode) -> None: 

167 featurizer = chemprop.featurizers.get_multi_hot_atom_featurizer( 

168 chemprop.featurizers.AtomFeatureMode.get(mode) 

169 ) 

170 self.atomic_nums = featurizer.atomic_nums 

171 self.degrees = featurizer.degrees 

172 self.formal_charges = featurizer.formal_charges 

173 self.scan_directions = _SCAN_DIRECTIONS 

174 self.num_Hs = featurizer.num_Hs 

175 self.hybridizations = featurizer.hybridizations 

176 self._len = sum(self.sizes) 

177 

178 def __len__(self): 

179 return self._len 

180 

181 def __call__(self, a: Chem.Atom | None) -> np.ndarray: 

182 """Featurize an RDKit atom with stereochemical information. 

183 

184 Parameters 

185 ---------- 

186 a 

187 The atom to featurize. 

188 

189 Returns 

190 ------- 

191 np.ndarray 

192 A 1D array of shape `(len(self),)` containing the following features: 

193 - `atomic_num`: one-hot encoding of the atomic number 

194 - `total_degree`: one-hot encoding of the total degree 

195 - `formal_charge`: one-hot encoding of the formal charge 

196 - `scan_direction`: one-hot encoding of the scan direction 

197 - `total_num_hs`: one-hot encoding of the total number of Hs 

198 - `hybridization`: one-hot encoding of the hybridization 

199 - `is_aromatic`: boolean indicating whether the atom is aromatic 

200 - `mass`: mass of the atom divided by 100 

201 

202 """ 

203 if a is None: 

204 return np.zeros(len(self)) 

205 

206 atomic_num = a.GetAtomicNum() 

207 total_degree = a.GetTotalDegree() 

208 formal_charge = a.GetFormalCharge() 

209 total_num_hs = a.GetTotalNumHs() 

210 hybridization = a.GetHybridization() 

211 scan_direction = stereochemistry.ScanDirection.get_from(a) 

212 

213 return np.array( 

214 [ 

215 *(atomic_num == item for item in self.atomic_nums), 

216 atomic_num not in self.atomic_nums, 

217 *(total_degree == item for item in self.degrees), 

218 total_degree not in self.degrees, 

219 *(formal_charge == item for item in self.formal_charges), 

220 formal_charge not in self.formal_charges, 

221 *(scan_direction == item for item in self.scan_directions), 

222 scan_direction not in self.scan_directions, 

223 *(total_num_hs == item for item in self.num_Hs), 

224 total_num_hs not in self.num_Hs, 

225 *(hybridization == item for item in self.hybridizations), 

226 hybridization not in self.hybridizations, 

227 a.GetIsAromatic(), 

228 0.01 * a.GetMass(), 

229 ], 

230 dtype=float, 

231 ) 

232 

233 @property 

234 def sizes(self) -> tuple[int, ...]: 

235 """Get a tuple of sizes corresponding to different atom features. 

236 

237 The tuple contains the sizes for: 

238 - Atomic numbers 

239 - Total degrees 

240 - Formal charges 

241 - Scan directions 

242 - Total numbers of Hs 

243 - Hybridizations 

244 - Aromatic indicator 

245 - Mass 

246 

247 Returns 

248 ------- 

249 tuple[int, ...] 

250 A tuple of integers representing the sizes of each atom feature. 

251 

252 Examples 

253 -------- 

254 >>> from chempropstereo import featurizers 

255 >>> featurizer = featurizers.AtomStereoFeaturizer("ORGANIC") 

256 >>> featurizer.sizes 

257 (13, 7, 6, 3, 6, 5, 1, 1) 

258 

259 """ 

260 return ( 

261 len(self.atomic_nums) + 1, 

262 len(self.degrees) + 1, 

263 len(self.formal_charges) + 1, 

264 len(self.scan_directions) + 1, 

265 len(self.num_Hs) + 1, 

266 len(self.hybridizations) + 1, 

267 1, 

268 1, 

269 ) 

270 

271 def pretty_print(self, a: Chem.Atom | None) -> str: 

272 """Get a formatted string representation of the atom features. 

273 

274 Parameters 

275 ---------- 

276 a : Chem.Atom or None 

277 The atom to be described. If None, a null atom is assumed. 

278 

279 Returns 

280 ------- 

281 str 

282 A formatted string representing the atom features. 

283 

284 Examples 

285 -------- 

286 >>> from rdkit import Chem 

287 >>> from chempropstereo import featurizers 

288 >>> mol = Chem.MolFromSmiles('CC') 

289 >>> atom = mol.GetAtomWithIdx(0) 

290 >>> featurizer = featurizers.AtomStereoFeaturizer("ORGANIC") 

291 >>> featurizer.pretty_print(atom) 

292 ' 0: 0010000000000 0000100 000010 001 000100 00010 0 0.120' 

293 

294 """ 

295 return utils.describe_atom_features(a.GetIdx(), self(a), self.sizes)