Coverage for cosmolayer / cosmosac / visualize.py: 89%
192 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 14:25 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 14:25 +0000
1"""
2.. module:: cosmolayer.cosmosac.visualize
3 :synopsis: Visualize COSMO-SAC surface segments.
5.. functionauthor:: Charlles Abreu <craabreu@gmail.com>
6"""
8import argparse
9import pathlib
11import cmap
12import networkx as nx
13import numpy as np
14import open3d as o3d
15import periodictable as pt
17from cosmolayer.cosmosac import Component
19RADII_MULTIPLIERS: tuple[float, float, float] = (1.5, 2.5, 4.0)
21ELEMENT_COLORS = { # https://pymolwiki.org/Color_Values
22 "Br": (0.650980392, 0.160784314, 0.160784314),
23 "C": (0.2, 1.0, 0.2),
24 "Cl": (0.121568627, 0.941176471, 0.121568627),
25 "F": (0.701960784, 1.0, 1.0),
26 "H": (0.9, 0.9, 0.9),
27 "I": (0.580392157, 0.0, 0.580392157),
28 "N": (0.2, 0.2, 1.0),
29 "O": (1.0, 0.3, 0.3),
30 "P": (1.0, 0.501960784, 0.0),
31 "Si": (0.941176471, 0.784313725, 0.627450980),
32 "S": (0.9, 0.775, 0.25),
33}
36TOLERANCE: float = 1e-10
37DOT_PRODUCT_TOLERANCE: float = 0.9
38X_AXIS: np.ndarray = np.array([1.0, 0.0, 0.0])
39Y_AXIS: np.ndarray = np.array([0.0, 1.0, 0.0])
40Z_AXIS: np.ndarray = np.array([0.0, 0.0, 1.0])
43def estimate_vdw_radius(element: str) -> float:
44 return float(pt.elements.symbol(element).covalent_radius) + 0.8 # Å
47def create_atom_spheres(
48 component: Component,
49 radius_scale: float,
50 resolution: int = 40,
51 default_color: tuple[float, float, float] = (0.7, 0.7, 0.7),
52) -> list[o3d.geometry.TriangleMesh]:
53 atom_df = component.atom_data
54 spheres: list[o3d.geometry.TriangleMesh] = []
55 for item in atom_df.itertuples(index=False):
56 element = str(item.element).strip()
57 radius = estimate_vdw_radius(element) * radius_scale
58 sphere = o3d.geometry.TriangleMesh.create_sphere(
59 radius=radius,
60 resolution=resolution,
61 )
62 sphere.compute_vertex_normals()
63 sphere.translate((item.x, item.y, item.z))
64 rgb = np.array(ELEMENT_COLORS.get(element, default_color))
65 sphere.paint_uniform_color(rgb)
66 spheres.append(sphere)
67 return spheres
70def compute_rotation_matrix(
71 original_axis: np.ndarray, target_axis: np.ndarray, normalize: bool = False
72) -> np.ndarray:
73 """Rodrigues' rotation formula between two unit-direction vectors."""
74 if normalize:
75 original_axis = original_axis / np.linalg.norm(original_axis)
76 target_axis = target_axis / np.linalg.norm(target_axis)
77 v = np.cross(original_axis, target_axis)
78 c = original_axis.dot(target_axis)
79 s2 = v.dot(v)
80 if s2 < TOLERANCE:
81 if c > 0:
82 return np.eye(3) # Parallel (c ≈ 1) → identity
83 arbitrary = X_AXIS if abs(original_axis[0]) < DOT_PRODUCT_TOLERANCE else Y_AXIS
84 arbitrary -= original_axis * original_axis.dot(arbitrary)
85 orthogonal = arbitrary / np.linalg.norm(arbitrary)
86 return 2.0 * np.outer(orthogonal, orthogonal) - np.eye(3)
87 kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
88 rotation: np.ndarray = np.eye(3) + kmat + ((1 - c) / s2) * kmat @ kmat
89 return rotation
92def create_bond_sticks(
93 component: Component,
94 atom_radius_scale: float,
95 bond_radius: float,
96 resolution: int = 100,
97 default_color: tuple[float, float, float] = (0.7, 0.7, 0.7),
98) -> list[o3d.geometry.TriangleMesh]:
99 atom_df = component.atom_data
100 coords = atom_df[["x", "y", "z"]].values
101 elements = atom_df["element"].values
102 radii = atom_df["element"].apply(estimate_vdw_radius).values * atom_radius_scale
103 bonds = component.bonds
104 cylinders: list[o3d.geometry.TriangleMesh] = []
105 for i, j in bonds:
106 vector = coords[j] - coords[i]
107 length = np.linalg.norm(vector)
108 if length < radii[i] + radii[j]:
109 continue
110 axis = vector / length
111 rotation = compute_rotation_matrix(Z_AXIS, axis)
112 midpoint = (coords[i] + coords[j] + (radii[i] - radii[j]) * axis) / 2
113 for k in (i, j):
114 cylinder = o3d.geometry.TriangleMesh.create_cylinder(
115 radius=bond_radius,
116 height=np.linalg.norm(coords[k] - midpoint),
117 resolution=resolution,
118 )
119 cylinder.rotate(rotation, center=np.zeros(3))
120 cylinder.translate((coords[k] + midpoint) / 2)
121 cylinder.compute_vertex_normals()
122 rgb = np.array(ELEMENT_COLORS.get(elements[k], default_color))
123 cylinder.paint_uniform_color(rgb)
124 cylinders.append(cylinder)
125 return cylinders
128def ball_pivoting_algorithm(
129 points: np.ndarray,
130 normals: np.ndarray,
131 vertex_rgb: np.ndarray,
132 radii_multipliers: tuple[float, float, float],
133) -> tuple[o3d.geometry.TriangleMesh, np.ndarray]:
134 pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points))
135 pcd.normals = o3d.utility.Vector3dVector(normals)
137 spacing = np.asarray(pcd.compute_nearest_neighbor_distance()).mean().item()
138 radii = o3d.utility.DoubleVector([m * spacing for m in radii_multipliers])
140 mesh_bpa = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
141 pcd, radii
142 )
144 mesh_bpa.remove_degenerate_triangles()
145 mesh_bpa.remove_duplicated_triangles()
146 mesh_bpa.remove_non_manifold_edges()
147 mesh_bpa.remove_unreferenced_vertices()
149 kdtree = o3d.geometry.KDTreeFlann(pcd)
150 indices = np.empty(len(mesh_bpa.vertices), dtype=int)
151 for vi, v in enumerate(mesh_bpa.vertices):
152 _, idx, _ = kdtree.search_knn_vector_3d(v, 1)
153 indices[vi] = int(idx[0])
155 vertex_rgb = vertex_rgb[indices]
156 mesh_bpa.vertex_colors = o3d.utility.Vector3dVector(vertex_rgb)
158 return mesh_bpa, indices
161def find_loops(
162 mesh: o3d.geometry.TriangleMesh, edge_color: str
163) -> list[o3d.geometry.LineSet]:
164 graph = nx.Graph()
165 for triangle in mesh.triangles:
166 _, j, k = map(int, triangle)
167 graph.add_edge(j, k)
168 loops = nx.cycle_basis(graph)
170 vertices = np.asarray(mesh.vertices, dtype=float)
171 linesets: list[o3d.geometry.LineSet] = []
173 for loop in loops:
174 idx = np.asarray(loop + [loop[0]], dtype=int)
175 pts = vertices[idx]
176 lines = np.column_stack(
177 [np.arange(len(idx) - 1), np.arange(1, len(idx))]
178 ).astype(np.int32)
179 lineset = o3d.geometry.LineSet(
180 points=o3d.utility.Vector3dVector(pts),
181 lines=o3d.utility.Vector2iVector(lines),
182 )
183 if edge_color is not None:
184 rgb = np.asarray(cmap.Color(edge_color))[:3]
185 lineset.paint_uniform_color(rgb)
186 linesets.append(lineset)
188 return linesets
191def geodesic_centroid(center: np.ndarray, *vertices: np.ndarray) -> np.ndarray:
192 num_vertices = len(vertices)
193 vectors = [v - center for v in vertices]
194 norms = [np.linalg.norm(v) for v in vectors]
195 radius = sum(norms) / num_vertices
196 mean_vector = sum(vectors) / num_vertices
197 centroid: np.ndarray = center + radius * mean_vector / np.linalg.norm(mean_vector)
198 return centroid
201def surface_tessellation(
202 component: Component,
203 original_charge_densities: bool = False,
204 interpolated_colors: bool = False,
205 colormap: str = "jet",
206) -> o3d.geometry.TriangleMesh:
207 segment_data = component.segment_data
208 atom_data = component.atom_data
209 sigma_grid = component.sigma_grid
210 vmin, vmax = sigma_grid[0], sigma_grid[-1]
211 sigmas = segment_data[
212 "sigma" if original_charge_densities else "sigma_avg"
213 ].values.clip(vmin, vmax)
215 atom_coords = np.stack(
216 [segment_data["atom"].map(atom_data[axis]).values for axis in "xyz"], axis=1
217 )
218 pts = segment_data[["x", "y", "z"]].values
219 displacements = pts - atom_coords
220 normals = displacements / np.linalg.norm(displacements, axis=1, keepdims=True)
222 normalized_sigmas = (sigmas.clip(vmin, vmax) - vmin) / (vmax - vmin)
223 mapper = cmap.Colormap(colormap)
224 vertex_rgb = mapper(normalized_sigmas)[:, :3]
226 mesh_bpa, indices = ball_pivoting_algorithm(
227 pts, normals, vertex_rgb, RADII_MULTIPLIERS
228 )
230 if interpolated_colors:
231 return mesh_bpa
233 vertices = np.asarray(mesh_bpa.vertices, dtype=float)
234 triangles = np.asarray(mesh_bpa.triangles, dtype=int)
235 colors = np.asarray(mesh_bpa.vertex_colors, dtype=float)
236 atoms = segment_data["atom"].values[indices]
238 new_vertices = vertices.tolist()
239 new_colors = colors.tolist()
241 def add_vertex(v: np.ndarray, c: np.ndarray) -> int:
242 idx = len(new_vertices)
243 new_vertices.append(v)
244 new_colors.append(c)
245 return idx
247 midpoint_cache: dict[tuple[int, int], int] = {}
249 def midpoint_vertices(i: int, j: int) -> tuple[int, int]:
250 if (i, j) in midpoint_cache:
251 return midpoint_cache[(i, j)], midpoint_cache[(j, i)]
253 if atoms[i] == atoms[j]:
254 midpoint = geodesic_centroid(
255 atom_coords[atoms[i]], vertices[i], vertices[j]
256 )
257 else:
258 midpoint = (vertices[i] + vertices[j]) / 2
260 mij = midpoint_cache[(i, j)] = add_vertex(midpoint, colors[i])
261 mji = midpoint_cache[(j, i)] = add_vertex(midpoint, colors[j])
262 return mij, mji
264 new_triangles = []
266 for triangle in triangles:
267 i, j, k = map(int, triangle)
268 mij, mji = midpoint_vertices(i, j)
269 mjk, mkj = midpoint_vertices(j, k)
270 mik, mki = midpoint_vertices(i, k)
272 if atoms[i] == atoms[j] == atoms[k]:
273 centroid = geodesic_centroid(
274 atom_coords[atoms[i]], vertices[i], vertices[j], vertices[k]
275 )
276 else:
277 centroid = (vertices[i] + vertices[j] + vertices[k]) / 3
279 mijk = add_vertex(centroid, colors[i])
280 mjki = add_vertex(centroid, colors[j])
281 mkij = add_vertex(centroid, colors[k])
283 new_triangles += [
284 [i, mij, mijk],
285 [i, mik, mijk],
286 [j, mji, mjki],
287 [j, mjk, mjki],
288 [k, mkj, mkij],
289 [k, mki, mkij],
290 ]
292 mesh = o3d.geometry.TriangleMesh(
293 vertices=o3d.utility.Vector3dVector(new_vertices),
294 triangles=o3d.utility.Vector3iVector(new_triangles),
295 )
296 mesh.vertex_colors = o3d.utility.Vector3dVector(new_colors)
297 mesh.compute_vertex_normals()
299 return mesh
302def generate_geometries(
303 component: Component,
304 original_charge_densities: bool = False,
305 use_continuous_colors: bool = False,
306 colormap: str = "jet",
307 segment_edge_color: str | None = None,
308) -> tuple[o3d.geometry.Geometry3D, ...]:
309 """Build Open3D geometries for visualizing a component's COSMO surface.
311 Returns a tuple of Open3D geometries:
313 (1) a tessellated surface mesh colored by screening charge density;
314 (2) optionally, segment-boundary loop line sets when ``segment_edge_color`` is set;
315 (3) atom spheres; and
316 (4) bond sticks.
318 Parameters
319 ----------
320 component : Component
321 The molecular component whose COSMO surface is to be visualized.
322 original_charge_densities : bool, optional
323 If ``True``, color the surface using the original (unsmoothed) segment
324 charge densities instead of the distance-weighted averages. Default is
325 ``False``.
326 use_continuous_colors : bool, optional
327 If ``True``, use interpolated colors across the surface; otherwise,
328 segments are uniformly colored. Default is ``False``.
329 colormap : str, optional
330 Name of the colormap used to map charge density to color (e.g.
331 ``"jet"``, ``"viridis"``). Default is ``"jet"``.
332 segment_edge_color : str or None, optional
333 Color name for the edges between segments (e.g. ``"black"``).
334 If ``None`` or if ``use_continuous_colors`` is ``True``, no edge
335 loops are drawn. Default is ``None``.
337 Returns
338 -------
339 tuple of Geometry3D
340 A sequence of Open3D geometries: mesh, loops (if any), atom spheres,
341 and bond sticks.
343 Examples
344 --------
345 >>> from importlib.resources import files
346 >>> from cosmolayer.cosmosac import Component
347 >>> from cosmolayer.cosmosac.visualize import generate_geometries
348 >>> path = files("cosmolayer.data") / "C=C(N)O.cosmo"
349 >>> component = Component(path.read_text())
350 >>> geometries = generate_geometries(component)
351 >>> len(geometries) >= 1
352 True
353 >>> type(geometries[0]).__name__
354 'TriangleMesh'
355 >>> geometries_loops = generate_geometries(component, segment_edge_color="black")
356 >>> len(geometries_loops) > len(geometries)
357 True
358 """
359 mesh = surface_tessellation(
360 component,
361 original_charge_densities,
362 use_continuous_colors,
363 colormap,
364 )
365 if segment_edge_color is None or use_continuous_colors:
366 loops = []
367 else:
368 loops = find_loops(mesh, segment_edge_color)
369 atom_spheres = create_atom_spheres(component, 0.25)
370 bond_sticks = create_bond_sticks(component, 0.25, 0.1)
371 return (mesh, *loops, *atom_spheres, *bond_sticks)
374def get_parser() -> argparse.ArgumentParser:
375 """Return the argument parser for cosmoview (used by sphinx-argparse)."""
376 parser = argparse.ArgumentParser(
377 formatter_class=argparse.RawTextHelpFormatter,
378 description="Visualize COSMO files",
379 )
380 parser.add_argument(
381 "cosmo_file",
382 type=pathlib.Path,
383 help="Path to a COSMO quantum mechanical output file",
384 )
385 parser.add_argument(
386 "--show-original-charge-densities",
387 action="store_true",
388 help="Show original charge densities instead of smoothed ones",
389 )
390 parser.add_argument(
391 "--use-continuous-colors",
392 action="store_true",
393 help="Use continuous colors instead of uniformly colored segments",
394 )
395 parser.add_argument(
396 "--segment-edge-color",
397 type=str,
398 default=None,
399 help="Color of the edges between segments (default: None)",
400 )
401 parser.add_argument(
402 "--colormap",
403 type=str,
404 default="jet",
405 help="Matplotlib colormap name (default: jet)",
406 )
407 return parser
410def main() -> None:
411 args = get_parser().parse_args()
412 component = Component(args.cosmo_file.read_text())
413 geometries = generate_geometries(
414 component,
415 args.show_original_charge_densities,
416 args.use_continuous_colors,
417 args.colormap,
418 args.segment_edge_color,
419 )
420 o3d.visualization.draw_geometries(
421 geometries,
422 mesh_show_back_face=True,
423 window_name=f"Surface Segments from {args.cosmo_file.name}",
424 )
427if __name__ == "__main__":
428 main()