Coverage for chempropstereo/featurizers/bond.py: 92%
36 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-22 21:04 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-22 21:04 +0000
1"""Bond featurization.
3.. module:: featurizers.bond
4.. moduleauthor:: Charlles Abreu <craabreu@mit.edu>
5"""
7import chemprop
8import numpy as np
9from rdkit import Chem
11from .. import stereochemistry
12from . import utils
14_BOND_TYPES: tuple[Chem.BondType, ...] = (
15 Chem.BondType.SINGLE,
16 Chem.BondType.DOUBLE,
17 Chem.BondType.TRIPLE,
18 Chem.BondType.AROMATIC,
19)
20_VERTEX_RANKS: tuple[stereochemistry.VertexRank, ...] = (
21 stereochemistry.VertexRank.FIRST,
22 stereochemistry.VertexRank.SECOND,
23 stereochemistry.VertexRank.THIRD,
24 stereochemistry.VertexRank.FOURTH,
25)
26_STEM_ARRANGEMENTS: tuple[stereochemistry.StemArrangement, ...] = (
27 stereochemistry.StemArrangement.CIS,
28 stereochemistry.StemArrangement.TRANS,
29)
30_BRANCH_RANKS: tuple[stereochemistry.BranchRank, ...] = (
31 stereochemistry.BranchRank.MAJOR,
32 stereochemistry.BranchRank.MINOR,
33)
36class BondStereoFeaturizer(chemprop.featurizers.base.VectorFeaturizer[Chem.Bond]):
37 r"""Multi-hot bond featurizer that includes canonical stereochemistry information.
39 The featurizer encodes the position of the end atom in the canonical order of
40 neighbors when the begin atom has a canonical chiral tag.
42 The featurized bonds are expected to be part of an RDKit molecule with canonical
43 chiral tags assigned via :func:`tetrahedral.tag_tetrahedral_stereocenters`.
45 Attributes
46 ----------
47 sizes: tuple[int]
48 A tuple of integers representing the sizes of each bond subfeature.
50 Examples
51 --------
52 >>> from chempropstereo import featurizers, stereochemistry
53 >>> from rdkit import Chem
54 >>> import numpy as np
55 >>> mol = Chem.MolFromSmiles("C\C(=C(O)/C=C(/N)O)[C@@H]([C@H](N)O)O")
56 >>> stereochemistry.tag_tetrahedral_stereocenters(mol)
57 >>> featurizer = featurizers.BondStereoFeaturizer()
58 >>> def describe_bonds_from_atom(index):
59 ... for bond in mol.GetAtomWithIdx(index).GetBonds():
60 ... atom_is_begin = bond.GetBeginAtomIdx() == index
61 ... for reverse in (not atom_is_begin, atom_is_begin):
62 ... print(featurizer.pretty_print(bond, reverse))
63 >>> stereochemistry.tag_stereogroups(mol)
64 >>> stereochemistry.describe_stereobond(mol.GetBondBetweenAtoms(1, 2))
65 'C0 C8 C1 (CIS) C2 C4 O3'
66 >>> describe_bonds_from_atom(1) # doctest: +NORMALIZE_WHITESPACE
67 1→0: 0 1000 0 0 0000 00 10
68 0→1: 0 1000 0 0 0000 00 00
69 1→2: 0 0100 1 0 0000 10 00
70 2→1: 0 0100 1 0 0000 10 00
71 1→8: 0 1000 0 0 0000 00 01
72 8→1: 0 1000 0 0 0010 00 00
73 >>> describe_bonds_from_atom(2) # doctest: +NORMALIZE_WHITESPACE
74 2→1: 0 0100 1 0 0000 10 00
75 1→2: 0 0100 1 0 0000 10 00
76 2→3: 0 1000 1 0 0000 00 01
77 3→2: 0 1000 1 0 0000 00 00
78 2→4: 0 1000 1 0 0000 00 10
79 4→2: 0 1000 1 0 0000 00 10
80 >>> stereochemistry.describe_stereobond(mol.GetBondBetweenAtoms(4, 5))
81 'C2 C4 (TRANS) C5 N6 O7'
82 >>> describe_bonds_from_atom(4) # doctest: +NORMALIZE_WHITESPACE
83 4→2: 0 1000 1 0 0000 00 10
84 2→4: 0 1000 1 0 0000 00 10
85 4→5: 0 0100 1 0 0000 01 00
86 5→4: 0 0100 1 0 0000 01 00
87 >>> describe_bonds_from_atom(5) # doctest: +NORMALIZE_WHITESPACE
88 5→4: 0 0100 1 0 0000 01 00
89 4→5: 0 0100 1 0 0000 01 00
90 5→6: 0 1000 1 0 0000 00 10
91 6→5: 0 1000 1 0 0000 00 00
92 5→7: 0 1000 1 0 0000 00 01
93 7→5: 0 1000 1 0 0000 00 00
94 >>> stereochemistry.describe_stereocenter(mol.GetAtomWithIdx(8))
95 'C8 (CCW) O12 C9 C1'
96 >>> describe_bonds_from_atom(8) # doctest: +NORMALIZE_WHITESPACE
97 8→1: 0 1000 0 0 0010 00 00
98 1→8: 0 1000 0 0 0000 00 01
99 8→9: 0 1000 0 0 0100 00 00
100 9→8: 0 1000 0 0 0010 00 00
101 8→12: 0 1000 0 0 1000 00 00
102 12→8: 0 1000 0 0 0000 00 00
103 >>> stereochemistry.describe_stereocenter(mol.GetAtomWithIdx(9))
104 'C9 (CW) O11 N10 C8'
105 >>> describe_bonds_from_atom(9) # doctest: +NORMALIZE_WHITESPACE
106 9→8: 0 1000 0 0 0010 00 00
107 8→9: 0 1000 0 0 0100 00 00
108 9→10: 0 1000 0 0 0100 00 00
109 10→9: 0 1000 0 0 0000 00 00
110 9→11: 0 1000 0 0 1000 00 00
111 11→9: 0 1000 0 0 0000 00 00
113 """
115 def __init__(self):
116 self.bond_types = _BOND_TYPES
117 self.vertex_ranks = _VERTEX_RANKS
118 self.stem_arrangements = _STEM_ARRANGEMENTS
119 self.branch_ranks = _BRANCH_RANKS
120 self._len = sum(self.sizes)
122 def __len__(self) -> int:
123 return self._len
125 def __call__(self, b: Chem.Bond | None, flip_direction: bool = False) -> np.ndarray:
126 """Encode a bond in a molecule with canonical stereochemistry information.
128 Parameters
129 ----------
130 b : Chem.Bond | None
131 The bond to be encoded.
132 flip_direction : bool, optional
133 Whether to reverse the direction of the bond (default is False).
135 Returns
136 -------
137 np.ndarray
138 A vector encoding the bond.
140 Notes
141 -----
142 The vector includes the following information:
143 - Null bond indicator
144 - Bond types
145 - Conjugation indicator
146 - Ring indicator
147 - Canonical vertex rank
148 - Canonical stem arrangement
149 - Canonical branch rank
151 """
152 if b is None:
153 x = np.zeros(len(self), int)
154 x[0] = 1
155 return x
157 bond_type = b.GetBondType()
158 vertex_rank = stereochemistry.VertexRank.from_bond(b, flip_direction)
159 arrangement = stereochemistry.StemArrangement.get_from(b)
160 branch_rank = stereochemistry.BranchRank.from_bond(b, flip_direction)
162 return np.array(
163 [
164 b is None,
165 *(bond_type == item for item in _BOND_TYPES),
166 b.GetIsConjugated(),
167 b.IsInRing(),
168 *(vertex_rank == item for item in _VERTEX_RANKS),
169 *(arrangement == item for item in _STEM_ARRANGEMENTS),
170 *(branch_rank == item for item in _BRANCH_RANKS),
171 ],
172 dtype=int,
173 )
175 @property
176 def sizes(self) -> list[int]:
177 """Get a list of sizes corresponding to different bond features.
179 The list contains the sizes for:
180 - Null bond indicator
181 - Bond types
182 - Conjugation indicator
183 - Ring indicator
184 - Tetrahedral vertex ranks
185 - Cis/trans stem arrangements
186 - Cis/trans branch ranks
188 Returns
189 -------
190 list[int]
191 A list of integers representing the sizes of each bond feature.
193 Examples
194 --------
195 >>> from chempropstereo import featurizers
196 >>> featurizer = featurizers.BondStereoFeaturizer()
197 >>> featurizer.sizes
198 (1, 4, 1, 1, 4, 2, 2)
200 """
201 return (
202 1,
203 len(self.bond_types),
204 1,
205 1,
206 len(self.vertex_ranks),
207 len(self.stem_arrangements),
208 len(self.branch_ranks),
209 )
211 def pretty_print(self, b: Chem.Bond | None, flip_direction: bool = False) -> str:
212 """Get a formatted string representation of the bond features.
214 Parameters
215 ----------
216 b : Chem.Bond or None
217 The bond to be described. If None, a null bond is assumed.
218 flip_direction : bool, optional
219 Whether to reverse the direction of the bond (default is False).
221 Returns
222 -------
223 str
224 A formatted string representing the bond features.
226 Examples
227 --------
228 >>> from rdkit import Chem
229 >>> from chempropstereo import featurizers
230 >>> mol = Chem.MolFromSmiles('CC')
231 >>> bond = mol.GetBondWithIdx(0)
232 >>> featurizer = featurizers.BondStereoFeaturizer()
233 >>> featurizer.pretty_print(bond)
234 ' 0→1: 0 1000 0 0 0000 00 00'
236 """
237 atoms = [b.GetBeginAtomIdx(), b.GetEndAtomIdx()]
238 if flip_direction:
239 atoms.reverse()
240 return utils.describe_bond_features(atoms, self(b, flip_direction), self.sizes)