Coverage for chempropstereo/featurizers/molecule.py: 90%
52 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-22 21:04 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-22 21:04 +0000
1"""Molecule featurization.
3.. module:: featurizers.molecule
4.. moduleauthor:: Charlles Abreu <craabreu@mit.edu>
5"""
7import chemprop
8import numpy as np
9from rdkit import Chem
11from .. import stereochemistry
12from . import utils
13from .atom import AtomCIPFeaturizer, AtomStereoFeaturizer
14from .bond import BondStereoFeaturizer
17class MoleculeCIPFeaturizer(chemprop.featurizers.SimpleMoleculeMolGraphFeaturizer):
18 """Molecule featurizer that includes CIP codes for stereocenters.
20 Examples
21 --------
22 >>> from chempropstereo import MoleculeCIPFeaturizer
23 >>> from rdkit import Chem
24 >>> import numpy as np
25 >>> r_mol = Chem.MolFromSmiles("C[C@H](N)O")
26 >>> s_mol = Chem.MolFromSmiles("C[C@@H](N)O")
27 >>> featurizer = MoleculeCIPFeaturizer()
28 >>> r_molgraph = featurizer(r_mol)
29 >>> s_molgraph = featurizer(s_mol)
30 >>> assert not np.array_equal(r_molgraph.V, s_molgraph.V)
31 >>> assert np.array_equal(r_molgraph.E, s_molgraph.E)
33 """
35 def __init__(self):
36 super().__init__(
37 atom_featurizer=AtomCIPFeaturizer(),
38 bond_featurizer=chemprop.featurizers.MultiHotBondFeaturizer(),
39 )
41 def __call__(
42 self,
43 mol: Chem.Mol,
44 atom_features_extra: np.ndarray | None = None,
45 bond_features_extra: np.ndarray | None = None,
46 ) -> chemprop.data.MolGraph:
47 """Featurize a molecule with canonical stereochemical information.
49 Parameters
50 ----------
51 mol
52 Molecule to be featurized.
53 atom_features_extra
54 Extra features to be added to the atoms.
55 bond_features_extra
56 Extra features to be added to the bonds.
58 Returns
59 -------
60 chemprop.data.MolGraph
61 Featurized molecule with canonical stereochemical information.
63 """
64 mol = Chem.Mol(mol)
65 Chem.AssignCIPLabels(mol)
66 return super().__call__(mol, atom_features_extra, bond_features_extra)
69class MoleculeStereoFeaturizer(chemprop.featurizers.SimpleMoleculeMolGraphFeaturizer):
70 r"""Molecule featurizer that includes canonical stereochemical information.
72 This featurizer includes canonicalized tetrahedral stereocenters and
73 cis/trans stereobonds.
75 Parameters
76 ----------
77 divergent_bonds : bool
78 Whether to add stereochemical features to the directed bonds that diverge from
79 stereocenters and stereobonds, as opposed to those that converge to them.
81 Examples
82 --------
83 >>> from chempropstereo import featurizers
84 >>> from rdkit import Chem
85 >>> import numpy as np
86 >>> mol = Chem.MolFromSmiles("C[C@@H](N)/C=C(O)/N")
87 >>> for divergent in (True, False):
88 ... print(f"\nWith {'di' if divergent else 'con'}vergent bonds:\n")
89 ... featurizer = featurizers.MoleculeStereoFeaturizer(
90 ... mode="ORGANIC",
91 ... divergent_bonds=divergent,
92 ... )
93 ... print(featurizer.pretty_print(mol))
94 <BLANKLINE>
95 With divergent bonds:
96 <BLANKLINE>
97 Vertices:
98 0: 0010000000000 0000100 000010 001 000100 00010 0 0.120
99 1: 0010000000000 0000100 000010 010 010000 00010 0 0.120
100 2: 0001000000000 0001000 000010 001 001000 00010 0 0.140
101 3: 0010000000000 0001000 000010 001 010000 00100 0 0.120
102 4: 0010000000000 0001000 000010 001 100000 00100 0 0.120
103 5: 0000100000000 0010000 000010 001 010000 00100 0 0.160
104 6: 0001000000000 0001000 000010 001 001000 00100 0 0.140
105 Edges:
106 0→1: 0 1000 0 0 0000 00 00
107 1→0: 0 1000 0 0 0100 00 00
108 1→2: 0 1000 0 0 1000 00 00
109 2→1: 0 1000 0 0 0000 00 00
110 1→3: 0 1000 0 0 0010 00 00
111 3→1: 0 1000 0 0 0000 00 10
112 3→4: 0 0100 1 0 0000 01 00
113 4→3: 0 0100 1 0 0000 01 00
114 4→5: 0 1000 1 0 0000 00 01
115 5→4: 0 1000 1 0 0000 00 00
116 4→6: 0 1000 1 0 0000 00 10
117 6→4: 0 1000 1 0 0000 00 00
118 <BLANKLINE>
119 With convergent bonds:
120 <BLANKLINE>
121 Vertices:
122 0: 0010000000000 0000100 000010 001 000100 00010 0 0.120
123 1: 0010000000000 0000100 000010 010 010000 00010 0 0.120
124 2: 0001000000000 0001000 000010 001 001000 00010 0 0.140
125 3: 0010000000000 0001000 000010 001 010000 00100 0 0.120
126 4: 0010000000000 0001000 000010 001 100000 00100 0 0.120
127 5: 0000100000000 0010000 000010 001 010000 00100 0 0.160
128 6: 0001000000000 0001000 000010 001 001000 00100 0 0.140
129 Edges:
130 1→0: 0 1000 0 0 0000 00 00
131 0→1: 0 1000 0 0 0100 00 00
132 2→1: 0 1000 0 0 1000 00 00
133 1→2: 0 1000 0 0 0000 00 00
134 3→1: 0 1000 0 0 0010 00 00
135 1→3: 0 1000 0 0 0000 00 10
136 4→3: 0 0100 1 0 0000 01 00
137 3→4: 0 0100 1 0 0000 01 00
138 5→4: 0 1000 1 0 0000 00 01
139 4→5: 0 1000 1 0 0000 00 00
140 6→4: 0 1000 1 0 0000 00 10
141 4→6: 0 1000 1 0 0000 00 00
143 """
145 def __init__(
146 self,
147 mode: str | chemprop.featurizers.AtomFeatureMode,
148 divergent_bonds: bool,
149 ) -> None:
150 super().__init__(
151 atom_featurizer=AtomStereoFeaturizer(mode),
152 bond_featurizer=BondStereoFeaturizer(),
153 )
154 self.divergent_bonds = divergent_bonds
156 def __call__(
157 self,
158 mol: Chem.Mol,
159 atom_features_extra: np.ndarray | None = None,
160 bond_features_extra: np.ndarray | None = None,
161 ) -> chemprop.data.MolGraph:
162 """Featurize a molecule with canonical stereochemical information.
164 Parameters
165 ----------
166 mol
167 Molecule to be featurized.
168 atom_features_extra
169 Extra features to be added to the atoms.
170 bond_features_extra
171 Extra features to be added to the bonds.
173 Returns
174 -------
175 chemprop.data.MolGraph
176 Featurized molecule with canonical stereochemical information.
178 """
179 stereochemistry.tag_stereogroups(mol, force=False)
181 n_atoms = mol.GetNumAtoms()
182 n_bonds = mol.GetNumBonds()
184 if atom_features_extra is not None and len(atom_features_extra) != n_atoms:
185 raise ValueError(
186 "Input molecule must have same number of atoms as "
187 "`len(atom_features_extra)`! "
188 f"Got: {n_atoms} and {len(atom_features_extra)}, respectively."
189 )
190 if bond_features_extra is not None and len(bond_features_extra) != n_bonds:
191 raise ValueError(
192 "Input molecule must have same number of bonds as "
193 "`len(bond_features_extra)`! "
194 f"Got: {n_bonds} and {len(bond_features_extra)}, respectively."
195 )
197 if n_atoms == 0:
198 vertices = np.zeros((1, self.atom_fdim), dtype=np.single)
199 else:
200 vertices = np.array(
201 [self.atom_featurizer(a) for a in mol.GetAtoms()], dtype=np.single
202 )
203 edges = np.empty((2 * n_bonds, self.bond_fdim))
204 edge_index = [[], []]
206 if atom_features_extra is not None:
207 vertices = np.hstack((vertices, atom_features_extra))
209 i = 0
210 for bond in mol.GetBonds():
211 begin, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
212 u, v = (begin, end) if self.divergent_bonds else (end, begin)
213 for j, flip_direction in enumerate((False, True)):
214 x_e = self.bond_featurizer(bond, flip_direction)
215 if bond_features_extra is not None:
216 x_e = np.concatenate(
217 (x_e, bond_features_extra[bond.GetIdx()]), dtype=np.single
218 )
219 edges[i + j] = x_e
220 edge_index[j].extend([v, u] if flip_direction else [u, v])
221 i += 2
223 rev_edge_index = np.arange(len(edges)).reshape(-1, 2)[:, ::-1].ravel()
224 edge_index = np.array(edge_index, int)
226 return chemprop.data.MolGraph(vertices, edges, edge_index, rev_edge_index)
228 def pretty_print(self, mol: Chem.Mol) -> None:
229 """Print a formatted string representation of the featurized molecule.
231 Parameters
232 ----------
233 mol
234 The molecule to be featurized.
236 Returns
237 -------
238 str
239 A string with the following format:
240 .. code-block:: text
242 Vertices:
243 <atom1 features>
244 <atom2 features>
245 ...
246 Edges:
247 <bond1 features>
248 <bond2 features>
249 ...
251 The features for each atom and bond are described in terms of the
252 one-hot encodings of the properties and the floats for the masses.
254 Example
255 -------
256 >>> from rdkit import Chem
257 >>> from chempropstereo import MoleculeStereoFeaturizer
258 >>> mol = Chem.MolFromSmiles("C[C@H](N)O")
259 >>> featurizer = MoleculeStereoFeaturizer("ORGANIC", divergent_bonds=True)
260 >>> print(featurizer.pretty_print(mol)) # doctest: +NORMALIZE_WHITESPACE
261 Vertices:
262 0: 0010000000000 0000100 000010 001 000100 00010 0 0.120
263 1: 0010000000000 0000100 000010 100 010000 00010 0 0.120
264 2: 0001000000000 0001000 000010 001 001000 00010 0 0.140
265 3: 0000100000000 0010000 000010 001 010000 00010 0 0.160
266 Edges:
267 0→1: 0 1000 0 0 0000 00 00
268 1→0: 0 1000 0 0 0010 00 00
269 1→2: 0 1000 0 0 0100 00 00
270 2→1: 0 1000 0 0 0000 00 00
271 1→3: 0 1000 0 0 1000 00 00
272 3→1: 0 1000 0 0 0000 00 00
274 """
275 molgraph = self(mol)
276 vertices = "\n".join(
277 utils.describe_atom_features(vertex, features, self.atom_featurizer.sizes)
278 for vertex, features in enumerate(molgraph.V)
279 )
280 edges = "\n".join(
281 utils.describe_bond_features(edge, features, self.bond_featurizer.sizes)
282 for edge, features in zip(molgraph.edge_index.T, molgraph.E)
283 )
284 return f"Vertices:\n{vertices}\nEdges:\n{edges}"