Coverage for cosmolayer / cosmosac / visualize.py: 89%

192 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-11 14:25 +0000

1""" 

2.. module:: cosmolayer.cosmosac.visualize 

3 :synopsis: Visualize COSMO-SAC surface segments. 

4 

5.. functionauthor:: Charlles Abreu <craabreu@gmail.com> 

6""" 

7 

8import argparse 

9import pathlib 

10 

11import cmap 

12import networkx as nx 

13import numpy as np 

14import open3d as o3d 

15import periodictable as pt 

16 

17from cosmolayer.cosmosac import Component 

18 

19RADII_MULTIPLIERS: tuple[float, float, float] = (1.5, 2.5, 4.0) 

20 

21ELEMENT_COLORS = { # https://pymolwiki.org/Color_Values 

22 "Br": (0.650980392, 0.160784314, 0.160784314), 

23 "C": (0.2, 1.0, 0.2), 

24 "Cl": (0.121568627, 0.941176471, 0.121568627), 

25 "F": (0.701960784, 1.0, 1.0), 

26 "H": (0.9, 0.9, 0.9), 

27 "I": (0.580392157, 0.0, 0.580392157), 

28 "N": (0.2, 0.2, 1.0), 

29 "O": (1.0, 0.3, 0.3), 

30 "P": (1.0, 0.501960784, 0.0), 

31 "Si": (0.941176471, 0.784313725, 0.627450980), 

32 "S": (0.9, 0.775, 0.25), 

33} 

34 

35 

36TOLERANCE: float = 1e-10 

37DOT_PRODUCT_TOLERANCE: float = 0.9 

38X_AXIS: np.ndarray = np.array([1.0, 0.0, 0.0]) 

39Y_AXIS: np.ndarray = np.array([0.0, 1.0, 0.0]) 

40Z_AXIS: np.ndarray = np.array([0.0, 0.0, 1.0]) 

41 

42 

43def estimate_vdw_radius(element: str) -> float: 

44 return float(pt.elements.symbol(element).covalent_radius) + 0.8 # Å 

45 

46 

47def create_atom_spheres( 

48 component: Component, 

49 radius_scale: float, 

50 resolution: int = 40, 

51 default_color: tuple[float, float, float] = (0.7, 0.7, 0.7), 

52) -> list[o3d.geometry.TriangleMesh]: 

53 atom_df = component.atom_data 

54 spheres: list[o3d.geometry.TriangleMesh] = [] 

55 for item in atom_df.itertuples(index=False): 

56 element = str(item.element).strip() 

57 radius = estimate_vdw_radius(element) * radius_scale 

58 sphere = o3d.geometry.TriangleMesh.create_sphere( 

59 radius=radius, 

60 resolution=resolution, 

61 ) 

62 sphere.compute_vertex_normals() 

63 sphere.translate((item.x, item.y, item.z)) 

64 rgb = np.array(ELEMENT_COLORS.get(element, default_color)) 

65 sphere.paint_uniform_color(rgb) 

66 spheres.append(sphere) 

67 return spheres 

68 

69 

70def compute_rotation_matrix( 

71 original_axis: np.ndarray, target_axis: np.ndarray, normalize: bool = False 

72) -> np.ndarray: 

73 """Rodrigues' rotation formula between two unit-direction vectors.""" 

74 if normalize: 

75 original_axis = original_axis / np.linalg.norm(original_axis) 

76 target_axis = target_axis / np.linalg.norm(target_axis) 

77 v = np.cross(original_axis, target_axis) 

78 c = original_axis.dot(target_axis) 

79 s2 = v.dot(v) 

80 if s2 < TOLERANCE: 

81 if c > 0: 

82 return np.eye(3) # Parallel (c ≈ 1) → identity 

83 arbitrary = X_AXIS if abs(original_axis[0]) < DOT_PRODUCT_TOLERANCE else Y_AXIS 

84 arbitrary -= original_axis * original_axis.dot(arbitrary) 

85 orthogonal = arbitrary / np.linalg.norm(arbitrary) 

86 return 2.0 * np.outer(orthogonal, orthogonal) - np.eye(3) 

87 kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 

88 rotation: np.ndarray = np.eye(3) + kmat + ((1 - c) / s2) * kmat @ kmat 

89 return rotation 

90 

91 

92def create_bond_sticks( 

93 component: Component, 

94 atom_radius_scale: float, 

95 bond_radius: float, 

96 resolution: int = 100, 

97 default_color: tuple[float, float, float] = (0.7, 0.7, 0.7), 

98) -> list[o3d.geometry.TriangleMesh]: 

99 atom_df = component.atom_data 

100 coords = atom_df[["x", "y", "z"]].values 

101 elements = atom_df["element"].values 

102 radii = atom_df["element"].apply(estimate_vdw_radius).values * atom_radius_scale 

103 bonds = component.bonds 

104 cylinders: list[o3d.geometry.TriangleMesh] = [] 

105 for i, j in bonds: 

106 vector = coords[j] - coords[i] 

107 length = np.linalg.norm(vector) 

108 if length < radii[i] + radii[j]: 

109 continue 

110 axis = vector / length 

111 rotation = compute_rotation_matrix(Z_AXIS, axis) 

112 midpoint = (coords[i] + coords[j] + (radii[i] - radii[j]) * axis) / 2 

113 for k in (i, j): 

114 cylinder = o3d.geometry.TriangleMesh.create_cylinder( 

115 radius=bond_radius, 

116 height=np.linalg.norm(coords[k] - midpoint), 

117 resolution=resolution, 

118 ) 

119 cylinder.rotate(rotation, center=np.zeros(3)) 

120 cylinder.translate((coords[k] + midpoint) / 2) 

121 cylinder.compute_vertex_normals() 

122 rgb = np.array(ELEMENT_COLORS.get(elements[k], default_color)) 

123 cylinder.paint_uniform_color(rgb) 

124 cylinders.append(cylinder) 

125 return cylinders 

126 

127 

128def ball_pivoting_algorithm( 

129 points: np.ndarray, 

130 normals: np.ndarray, 

131 vertex_rgb: np.ndarray, 

132 radii_multipliers: tuple[float, float, float], 

133) -> tuple[o3d.geometry.TriangleMesh, np.ndarray]: 

134 pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points)) 

135 pcd.normals = o3d.utility.Vector3dVector(normals) 

136 

137 spacing = np.asarray(pcd.compute_nearest_neighbor_distance()).mean().item() 

138 radii = o3d.utility.DoubleVector([m * spacing for m in radii_multipliers]) 

139 

140 mesh_bpa = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting( 

141 pcd, radii 

142 ) 

143 

144 mesh_bpa.remove_degenerate_triangles() 

145 mesh_bpa.remove_duplicated_triangles() 

146 mesh_bpa.remove_non_manifold_edges() 

147 mesh_bpa.remove_unreferenced_vertices() 

148 

149 kdtree = o3d.geometry.KDTreeFlann(pcd) 

150 indices = np.empty(len(mesh_bpa.vertices), dtype=int) 

151 for vi, v in enumerate(mesh_bpa.vertices): 

152 _, idx, _ = kdtree.search_knn_vector_3d(v, 1) 

153 indices[vi] = int(idx[0]) 

154 

155 vertex_rgb = vertex_rgb[indices] 

156 mesh_bpa.vertex_colors = o3d.utility.Vector3dVector(vertex_rgb) 

157 

158 return mesh_bpa, indices 

159 

160 

161def find_loops( 

162 mesh: o3d.geometry.TriangleMesh, edge_color: str 

163) -> list[o3d.geometry.LineSet]: 

164 graph = nx.Graph() 

165 for triangle in mesh.triangles: 

166 _, j, k = map(int, triangle) 

167 graph.add_edge(j, k) 

168 loops = nx.cycle_basis(graph) 

169 

170 vertices = np.asarray(mesh.vertices, dtype=float) 

171 linesets: list[o3d.geometry.LineSet] = [] 

172 

173 for loop in loops: 

174 idx = np.asarray(loop + [loop[0]], dtype=int) 

175 pts = vertices[idx] 

176 lines = np.column_stack( 

177 [np.arange(len(idx) - 1), np.arange(1, len(idx))] 

178 ).astype(np.int32) 

179 lineset = o3d.geometry.LineSet( 

180 points=o3d.utility.Vector3dVector(pts), 

181 lines=o3d.utility.Vector2iVector(lines), 

182 ) 

183 if edge_color is not None: 

184 rgb = np.asarray(cmap.Color(edge_color))[:3] 

185 lineset.paint_uniform_color(rgb) 

186 linesets.append(lineset) 

187 

188 return linesets 

189 

190 

191def geodesic_centroid(center: np.ndarray, *vertices: np.ndarray) -> np.ndarray: 

192 num_vertices = len(vertices) 

193 vectors = [v - center for v in vertices] 

194 norms = [np.linalg.norm(v) for v in vectors] 

195 radius = sum(norms) / num_vertices 

196 mean_vector = sum(vectors) / num_vertices 

197 centroid: np.ndarray = center + radius * mean_vector / np.linalg.norm(mean_vector) 

198 return centroid 

199 

200 

201def surface_tessellation( 

202 component: Component, 

203 original_charge_densities: bool = False, 

204 interpolated_colors: bool = False, 

205 colormap: str = "jet", 

206) -> o3d.geometry.TriangleMesh: 

207 segment_data = component.segment_data 

208 atom_data = component.atom_data 

209 sigma_grid = component.sigma_grid 

210 vmin, vmax = sigma_grid[0], sigma_grid[-1] 

211 sigmas = segment_data[ 

212 "sigma" if original_charge_densities else "sigma_avg" 

213 ].values.clip(vmin, vmax) 

214 

215 atom_coords = np.stack( 

216 [segment_data["atom"].map(atom_data[axis]).values for axis in "xyz"], axis=1 

217 ) 

218 pts = segment_data[["x", "y", "z"]].values 

219 displacements = pts - atom_coords 

220 normals = displacements / np.linalg.norm(displacements, axis=1, keepdims=True) 

221 

222 normalized_sigmas = (sigmas.clip(vmin, vmax) - vmin) / (vmax - vmin) 

223 mapper = cmap.Colormap(colormap) 

224 vertex_rgb = mapper(normalized_sigmas)[:, :3] 

225 

226 mesh_bpa, indices = ball_pivoting_algorithm( 

227 pts, normals, vertex_rgb, RADII_MULTIPLIERS 

228 ) 

229 

230 if interpolated_colors: 

231 return mesh_bpa 

232 

233 vertices = np.asarray(mesh_bpa.vertices, dtype=float) 

234 triangles = np.asarray(mesh_bpa.triangles, dtype=int) 

235 colors = np.asarray(mesh_bpa.vertex_colors, dtype=float) 

236 atoms = segment_data["atom"].values[indices] 

237 

238 new_vertices = vertices.tolist() 

239 new_colors = colors.tolist() 

240 

241 def add_vertex(v: np.ndarray, c: np.ndarray) -> int: 

242 idx = len(new_vertices) 

243 new_vertices.append(v) 

244 new_colors.append(c) 

245 return idx 

246 

247 midpoint_cache: dict[tuple[int, int], int] = {} 

248 

249 def midpoint_vertices(i: int, j: int) -> tuple[int, int]: 

250 if (i, j) in midpoint_cache: 

251 return midpoint_cache[(i, j)], midpoint_cache[(j, i)] 

252 

253 if atoms[i] == atoms[j]: 

254 midpoint = geodesic_centroid( 

255 atom_coords[atoms[i]], vertices[i], vertices[j] 

256 ) 

257 else: 

258 midpoint = (vertices[i] + vertices[j]) / 2 

259 

260 mij = midpoint_cache[(i, j)] = add_vertex(midpoint, colors[i]) 

261 mji = midpoint_cache[(j, i)] = add_vertex(midpoint, colors[j]) 

262 return mij, mji 

263 

264 new_triangles = [] 

265 

266 for triangle in triangles: 

267 i, j, k = map(int, triangle) 

268 mij, mji = midpoint_vertices(i, j) 

269 mjk, mkj = midpoint_vertices(j, k) 

270 mik, mki = midpoint_vertices(i, k) 

271 

272 if atoms[i] == atoms[j] == atoms[k]: 

273 centroid = geodesic_centroid( 

274 atom_coords[atoms[i]], vertices[i], vertices[j], vertices[k] 

275 ) 

276 else: 

277 centroid = (vertices[i] + vertices[j] + vertices[k]) / 3 

278 

279 mijk = add_vertex(centroid, colors[i]) 

280 mjki = add_vertex(centroid, colors[j]) 

281 mkij = add_vertex(centroid, colors[k]) 

282 

283 new_triangles += [ 

284 [i, mij, mijk], 

285 [i, mik, mijk], 

286 [j, mji, mjki], 

287 [j, mjk, mjki], 

288 [k, mkj, mkij], 

289 [k, mki, mkij], 

290 ] 

291 

292 mesh = o3d.geometry.TriangleMesh( 

293 vertices=o3d.utility.Vector3dVector(new_vertices), 

294 triangles=o3d.utility.Vector3iVector(new_triangles), 

295 ) 

296 mesh.vertex_colors = o3d.utility.Vector3dVector(new_colors) 

297 mesh.compute_vertex_normals() 

298 

299 return mesh 

300 

301 

302def generate_geometries( 

303 component: Component, 

304 original_charge_densities: bool = False, 

305 use_continuous_colors: bool = False, 

306 colormap: str = "jet", 

307 segment_edge_color: str | None = None, 

308) -> tuple[o3d.geometry.Geometry3D, ...]: 

309 """Build Open3D geometries for visualizing a component's COSMO surface. 

310 

311 Returns a tuple of Open3D geometries: 

312 

313 (1) a tessellated surface mesh colored by screening charge density; 

314 (2) optionally, segment-boundary loop line sets when ``segment_edge_color`` is set; 

315 (3) atom spheres; and 

316 (4) bond sticks. 

317 

318 Parameters 

319 ---------- 

320 component : Component 

321 The molecular component whose COSMO surface is to be visualized. 

322 original_charge_densities : bool, optional 

323 If ``True``, color the surface using the original (unsmoothed) segment 

324 charge densities instead of the distance-weighted averages. Default is 

325 ``False``. 

326 use_continuous_colors : bool, optional 

327 If ``True``, use interpolated colors across the surface; otherwise, 

328 segments are uniformly colored. Default is ``False``. 

329 colormap : str, optional 

330 Name of the colormap used to map charge density to color (e.g. 

331 ``"jet"``, ``"viridis"``). Default is ``"jet"``. 

332 segment_edge_color : str or None, optional 

333 Color name for the edges between segments (e.g. ``"black"``). 

334 If ``None`` or if ``use_continuous_colors`` is ``True``, no edge 

335 loops are drawn. Default is ``None``. 

336 

337 Returns 

338 ------- 

339 tuple of Geometry3D 

340 A sequence of Open3D geometries: mesh, loops (if any), atom spheres, 

341 and bond sticks. 

342 

343 Examples 

344 -------- 

345 >>> from importlib.resources import files 

346 >>> from cosmolayer.cosmosac import Component 

347 >>> from cosmolayer.cosmosac.visualize import generate_geometries 

348 >>> path = files("cosmolayer.data") / "C=C(N)O.cosmo" 

349 >>> component = Component(path.read_text()) 

350 >>> geometries = generate_geometries(component) 

351 >>> len(geometries) >= 1 

352 True 

353 >>> type(geometries[0]).__name__ 

354 'TriangleMesh' 

355 >>> geometries_loops = generate_geometries(component, segment_edge_color="black") 

356 >>> len(geometries_loops) > len(geometries) 

357 True 

358 """ 

359 mesh = surface_tessellation( 

360 component, 

361 original_charge_densities, 

362 use_continuous_colors, 

363 colormap, 

364 ) 

365 if segment_edge_color is None or use_continuous_colors: 

366 loops = [] 

367 else: 

368 loops = find_loops(mesh, segment_edge_color) 

369 atom_spheres = create_atom_spheres(component, 0.25) 

370 bond_sticks = create_bond_sticks(component, 0.25, 0.1) 

371 return (mesh, *loops, *atom_spheres, *bond_sticks) 

372 

373 

374def get_parser() -> argparse.ArgumentParser: 

375 """Return the argument parser for cosmoview (used by sphinx-argparse).""" 

376 parser = argparse.ArgumentParser( 

377 formatter_class=argparse.RawTextHelpFormatter, 

378 description="Visualize COSMO files", 

379 ) 

380 parser.add_argument( 

381 "cosmo_file", 

382 type=pathlib.Path, 

383 help="Path to a COSMO quantum mechanical output file", 

384 ) 

385 parser.add_argument( 

386 "--show-original-charge-densities", 

387 action="store_true", 

388 help="Show original charge densities instead of smoothed ones", 

389 ) 

390 parser.add_argument( 

391 "--use-continuous-colors", 

392 action="store_true", 

393 help="Use continuous colors instead of uniformly colored segments", 

394 ) 

395 parser.add_argument( 

396 "--segment-edge-color", 

397 type=str, 

398 default=None, 

399 help="Color of the edges between segments (default: None)", 

400 ) 

401 parser.add_argument( 

402 "--colormap", 

403 type=str, 

404 default="jet", 

405 help="Matplotlib colormap name (default: jet)", 

406 ) 

407 return parser 

408 

409 

410def main() -> None: 

411 args = get_parser().parse_args() 

412 component = Component(args.cosmo_file.read_text()) 

413 geometries = generate_geometries( 

414 component, 

415 args.show_original_charge_densities, 

416 args.use_continuous_colors, 

417 args.colormap, 

418 args.segment_edge_color, 

419 ) 

420 o3d.visualization.draw_geometries( 

421 geometries, 

422 mesh_show_back_face=True, 

423 window_name=f"Surface Segments from {args.cosmo_file.name}", 

424 ) 

425 

426 

427if __name__ == "__main__": 

428 main()