Coverage for chempropstereo/featurizers/atom.py: 96%
50 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-22 21:04 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-22 21:04 +0000
1"""Atom featurization.
3.. module:: featurizers.atom
4.. moduleauthor:: Charlles Abreu <craabreu@mit.edu>
5"""
7import chemprop
8import numpy as np
9from rdkit import Chem
11from .. import stereochemistry
12from . import utils
15class AtomCIPFeaturizer(chemprop.featurizers.MultiHotAtomFeaturizer):
16 """Multi-hot atom featurizer that includes a CIP code if the atom is a stereocenter.
18 The featurized atoms are expected to be part of an RDKit molecule with CIP labels
19 assigned via the `AssignCIPLabels`_ function.
21 .. _AssignCIPLabels: https://www.rdkit.org/docs/source/\
22rdkit.Chem.rdCIPLabeler.html#rdkit.Chem.rdCIPLabeler.AssignCIPLabels
24 Parameters
25 ----------
26 mode : featurizers.AtomFeatureMode
27 The mode to use for the featurizer. Available modes are `V1`_, `V2`_, and
28 `ORGANIC`_.
30 .. _V1: https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/\
31featurizers/atom/index.html#chemprop.featurizers.atom.MultiHotAtomFeaturizer.v1
32 .. _V2: https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/\
33featurizers/atom/index.html#chemprop.featurizers.atom.MultiHotAtomFeaturizer.v2
34 .. _ORGANIC: https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/\
35featurizers/atom/index.html#chemprop.featurizers.atom.MultiHotAtomFeaturizer.organic
37 Examples
38 --------
39 >>> from chempropstereo import AtomCIPFeaturizer
40 >>> from rdkit import Chem
41 >>> r_mol = Chem.MolFromSmiles("C[C@H](N)O")
42 >>> s_mol = Chem.MolFromSmiles("C[C@@H](N)O")
43 >>> for mol in [r_mol, s_mol]:
44 ... Chem.AssignCIPLabels(mol)
45 >>> r_atom = r_mol.GetAtomWithIdx(1)
46 >>> s_atom = s_mol.GetAtomWithIdx(1)
47 >>> featurizer = AtomCIPFeaturizer("ORGANIC")
48 >>> for atom in [r_atom, s_atom]:
49 ... features = featurizer(atom)
50 ... assert len(features) == len(featurizer)
51 ... print("".join(map(str, features)))
52 0010000000000000010000001001000100000001000
53 0010000000000000010000001000100100000001000
55 """
57 def __init__(self, mode: str | chemprop.featurizers.AtomFeatureMode = "V2") -> None:
58 featurizer = chemprop.featurizers.get_multi_hot_atom_featurizer(
59 chemprop.featurizers.AtomFeatureMode.get(mode)
60 )
61 super().__init__(
62 atomic_nums=featurizer.atomic_nums,
63 degrees=featurizer.degrees,
64 formal_charges=featurizer.formal_charges,
65 chiral_tags=list(range(3)),
66 num_Hs=featurizer.num_Hs,
67 hybridizations=featurizer.hybridizations,
68 )
70 def __call__(self, a: Chem.Atom | None) -> np.ndarray:
71 """Featurize an RDKit atom with stereochemical information.
73 Parameters
74 ----------
75 a : Chem.Atom | None
76 The atom to featurize. If None, returns a zero array.
78 Returns
79 -------
80 np.ndarray
81 A 1D array of shape `(len(self),)` containing the following features:
82 - One-hot encoding of the atomic number
83 - One-hot encoding of the total degree
84 - One-hot encoding of the formal charge
85 - One-hot encoding of the CIP code
86 - One-hot encoding of the total number of hydrogens
87 - One-hot encoding of the hybridization
88 - Boolean indicating whether the atom is aromatic
89 - Mass of the atom divided by 100
91 """
92 x = np.zeros(len(self), int)
94 if a is None:
95 return x
97 feats = [
98 a.GetAtomicNum(),
99 a.GetTotalDegree(),
100 a.GetFormalCharge(),
101 stereochemistry.get_cip_code(a),
102 int(a.GetTotalNumHs()),
103 a.GetHybridization(),
104 ]
106 i = 0
107 for feat, choices in zip(feats, self._subfeats):
108 j = choices.get(feat, len(choices))
109 x[i + j] = 1
110 i += len(choices) + 1
111 x[i] = int(a.GetIsAromatic())
112 x[i + 1] = 0.01 * a.GetMass()
114 return x
117_SCAN_DIRECTIONS: tuple[stereochemistry.ScanDirection, ...] = (
118 stereochemistry.ScanDirection.CW,
119 stereochemistry.ScanDirection.CCW,
120)
123class AtomStereoFeaturizer(chemprop.featurizers.base.VectorFeaturizer[Chem.Atom]):
124 """Multi-hot atom featurizer that includes a canonical chiral tag for each atom.
126 The featurized atoms are expected to be part of an RDKit molecule with canonical
127 chiral tags assigned via :func:`~stereochemistry.tag_tetrahedral_stereocenters`.
129 Parameters
130 ----------
131 mode
132 The mode to use for the featurizer. Available modes are `V1`_, `V2`_, and
133 `ORGANIC`_.
135 .. _V1: https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/\
136featurizers/atom/index.html#chemprop.featurizers.atom.MultiHotAtomFeaturizer.v1
137 .. _V2: https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/\
138featurizers/atom/index.html#chemprop.featurizers.atom.MultiHotAtomFeaturizer.v2
139 .. _ORGANIC: https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/\
140featurizers/atom/index.html#chemprop.featurizers.atom.MultiHotAtomFeaturizer.organic
142 Examples
143 --------
144 >>> from chempropstereo import featurizers, stereochemistry
145 >>> from rdkit import Chem
146 >>> featurizer = featurizers.AtomStereoFeaturizer("ORGANIC")
147 >>> for smi in ["C[C@H](N)O", "C[C@@H](N)O"]:
148 ... mol = Chem.MolFromSmiles(smi)
149 ... stereochemistry.tag_stereogroups(mol)
150 ... print(f"Molecule: {smi}")
151 ... for atom in mol.GetAtoms():
152 ... print(featurizer.pretty_print(atom))
153 Molecule: C[C@H](N)O
154 0: 0010000000000 0000100 000010 001 000100 00010 0 0.120
155 1: 0010000000000 0000100 000010 100 010000 00010 0 0.120
156 2: 0001000000000 0001000 000010 001 001000 00010 0 0.140
157 3: 0000100000000 0010000 000010 001 010000 00010 0 0.160
158 Molecule: C[C@@H](N)O
159 0: 0010000000000 0000100 000010 001 000100 00010 0 0.120
160 1: 0010000000000 0000100 000010 010 010000 00010 0 0.120
161 2: 0001000000000 0001000 000010 001 001000 00010 0 0.140
162 3: 0000100000000 0010000 000010 001 010000 00010 0 0.160
164 """
166 def __init__(self, mode: str | chemprop.featurizers.AtomFeatureMode) -> None:
167 featurizer = chemprop.featurizers.get_multi_hot_atom_featurizer(
168 chemprop.featurizers.AtomFeatureMode.get(mode)
169 )
170 self.atomic_nums = featurizer.atomic_nums
171 self.degrees = featurizer.degrees
172 self.formal_charges = featurizer.formal_charges
173 self.scan_directions = _SCAN_DIRECTIONS
174 self.num_Hs = featurizer.num_Hs
175 self.hybridizations = featurizer.hybridizations
176 self._len = sum(self.sizes)
178 def __len__(self):
179 return self._len
181 def __call__(self, a: Chem.Atom | None) -> np.ndarray:
182 """Featurize an RDKit atom with stereochemical information.
184 Parameters
185 ----------
186 a
187 The atom to featurize.
189 Returns
190 -------
191 np.ndarray
192 A 1D array of shape `(len(self),)` containing the following features:
193 - `atomic_num`: one-hot encoding of the atomic number
194 - `total_degree`: one-hot encoding of the total degree
195 - `formal_charge`: one-hot encoding of the formal charge
196 - `scan_direction`: one-hot encoding of the scan direction
197 - `total_num_hs`: one-hot encoding of the total number of Hs
198 - `hybridization`: one-hot encoding of the hybridization
199 - `is_aromatic`: boolean indicating whether the atom is aromatic
200 - `mass`: mass of the atom divided by 100
202 """
203 if a is None:
204 return np.zeros(len(self))
206 atomic_num = a.GetAtomicNum()
207 total_degree = a.GetTotalDegree()
208 formal_charge = a.GetFormalCharge()
209 total_num_hs = a.GetTotalNumHs()
210 hybridization = a.GetHybridization()
211 scan_direction = stereochemistry.ScanDirection.get_from(a)
213 return np.array(
214 [
215 *(atomic_num == item for item in self.atomic_nums),
216 atomic_num not in self.atomic_nums,
217 *(total_degree == item for item in self.degrees),
218 total_degree not in self.degrees,
219 *(formal_charge == item for item in self.formal_charges),
220 formal_charge not in self.formal_charges,
221 *(scan_direction == item for item in self.scan_directions),
222 scan_direction not in self.scan_directions,
223 *(total_num_hs == item for item in self.num_Hs),
224 total_num_hs not in self.num_Hs,
225 *(hybridization == item for item in self.hybridizations),
226 hybridization not in self.hybridizations,
227 a.GetIsAromatic(),
228 0.01 * a.GetMass(),
229 ],
230 dtype=float,
231 )
233 @property
234 def sizes(self) -> tuple[int, ...]:
235 """Get a tuple of sizes corresponding to different atom features.
237 The tuple contains the sizes for:
238 - Atomic numbers
239 - Total degrees
240 - Formal charges
241 - Scan directions
242 - Total numbers of Hs
243 - Hybridizations
244 - Aromatic indicator
245 - Mass
247 Returns
248 -------
249 tuple[int, ...]
250 A tuple of integers representing the sizes of each atom feature.
252 Examples
253 --------
254 >>> from chempropstereo import featurizers
255 >>> featurizer = featurizers.AtomStereoFeaturizer("ORGANIC")
256 >>> featurizer.sizes
257 (13, 7, 6, 3, 6, 5, 1, 1)
259 """
260 return (
261 len(self.atomic_nums) + 1,
262 len(self.degrees) + 1,
263 len(self.formal_charges) + 1,
264 len(self.scan_directions) + 1,
265 len(self.num_Hs) + 1,
266 len(self.hybridizations) + 1,
267 1,
268 1,
269 )
271 def pretty_print(self, a: Chem.Atom | None) -> str:
272 """Get a formatted string representation of the atom features.
274 Parameters
275 ----------
276 a : Chem.Atom or None
277 The atom to be described. If None, a null atom is assumed.
279 Returns
280 -------
281 str
282 A formatted string representing the atom features.
284 Examples
285 --------
286 >>> from rdkit import Chem
287 >>> from chempropstereo import featurizers
288 >>> mol = Chem.MolFromSmiles('CC')
289 >>> atom = mol.GetAtomWithIdx(0)
290 >>> featurizer = featurizers.AtomStereoFeaturizer("ORGANIC")
291 >>> featurizer.pretty_print(atom)
292 ' 0: 0010000000000 0000100 000010 001 000100 00010 0 0.120'
294 """
295 return utils.describe_atom_features(a.GetIdx(), self(a), self.sizes)