Coverage for chempropstereo/featurizers/bond.py: 92%

36 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-22 21:04 +0000

1"""Bond featurization. 

2 

3.. module:: featurizers.bond 

4.. moduleauthor:: Charlles Abreu <craabreu@mit.edu> 

5""" 

6 

7import chemprop 

8import numpy as np 

9from rdkit import Chem 

10 

11from .. import stereochemistry 

12from . import utils 

13 

14_BOND_TYPES: tuple[Chem.BondType, ...] = ( 

15 Chem.BondType.SINGLE, 

16 Chem.BondType.DOUBLE, 

17 Chem.BondType.TRIPLE, 

18 Chem.BondType.AROMATIC, 

19) 

20_VERTEX_RANKS: tuple[stereochemistry.VertexRank, ...] = ( 

21 stereochemistry.VertexRank.FIRST, 

22 stereochemistry.VertexRank.SECOND, 

23 stereochemistry.VertexRank.THIRD, 

24 stereochemistry.VertexRank.FOURTH, 

25) 

26_STEM_ARRANGEMENTS: tuple[stereochemistry.StemArrangement, ...] = ( 

27 stereochemistry.StemArrangement.CIS, 

28 stereochemistry.StemArrangement.TRANS, 

29) 

30_BRANCH_RANKS: tuple[stereochemistry.BranchRank, ...] = ( 

31 stereochemistry.BranchRank.MAJOR, 

32 stereochemistry.BranchRank.MINOR, 

33) 

34 

35 

36class BondStereoFeaturizer(chemprop.featurizers.base.VectorFeaturizer[Chem.Bond]): 

37 r"""Multi-hot bond featurizer that includes canonical stereochemistry information. 

38 

39 The featurizer encodes the position of the end atom in the canonical order of 

40 neighbors when the begin atom has a canonical chiral tag. 

41 

42 The featurized bonds are expected to be part of an RDKit molecule with canonical 

43 chiral tags assigned via :func:`tetrahedral.tag_tetrahedral_stereocenters`. 

44 

45 Attributes 

46 ---------- 

47 sizes: tuple[int] 

48 A tuple of integers representing the sizes of each bond subfeature. 

49 

50 Examples 

51 -------- 

52 >>> from chempropstereo import featurizers, stereochemistry 

53 >>> from rdkit import Chem 

54 >>> import numpy as np 

55 >>> mol = Chem.MolFromSmiles("C\C(=C(O)/C=C(/N)O)[C@@H]([C@H](N)O)O") 

56 >>> stereochemistry.tag_tetrahedral_stereocenters(mol) 

57 >>> featurizer = featurizers.BondStereoFeaturizer() 

58 >>> def describe_bonds_from_atom(index): 

59 ... for bond in mol.GetAtomWithIdx(index).GetBonds(): 

60 ... atom_is_begin = bond.GetBeginAtomIdx() == index 

61 ... for reverse in (not atom_is_begin, atom_is_begin): 

62 ... print(featurizer.pretty_print(bond, reverse)) 

63 >>> stereochemistry.tag_stereogroups(mol) 

64 >>> stereochemistry.describe_stereobond(mol.GetBondBetweenAtoms(1, 2)) 

65 'C0 C8 C1 (CIS) C2 C4 O3' 

66 >>> describe_bonds_from_atom(1) # doctest: +NORMALIZE_WHITESPACE 

67 1→0: 0 1000 0 0 0000 00 10 

68 0→1: 0 1000 0 0 0000 00 00 

69 1→2: 0 0100 1 0 0000 10 00 

70 2→1: 0 0100 1 0 0000 10 00 

71 1→8: 0 1000 0 0 0000 00 01 

72 8→1: 0 1000 0 0 0010 00 00 

73 >>> describe_bonds_from_atom(2) # doctest: +NORMALIZE_WHITESPACE 

74 2→1: 0 0100 1 0 0000 10 00 

75 1→2: 0 0100 1 0 0000 10 00 

76 2→3: 0 1000 1 0 0000 00 01 

77 3→2: 0 1000 1 0 0000 00 00 

78 2→4: 0 1000 1 0 0000 00 10 

79 4→2: 0 1000 1 0 0000 00 10 

80 >>> stereochemistry.describe_stereobond(mol.GetBondBetweenAtoms(4, 5)) 

81 'C2 C4 (TRANS) C5 N6 O7' 

82 >>> describe_bonds_from_atom(4) # doctest: +NORMALIZE_WHITESPACE 

83 4→2: 0 1000 1 0 0000 00 10 

84 2→4: 0 1000 1 0 0000 00 10 

85 4→5: 0 0100 1 0 0000 01 00 

86 5→4: 0 0100 1 0 0000 01 00 

87 >>> describe_bonds_from_atom(5) # doctest: +NORMALIZE_WHITESPACE 

88 5→4: 0 0100 1 0 0000 01 00 

89 4→5: 0 0100 1 0 0000 01 00 

90 5→6: 0 1000 1 0 0000 00 10 

91 6→5: 0 1000 1 0 0000 00 00 

92 5→7: 0 1000 1 0 0000 00 01 

93 7→5: 0 1000 1 0 0000 00 00 

94 >>> stereochemistry.describe_stereocenter(mol.GetAtomWithIdx(8)) 

95 'C8 (CCW) O12 C9 C1' 

96 >>> describe_bonds_from_atom(8) # doctest: +NORMALIZE_WHITESPACE 

97 8→1: 0 1000 0 0 0010 00 00 

98 1→8: 0 1000 0 0 0000 00 01 

99 8→9: 0 1000 0 0 0100 00 00 

100 9→8: 0 1000 0 0 0010 00 00 

101 8→12: 0 1000 0 0 1000 00 00 

102 12→8: 0 1000 0 0 0000 00 00 

103 >>> stereochemistry.describe_stereocenter(mol.GetAtomWithIdx(9)) 

104 'C9 (CW) O11 N10 C8' 

105 >>> describe_bonds_from_atom(9) # doctest: +NORMALIZE_WHITESPACE 

106 9→8: 0 1000 0 0 0010 00 00 

107 8→9: 0 1000 0 0 0100 00 00 

108 9→10: 0 1000 0 0 0100 00 00 

109 10→9: 0 1000 0 0 0000 00 00 

110 9→11: 0 1000 0 0 1000 00 00 

111 11→9: 0 1000 0 0 0000 00 00 

112 

113 """ 

114 

115 def __init__(self): 

116 self.bond_types = _BOND_TYPES 

117 self.vertex_ranks = _VERTEX_RANKS 

118 self.stem_arrangements = _STEM_ARRANGEMENTS 

119 self.branch_ranks = _BRANCH_RANKS 

120 self._len = sum(self.sizes) 

121 

122 def __len__(self) -> int: 

123 return self._len 

124 

125 def __call__(self, b: Chem.Bond | None, flip_direction: bool = False) -> np.ndarray: 

126 """Encode a bond in a molecule with canonical stereochemistry information. 

127 

128 Parameters 

129 ---------- 

130 b : Chem.Bond | None 

131 The bond to be encoded. 

132 flip_direction : bool, optional 

133 Whether to reverse the direction of the bond (default is False). 

134 

135 Returns 

136 ------- 

137 np.ndarray 

138 A vector encoding the bond. 

139 

140 Notes 

141 ----- 

142 The vector includes the following information: 

143 - Null bond indicator 

144 - Bond types 

145 - Conjugation indicator 

146 - Ring indicator 

147 - Canonical vertex rank 

148 - Canonical stem arrangement 

149 - Canonical branch rank 

150 

151 """ 

152 if b is None: 

153 x = np.zeros(len(self), int) 

154 x[0] = 1 

155 return x 

156 

157 bond_type = b.GetBondType() 

158 vertex_rank = stereochemistry.VertexRank.from_bond(b, flip_direction) 

159 arrangement = stereochemistry.StemArrangement.get_from(b) 

160 branch_rank = stereochemistry.BranchRank.from_bond(b, flip_direction) 

161 

162 return np.array( 

163 [ 

164 b is None, 

165 *(bond_type == item for item in _BOND_TYPES), 

166 b.GetIsConjugated(), 

167 b.IsInRing(), 

168 *(vertex_rank == item for item in _VERTEX_RANKS), 

169 *(arrangement == item for item in _STEM_ARRANGEMENTS), 

170 *(branch_rank == item for item in _BRANCH_RANKS), 

171 ], 

172 dtype=int, 

173 ) 

174 

175 @property 

176 def sizes(self) -> list[int]: 

177 """Get a list of sizes corresponding to different bond features. 

178 

179 The list contains the sizes for: 

180 - Null bond indicator 

181 - Bond types 

182 - Conjugation indicator 

183 - Ring indicator 

184 - Tetrahedral vertex ranks 

185 - Cis/trans stem arrangements 

186 - Cis/trans branch ranks 

187 

188 Returns 

189 ------- 

190 list[int] 

191 A list of integers representing the sizes of each bond feature. 

192 

193 Examples 

194 -------- 

195 >>> from chempropstereo import featurizers 

196 >>> featurizer = featurizers.BondStereoFeaturizer() 

197 >>> featurizer.sizes 

198 (1, 4, 1, 1, 4, 2, 2) 

199 

200 """ 

201 return ( 

202 1, 

203 len(self.bond_types), 

204 1, 

205 1, 

206 len(self.vertex_ranks), 

207 len(self.stem_arrangements), 

208 len(self.branch_ranks), 

209 ) 

210 

211 def pretty_print(self, b: Chem.Bond | None, flip_direction: bool = False) -> str: 

212 """Get a formatted string representation of the bond features. 

213 

214 Parameters 

215 ---------- 

216 b : Chem.Bond or None 

217 The bond to be described. If None, a null bond is assumed. 

218 flip_direction : bool, optional 

219 Whether to reverse the direction of the bond (default is False). 

220 

221 Returns 

222 ------- 

223 str 

224 A formatted string representing the bond features. 

225 

226 Examples 

227 -------- 

228 >>> from rdkit import Chem 

229 >>> from chempropstereo import featurizers 

230 >>> mol = Chem.MolFromSmiles('CC') 

231 >>> bond = mol.GetBondWithIdx(0) 

232 >>> featurizer = featurizers.BondStereoFeaturizer() 

233 >>> featurizer.pretty_print(bond) 

234 ' 0→1: 0 1000 0 0 0000 00 00' 

235 

236 """ 

237 atoms = [b.GetBeginAtomIdx(), b.GetEndAtomIdx()] 

238 if flip_direction: 

239 atoms.reverse() 

240 return utils.describe_bond_features(atoms, self(b, flip_direction), self.sizes)