Coverage for chempropstereo/featurizers/molecule.py: 90%

52 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-22 21:04 +0000

1"""Molecule featurization. 

2 

3.. module:: featurizers.molecule 

4.. moduleauthor:: Charlles Abreu <craabreu@mit.edu> 

5""" 

6 

7import chemprop 

8import numpy as np 

9from rdkit import Chem 

10 

11from .. import stereochemistry 

12from . import utils 

13from .atom import AtomCIPFeaturizer, AtomStereoFeaturizer 

14from .bond import BondStereoFeaturizer 

15 

16 

17class MoleculeCIPFeaturizer(chemprop.featurizers.SimpleMoleculeMolGraphFeaturizer): 

18 """Molecule featurizer that includes CIP codes for stereocenters. 

19 

20 Examples 

21 -------- 

22 >>> from chempropstereo import MoleculeCIPFeaturizer 

23 >>> from rdkit import Chem 

24 >>> import numpy as np 

25 >>> r_mol = Chem.MolFromSmiles("C[C@H](N)O") 

26 >>> s_mol = Chem.MolFromSmiles("C[C@@H](N)O") 

27 >>> featurizer = MoleculeCIPFeaturizer() 

28 >>> r_molgraph = featurizer(r_mol) 

29 >>> s_molgraph = featurizer(s_mol) 

30 >>> assert not np.array_equal(r_molgraph.V, s_molgraph.V) 

31 >>> assert np.array_equal(r_molgraph.E, s_molgraph.E) 

32 

33 """ 

34 

35 def __init__(self): 

36 super().__init__( 

37 atom_featurizer=AtomCIPFeaturizer(), 

38 bond_featurizer=chemprop.featurizers.MultiHotBondFeaturizer(), 

39 ) 

40 

41 def __call__( 

42 self, 

43 mol: Chem.Mol, 

44 atom_features_extra: np.ndarray | None = None, 

45 bond_features_extra: np.ndarray | None = None, 

46 ) -> chemprop.data.MolGraph: 

47 """Featurize a molecule with canonical stereochemical information. 

48 

49 Parameters 

50 ---------- 

51 mol 

52 Molecule to be featurized. 

53 atom_features_extra 

54 Extra features to be added to the atoms. 

55 bond_features_extra 

56 Extra features to be added to the bonds. 

57 

58 Returns 

59 ------- 

60 chemprop.data.MolGraph 

61 Featurized molecule with canonical stereochemical information. 

62 

63 """ 

64 mol = Chem.Mol(mol) 

65 Chem.AssignCIPLabels(mol) 

66 return super().__call__(mol, atom_features_extra, bond_features_extra) 

67 

68 

69class MoleculeStereoFeaturizer(chemprop.featurizers.SimpleMoleculeMolGraphFeaturizer): 

70 r"""Molecule featurizer that includes canonical stereochemical information. 

71 

72 This featurizer includes canonicalized tetrahedral stereocenters and 

73 cis/trans stereobonds. 

74 

75 Parameters 

76 ---------- 

77 divergent_bonds : bool 

78 Whether to add stereochemical features to the directed bonds that diverge from 

79 stereocenters and stereobonds, as opposed to those that converge to them. 

80 

81 Examples 

82 -------- 

83 >>> from chempropstereo import featurizers 

84 >>> from rdkit import Chem 

85 >>> import numpy as np 

86 >>> mol = Chem.MolFromSmiles("C[C@@H](N)/C=C(O)/N") 

87 >>> for divergent in (True, False): 

88 ... print(f"\nWith {'di' if divergent else 'con'}vergent bonds:\n") 

89 ... featurizer = featurizers.MoleculeStereoFeaturizer( 

90 ... mode="ORGANIC", 

91 ... divergent_bonds=divergent, 

92 ... ) 

93 ... print(featurizer.pretty_print(mol)) 

94 <BLANKLINE> 

95 With divergent bonds: 

96 <BLANKLINE> 

97 Vertices: 

98 0: 0010000000000 0000100 000010 001 000100 00010 0 0.120 

99 1: 0010000000000 0000100 000010 010 010000 00010 0 0.120 

100 2: 0001000000000 0001000 000010 001 001000 00010 0 0.140 

101 3: 0010000000000 0001000 000010 001 010000 00100 0 0.120 

102 4: 0010000000000 0001000 000010 001 100000 00100 0 0.120 

103 5: 0000100000000 0010000 000010 001 010000 00100 0 0.160 

104 6: 0001000000000 0001000 000010 001 001000 00100 0 0.140 

105 Edges: 

106 0→1: 0 1000 0 0 0000 00 00 

107 1→0: 0 1000 0 0 0100 00 00 

108 1→2: 0 1000 0 0 1000 00 00 

109 2→1: 0 1000 0 0 0000 00 00 

110 1→3: 0 1000 0 0 0010 00 00 

111 3→1: 0 1000 0 0 0000 00 10 

112 3→4: 0 0100 1 0 0000 01 00 

113 4→3: 0 0100 1 0 0000 01 00 

114 4→5: 0 1000 1 0 0000 00 01 

115 5→4: 0 1000 1 0 0000 00 00 

116 4→6: 0 1000 1 0 0000 00 10 

117 6→4: 0 1000 1 0 0000 00 00 

118 <BLANKLINE> 

119 With convergent bonds: 

120 <BLANKLINE> 

121 Vertices: 

122 0: 0010000000000 0000100 000010 001 000100 00010 0 0.120 

123 1: 0010000000000 0000100 000010 010 010000 00010 0 0.120 

124 2: 0001000000000 0001000 000010 001 001000 00010 0 0.140 

125 3: 0010000000000 0001000 000010 001 010000 00100 0 0.120 

126 4: 0010000000000 0001000 000010 001 100000 00100 0 0.120 

127 5: 0000100000000 0010000 000010 001 010000 00100 0 0.160 

128 6: 0001000000000 0001000 000010 001 001000 00100 0 0.140 

129 Edges: 

130 1→0: 0 1000 0 0 0000 00 00 

131 0→1: 0 1000 0 0 0100 00 00 

132 2→1: 0 1000 0 0 1000 00 00 

133 1→2: 0 1000 0 0 0000 00 00 

134 3→1: 0 1000 0 0 0010 00 00 

135 1→3: 0 1000 0 0 0000 00 10 

136 4→3: 0 0100 1 0 0000 01 00 

137 3→4: 0 0100 1 0 0000 01 00 

138 5→4: 0 1000 1 0 0000 00 01 

139 4→5: 0 1000 1 0 0000 00 00 

140 6→4: 0 1000 1 0 0000 00 10 

141 4→6: 0 1000 1 0 0000 00 00 

142 

143 """ 

144 

145 def __init__( 

146 self, 

147 mode: str | chemprop.featurizers.AtomFeatureMode, 

148 divergent_bonds: bool, 

149 ) -> None: 

150 super().__init__( 

151 atom_featurizer=AtomStereoFeaturizer(mode), 

152 bond_featurizer=BondStereoFeaturizer(), 

153 ) 

154 self.divergent_bonds = divergent_bonds 

155 

156 def __call__( 

157 self, 

158 mol: Chem.Mol, 

159 atom_features_extra: np.ndarray | None = None, 

160 bond_features_extra: np.ndarray | None = None, 

161 ) -> chemprop.data.MolGraph: 

162 """Featurize a molecule with canonical stereochemical information. 

163 

164 Parameters 

165 ---------- 

166 mol 

167 Molecule to be featurized. 

168 atom_features_extra 

169 Extra features to be added to the atoms. 

170 bond_features_extra 

171 Extra features to be added to the bonds. 

172 

173 Returns 

174 ------- 

175 chemprop.data.MolGraph 

176 Featurized molecule with canonical stereochemical information. 

177 

178 """ 

179 stereochemistry.tag_stereogroups(mol, force=False) 

180 

181 n_atoms = mol.GetNumAtoms() 

182 n_bonds = mol.GetNumBonds() 

183 

184 if atom_features_extra is not None and len(atom_features_extra) != n_atoms: 

185 raise ValueError( 

186 "Input molecule must have same number of atoms as " 

187 "`len(atom_features_extra)`! " 

188 f"Got: {n_atoms} and {len(atom_features_extra)}, respectively." 

189 ) 

190 if bond_features_extra is not None and len(bond_features_extra) != n_bonds: 

191 raise ValueError( 

192 "Input molecule must have same number of bonds as " 

193 "`len(bond_features_extra)`! " 

194 f"Got: {n_bonds} and {len(bond_features_extra)}, respectively." 

195 ) 

196 

197 if n_atoms == 0: 

198 vertices = np.zeros((1, self.atom_fdim), dtype=np.single) 

199 else: 

200 vertices = np.array( 

201 [self.atom_featurizer(a) for a in mol.GetAtoms()], dtype=np.single 

202 ) 

203 edges = np.empty((2 * n_bonds, self.bond_fdim)) 

204 edge_index = [[], []] 

205 

206 if atom_features_extra is not None: 

207 vertices = np.hstack((vertices, atom_features_extra)) 

208 

209 i = 0 

210 for bond in mol.GetBonds(): 

211 begin, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() 

212 u, v = (begin, end) if self.divergent_bonds else (end, begin) 

213 for j, flip_direction in enumerate((False, True)): 

214 x_e = self.bond_featurizer(bond, flip_direction) 

215 if bond_features_extra is not None: 

216 x_e = np.concatenate( 

217 (x_e, bond_features_extra[bond.GetIdx()]), dtype=np.single 

218 ) 

219 edges[i + j] = x_e 

220 edge_index[j].extend([v, u] if flip_direction else [u, v]) 

221 i += 2 

222 

223 rev_edge_index = np.arange(len(edges)).reshape(-1, 2)[:, ::-1].ravel() 

224 edge_index = np.array(edge_index, int) 

225 

226 return chemprop.data.MolGraph(vertices, edges, edge_index, rev_edge_index) 

227 

228 def pretty_print(self, mol: Chem.Mol) -> None: 

229 """Print a formatted string representation of the featurized molecule. 

230 

231 Parameters 

232 ---------- 

233 mol 

234 The molecule to be featurized. 

235 

236 Returns 

237 ------- 

238 str 

239 A string with the following format: 

240 .. code-block:: text 

241 

242 Vertices: 

243 <atom1 features> 

244 <atom2 features> 

245 ... 

246 Edges: 

247 <bond1 features> 

248 <bond2 features> 

249 ... 

250 

251 The features for each atom and bond are described in terms of the 

252 one-hot encodings of the properties and the floats for the masses. 

253 

254 Example 

255 ------- 

256 >>> from rdkit import Chem 

257 >>> from chempropstereo import MoleculeStereoFeaturizer 

258 >>> mol = Chem.MolFromSmiles("C[C@H](N)O") 

259 >>> featurizer = MoleculeStereoFeaturizer("ORGANIC", divergent_bonds=True) 

260 >>> print(featurizer.pretty_print(mol)) # doctest: +NORMALIZE_WHITESPACE 

261 Vertices: 

262 0: 0010000000000 0000100 000010 001 000100 00010 0 0.120 

263 1: 0010000000000 0000100 000010 100 010000 00010 0 0.120 

264 2: 0001000000000 0001000 000010 001 001000 00010 0 0.140 

265 3: 0000100000000 0010000 000010 001 010000 00010 0 0.160 

266 Edges: 

267 0→1: 0 1000 0 0 0000 00 00 

268 1→0: 0 1000 0 0 0010 00 00 

269 1→2: 0 1000 0 0 0100 00 00 

270 2→1: 0 1000 0 0 0000 00 00 

271 1→3: 0 1000 0 0 1000 00 00 

272 3→1: 0 1000 0 0 0000 00 00 

273 

274 """ 

275 molgraph = self(mol) 

276 vertices = "\n".join( 

277 utils.describe_atom_features(vertex, features, self.atom_featurizer.sizes) 

278 for vertex, features in enumerate(molgraph.V) 

279 ) 

280 edges = "\n".join( 

281 utils.describe_bond_features(edge, features, self.bond_featurizer.sizes) 

282 for edge, features in zip(molgraph.edge_index.T, molgraph.E) 

283 ) 

284 return f"Vertices:\n{vertices}\nEdges:\n{edges}"