Janus 2.0.0
High-performance C++20 dual-mode numerical framework
Loading...
Searching...
No Matches
JanusIO.hpp
Go to the documentation of this file.
1#pragma once
4#include <Eigen/Dense>
5#include <casadi/casadi.hpp>
6#include <cstdlib>
7#include <fstream>
8#include <iomanip>
9#include <iostream>
10#include <limits>
11#include <map>
12#include <set>
13#include <sstream>
14#include <vector>
15
27
28namespace janus {
29
30// Forward declaration
31class Function;
32
39template <typename Derived>
40void print(const std::string &label, const Eigen::MatrixBase<Derived> &mat) {
41 std::cout << label << ":\n" << mat << "\n" << std::endl;
42}
43
51template <typename Derived>
52void disp(const std::string &label, const Eigen::MatrixBase<Derived> &mat) {
53 print(label, mat);
54}
55
62template <typename Derived> auto eval(const Eigen::MatrixBase<Derived> &mat) {
63 using Scalar = typename Derived::Scalar;
64 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
65 // Flatten to MX, evaluate, map back
66 SymbolicScalar flat = SymbolicScalar::zeros(mat.rows(), mat.cols());
67 for (int i = 0; i < mat.rows(); ++i) {
68 for (int j = 0; j < mat.cols(); ++j) {
69 flat(i, j) = mat(i, j);
70 }
71 }
72
73 try {
74 casadi::Function f("f", {}, {flat});
75 auto res = f(std::vector<casadi::DM>{});
76 casadi::DM res_dm = res[0];
77
78 NumericMatrix res_eigen(mat.rows(), mat.cols());
79 for (int i = 0; i < mat.rows(); ++i) {
80 for (int j = 0; j < mat.cols(); ++j) {
81 res_eigen(i, j) = static_cast<double>(res_dm(i, j));
82 }
83 }
84 return res_eigen;
85 } catch (const std::exception &e) {
86 throw RuntimeError("eval failed (expression contains free variables): " +
87 std::string(e.what()));
88 }
89 } else {
90 return mat.eval();
91 }
92}
93
98inline double eval(const SymbolicScalar &val) {
99 if (val.size1() != 1 || val.size2() != 1) {
100 throw RuntimeError("eval scalar failed: expected 1x1 symbolic expression, got " +
101 std::to_string(val.size1()) + "x" + std::to_string(val.size2()));
102 }
103 try {
104 casadi::Function f("f", {}, {val});
105 auto res = f(std::vector<casadi::DM>{});
106 casadi::DM res_dm = res[0];
107 return static_cast<double>(res_dm);
108 } catch (const std::exception &e) {
109 throw RuntimeError("eval scalar failed: " + std::string(e.what()));
110 }
111}
112
117template <typename T, std::enable_if_t<std::is_arithmetic_v<T>, int> = 0> T eval(const T &val) {
118 return val;
119}
120
121// ======================================================================
122// Graph Visualization
123// ======================================================================
124
125namespace detail {
126
130inline std::string escape_dot_label(const std::string &s) {
131 std::string result;
132 result.reserve(s.size());
133 for (char c : s) {
134 switch (c) {
135 case '"':
136 result += "\\\"";
137 break;
138 case '\\':
139 result += "\\\\";
140 break;
141 case '\n':
142 result += "\\n";
143 break;
144 case '<':
145 result += "&lt;";
146 break;
147 case '>':
148 result += "&gt;";
149 break;
150 case '{':
151 case '}':
152 result += ' ';
153 break;
154 default:
155 result += c;
156 }
157 }
158 // Truncate long labels
159 if (result.size() > 40) {
160 result = result.substr(0, 37) + "...";
161 }
162 return result;
163}
164
168inline std::string get_op_name(const SymbolicScalar &mx) {
169 if (mx.is_symbolic()) {
170 std::ostringstream ss;
171 mx.disp(ss, false);
172 return ss.str();
173 }
174 if (mx.is_constant()) {
175 std::ostringstream ss;
176 mx.disp(ss, false);
177 std::string s = ss.str();
178 if (s.size() > 20)
179 s = s.substr(0, 17) + "...";
180 return s;
181 }
182 // Try to get operation name from string representation
183 std::ostringstream ss;
184 mx.disp(ss, false);
185 std::string s = ss.str();
186
187 // Common operations
188 if (s.find("sq(") == 0)
189 return "sq";
190 if (s.find("sin(") == 0)
191 return "sin";
192 if (s.find("cos(") == 0)
193 return "cos";
194 if (s.find("exp(") == 0)
195 return "exp";
196 if (s.find("log(") == 0)
197 return "log";
198 if (s.find("sqrt(") == 0)
199 return "sqrt";
200 if (s.find("tanh(") == 0)
201 return "tanh";
202
203 // For binary ops, return truncated expression
204 if (s.size() > 30)
205 s = s.substr(0, 27) + "...";
206 return s;
207}
208
209} // namespace detail
210
221inline void export_graph_dot(const SymbolicScalar &expr, const std::string &filename,
222 const std::string &name = "expression") {
223 // Get all free variables
224 std::vector<SymbolicScalar> free_vars = SymbolicScalar::symvar(expr);
225
226 // Build DOT file
227 std::string dot_filename = filename + ".dot";
228 std::ofstream out(dot_filename);
229 if (!out.is_open()) {
230 throw RuntimeError("Failed to open file for writing: " + dot_filename);
231 }
232
233 out << "digraph \"" << name << "\" {\n";
234 out << " rankdir=BT;\n"; // Bottom to top (inputs at bottom)
235 out << " splines=ortho;\n";
236 out << " node [shape=box, style=\"rounded,filled\", fontname=\"Helvetica\"];\n";
237 out << " edge [color=\"#666666\", arrowsize=0.7];\n\n";
238
239 // Title
240 out << " labelloc=\"t\";\n";
241 out << " label=\"" << detail::escape_dot_label(name) << "\";\n";
242 out << " fontsize=16;\n\n";
243
244 // Collect nodes by traversing the expression
245 std::map<std::string, int> node_ids; // ptr-based ID -> sequential ID
246 std::vector<std::pair<int, int>> edges; // edges as sequential IDs
247 std::set<const void *> visited;
248 int node_counter = 0;
249
250 // BFS traversal to collect all nodes
251 std::vector<SymbolicScalar> queue = {expr};
252 std::map<const void *, int> ptr_to_id;
253
254 while (!queue.empty()) {
255 SymbolicScalar current = queue.back();
256 queue.pop_back();
257
258 const void *ptr = current.get();
259 if (visited.count(ptr))
260 continue;
261 visited.insert(ptr);
262
263 int current_id = node_counter++;
264 ptr_to_id[ptr] = current_id;
265
266 // Get dependencies
267 casadi_int n = current.n_dep();
268 for (casadi_int i = 0; i < n; ++i) {
269 SymbolicScalar dep = current.dep(i);
270 queue.push_back(dep);
271 }
272 }
273
274 // Second pass: generate nodes and edges
275 visited.clear();
276 queue = {expr};
277
278 while (!queue.empty()) {
279 SymbolicScalar current = queue.back();
280 queue.pop_back();
281
282 const void *ptr = current.get();
283 if (visited.count(ptr))
284 continue;
285 visited.insert(ptr);
286
287 int current_id = ptr_to_id[ptr];
288 std::string label = detail::get_op_name(current);
289
290 // Determine node style based on type
291 std::string color = "#87CEEB"; // light blue default
292 std::string shape = "box";
293
294 if (current.is_symbolic()) {
295 color = "#90EE90"; // light green for inputs
296 shape = "ellipse";
297 } else if (current.is_constant()) {
298 color = "#FFE4B5"; // moccasin for constants
299 shape = "ellipse";
300 } else if (current.n_dep() == 0) {
301 color = "#DDA0DD"; // plum for leaf operations
302 }
303
304 out << " node_" << current_id << " [label=\"" << detail::escape_dot_label(label)
305 << "\", fillcolor=\"" << color << "\", shape=" << shape << "];\n";
306
307 // Add edges from dependencies
308 casadi_int n = current.n_dep();
309 for (casadi_int i = 0; i < n; ++i) {
310 SymbolicScalar dep = current.dep(i);
311 const void *dep_ptr = dep.get();
312 int dep_id = ptr_to_id[dep_ptr];
313 out << " node_" << dep_id << " -> node_" << current_id << ";\n";
314 queue.push_back(dep);
315 }
316 }
317
318 // Mark output node specially
319 if (!ptr_to_id.empty()) {
320 const void *out_ptr = expr.get();
321 int out_id = ptr_to_id[out_ptr];
322 out << "\n // Output marker\n";
323 out << " output [label=\"Output\", shape=doublecircle, fillcolor=\"#FFD700\"];\n";
324 out << " node_" << out_id << " -> output;\n";
325 }
326
327 out << "}\n";
328 out.close();
329}
330
337// Note: This overload is implemented after Function class is defined
338// Use the casadi::Function version directly for now
339
349inline bool render_graph(const std::string &dot_file, const std::string &output_file) {
350 // Determine output format from extension
351 std::string format = "pdf"; // default
352 size_t dot_pos = output_file.rfind('.');
353 if (dot_pos != std::string::npos) {
354 format = output_file.substr(dot_pos + 1);
355 }
356
357 // Build command: dot -Tformat input.dot -o output.ext
358 std::string cmd = "dot -T" + format + " \"" + dot_file + "\" -o \"" + output_file + "\"";
359
360 int result = std::system(cmd.c_str());
361 return result == 0;
362}
363
373inline bool visualize_graph(const SymbolicScalar &expr, const std::string &output_base) {
374 try {
375 export_graph_dot(expr, output_base);
376 return render_graph(output_base + ".dot", output_base + ".pdf");
377 } catch (const std::exception &) {
378 return false;
379 }
380}
381
382namespace detail {
383
387inline std::string escape_for_js(const std::string &content) {
388 std::string escaped;
389 for (char c : content) {
390 if (c == '\\')
391 escaped += "\\\\";
392 else if (c == '"')
393 escaped += "\\\"";
394 else if (c == '\n')
395 escaped += "\\n";
396 else if (c == '\r')
397 escaped += "\\r";
398 else if (c == '<')
399 escaped += "\\u003C";
400 else
401 escaped += c;
402 }
403 return escaped;
404}
405
409inline std::string escape_for_json(const std::string &s) {
410 std::string result;
411 for (char c : s) {
412 if (c == '"')
413 result += "\\\"";
414 else if (c == '\\')
415 result += "\\\\";
416 else if (c == '\n')
417 result += "\\n";
418 else if (c == '\r')
419 result += "\\r";
420 else if (c == '\t')
421 result += "\\t";
422 else if (c == '<')
423 result += "\\u003C";
424 else
425 result += c;
426 }
427 return result;
428}
429
437inline void write_graph_html(std::ostream &out, const std::string &title,
438 const std::string &escaped_dot, const std::string &node_data_json,
439 const std::string &edges_json,
440 const std::string &extra_header_js = "") {
441 out << R"HTMLSTART(<!DOCTYPE html>
442<html lang="en">
443<head>
444 <meta charset="UTF-8">
445 <title>)HTMLSTART"
446 << title << R"HTMLMID(</title>
447 <script src="https://cdnjs.cloudflare.com/ajax/libs/viz.js/2.1.2/viz.js"></script>
448 <script src="https://cdnjs.cloudflare.com/ajax/libs/viz.js/2.1.2/full.render.js"></script>
449 <style>
450 * { margin: 0; padding: 0; box-sizing: border-box; }
451 html, body { height: 100%; width: 100%; }
452 body { font-family: 'Segoe UI', sans-serif; background: #1a1a2e; color: #eee; overflow: hidden; display: flex; }
453 #controls { position: fixed; top: 10px; left: 10px; z-index: 100; background: rgba(0,0,0,0.8);
454 padding: 12px; border-radius: 8px; }
455 #controls button { margin: 2px; padding: 8px 14px; cursor: pointer; border: none;
456 border-radius: 4px; background: #4a4a6a; color: white; font-size: 13px; }
457 #controls button:hover { background: #6a6a8a; }
458 #graph { flex: 1; cursor: grab; overflow: hidden; height: 100%; }
459 #graph:active { cursor: grabbing; }
460 #graph svg { display: block; }
461 #sidebar { width: 320px; height: 100%; background: #16213e; padding: 16px; overflow-y: auto;
462 border-left: 2px solid #0f3460; }
463 #sidebar h2 { color: #e94560; margin-bottom: 12px; font-size: 16px; }
464 #sidebar .section { margin-bottom: 16px; }
465 #sidebar .label { color: #888; font-size: 11px; text-transform: uppercase; margin-bottom: 4px; }
466 #sidebar .value { background: #0f3460; padding: 10px; border-radius: 6px; font-family: monospace;
467 font-size: 13px; word-break: break-all; white-space: pre-wrap; max-height: 200px; overflow-y: auto; }
468 #sidebar .type-badge { display: inline-block; padding: 3px 8px; border-radius: 4px; font-size: 11px;
469 margin-left: 8px; }
470 .type-input { background: #90EE90; color: #000; }
471 .type-constant { background: #FFE4B5; color: #000; }
472 .type-operation { background: #87CEEB; color: #000; }
473 .type-leaf { background: #DDA0DD; color: #000; }
474 #info { position: fixed; bottom: 10px; left: 10px; background: rgba(0,0,0,0.8);
475 padding: 10px; border-radius: 4px; font-size: 12px; }
476 #stats { position: fixed; top: 10px; right: 340px; background: rgba(0,0,0,0.8);
477 padding: 10px; border-radius: 4px; font-size: 12px; }
478 .node-highlighted polygon, .node-highlighted ellipse, .node-highlighted path { stroke: #e94560 !important; stroke-width: 3px !important; }
479 .edge-highlighted path { stroke: #e94560 !important; stroke-width: 2px !important; }
480 .edge-highlighted polygon { stroke: #e94560 !important; fill: #e94560 !important; }
481 svg .node { cursor: pointer; }
482 svg .node:hover polygon, svg .node:hover ellipse { stroke: #fff !important; stroke-width: 2px !important; }
483 </style>
484</head>
485<body>
486 <div id="controls">
487 <button onclick="zoomIn()">Zoom +</button>
488 <button onclick="zoomOut()">Zoom -</button>
489 <button onclick="resetView()">Reset</button>
490 <button onclick="fitToScreen()">Fit</button>
491 </div>
492 <div id="graph"></div>
493 <div id="sidebar">
494 <h2>Node Info</h2>
495 <div id="node-info">
496 <p style="color:#666; font-style:italic;">Click on a node to see details</p>
497 </div>
498 </div>
499 <div id="info">Scroll to zoom - Drag to pan - Click nodes for details</div>
500 <div id="stats"></div>
501 <script>
502 const dotSrc = ")HTMLMID"
503 << escaped_dot << R"HTMLMID2(";
504 const nodeData = )HTMLMID2"
505 << node_data_json << R"HTMLMID3(;
506 const edges = )HTMLMID3"
507 << edges_json << R"HTMLEND(;
508)HTMLEND";
509
510 // Insert optional extra JS (e.g. stats computation for SX graphs)
511 if (!extra_header_js.empty()) {
512 out << " " << extra_header_js << "\n";
513 }
514
515 out << R"HTMLEND2(
516 let scale = 1, panX = 0, panY = 0, isDragging = false, startX, startY;
517 let selectedNode = null;
518 const container = document.getElementById('graph');
519 const sidebar = document.getElementById('node-info');
520
521 new Viz().renderSVGElement(dotSrc).then(svg => {
522 container.appendChild(svg);
523 svg.style.transformOrigin = '0 0';
524 fitToScreen();
525 setupPanZoom(svg);
526 setupNodeInteraction(svg);
527 });
528
529 function updateTransform(svg) {
530 svg.style.transform = `translate(${panX}px, ${panY}px) scale(${scale})`;
531 }
532 function zoomIn() { scale *= 1.3; updateTransform(container.querySelector('svg')); }
533 function zoomOut() { scale /= 1.3; updateTransform(container.querySelector('svg')); }
534 function resetView() { scale = 1; panX = 0; panY = 0; updateTransform(container.querySelector('svg')); }
535 function fitToScreen() {
536 const svg = container.querySelector('svg');
537 if (!svg) return;
538 const bbox = svg.getBBox();
539 const availWidth = window.innerWidth - 320;
540 const scaleX = (availWidth - 40) / (bbox.width + 40);
541 const scaleY = (window.innerHeight - 40) / (bbox.height + 40);
542 scale = Math.min(scaleX, scaleY);
543 panX = (availWidth - bbox.width * scale) / 2;
544 panY = (window.innerHeight - bbox.height * scale) / 2;
545 updateTransform(svg);
546 }
547
548 function setupPanZoom(svg) {
549 container.addEventListener('wheel', e => {
550 e.preventDefault();
551 const rect = container.getBoundingClientRect();
552 const mouseX = e.clientX - rect.left, mouseY = e.clientY - rect.top;
553 const zoomFactor = e.deltaY < 0 ? 1.1 : 0.9;
554 panX = mouseX - (mouseX - panX) * zoomFactor;
555 panY = mouseY - (mouseY - panY) * zoomFactor;
556 scale *= zoomFactor;
557 updateTransform(svg);
558 });
559 container.addEventListener('mousedown', e => {
560 if (e.target.closest('.node')) return;
561 isDragging = true; startX = e.clientX - panX; startY = e.clientY - panY;
562 });
563 container.addEventListener('mousemove', e => { if (isDragging) { panX = e.clientX - startX; panY = e.clientY - startY; updateTransform(svg); } });
564 container.addEventListener('mouseup', () => isDragging = false);
565 container.addEventListener('mouseleave', () => isDragging = false);
566 }
567
568 function setupNodeInteraction(svg) {
569 const nodes = svg.querySelectorAll('.node');
570 nodes.forEach(node => {
571 const nodeId = node.id;
572 node.addEventListener('click', e => {
573 e.stopPropagation();
574 selectNode(svg, nodeId);
575 });
576 });
577 }
578
579 function selectNode(svg, nodeId) {
580 svg.querySelectorAll('.node-highlighted').forEach(n => n.classList.remove('node-highlighted'));
581 svg.querySelectorAll('.edge-highlighted').forEach(e => e.classList.remove('edge-highlighted'));
582
583 const node = svg.getElementById(nodeId);
584 if (!node) return;
585
586 node.classList.add('node-highlighted');
587 selectedNode = nodeId;
588
589 const nodeNum = parseInt(nodeId.replace('node_', '').replace('output_', '-'));
590 edges.forEach(([from, to]) => {
591 if (from === nodeNum || to === nodeNum || to === -nodeNum - 1) {
592 svg.querySelectorAll('.edge').forEach(edge => {
593 const title = edge.querySelector('title');
594 if (title) {
595 const edgeStr = title.textContent;
596 const toId = to < 0 ? `output_${-to - 1}` : `node_${to}`;
597 if (edgeStr.includes(`node_${from}`) && edgeStr.includes(toId)) {
598 edge.classList.add('edge-highlighted');
599 }
600 }
601 });
602 }
603 });
604
605 const data = nodeData[nodeId];
606 if (data) {
607 const nodeLabel = data.short || data.label || '';
608 const fullExpr = data.full || data.label || '';
609 sidebar.innerHTML = `
610 <div class="section">
611 <div class="label">Node ID</div>
612 <div class="value">${data.id} <span class="type-badge type-${data.type}">${data.type}</span></div>
613 </div>
614 <div class="section">
615 <div class="label">Label</div>
616 <div class="value">${escapeHtml(nodeLabel)}</div>
617 </div>
618 ${fullExpr !== nodeLabel ? `<div class="section">
619 <div class="label">Full Expression</div>
620 <div class="value">${escapeHtml(fullExpr)}</div>
621 </div>` : ''}
622 <div class="section">
623 <div class="label">Dependencies (${data.deps.length})</div>
624 <div class="value">${data.deps.length > 0 ? data.deps.map(d => 'node_' + d).join(', ') : 'None'}</div>
625 </div>
626 `;
627 } else if (nodeId === 'output' || nodeId.startsWith('output_')) {
628 const outIdx = nodeId.replace('output_', '').replace('output', '0');
629 sidebar.innerHTML = `
630 <div class="section">
631 <div class="label">Node</div>
632 <div class="value">Output[${outIdx}] <span class="type-badge" style="background:#FFD700;color:#000;">output</span></div>
633 </div>
634 <div class="section">
635 <div class="label">Description</div>
636 <div class="value">Output element ${outIdx} of the expression</div>
637 </div>
638 `;
639 }
640 }
641
642 function escapeHtml(str) {
643 return str.replace(/&/g, '&amp;').replace(/</g, '&lt;').replace(/>/g, '&gt;').replace(/"/g, '&quot;');
644 }
645 </script>
646</body>
647</html>
648)HTMLEND2";
649}
650
651} // namespace detail
652
663inline void export_graph_html(const SymbolicScalar &expr, const std::string &filename,
664 const std::string &name = "expression") {
665 // First generate the DOT content and collect node metadata
666 std::ostringstream dot_stream;
667 std::ostringstream node_data_stream; // JSON for node metadata
668
669 // Get all free variables
670 std::vector<SymbolicScalar> free_vars = SymbolicScalar::symvar(expr);
671
672 dot_stream << "digraph \"" << name << "\" {\n";
673 dot_stream << " rankdir=BT;\n";
674 dot_stream << " splines=ortho;\n";
675 dot_stream << " node [shape=box, style=\"rounded,filled\", fontname=\"Helvetica\"];\n";
676 dot_stream << " edge [color=\"#666666\", arrowsize=0.7];\n\n";
677 dot_stream << " labelloc=\"t\";\n";
678 dot_stream << " label=\"" << detail::escape_dot_label(name) << "\";\n";
679 dot_stream << " fontsize=16;\n\n";
680
681 // Collect nodes
682 std::set<const void *> visited;
683 std::map<const void *, int> ptr_to_id;
684 int node_counter = 0;
685
686 std::vector<SymbolicScalar> queue = {expr};
687 while (!queue.empty()) {
688 SymbolicScalar current = queue.back();
689 queue.pop_back();
690
691 const void *ptr = current.get();
692 if (visited.count(ptr))
693 continue;
694 visited.insert(ptr);
695 ptr_to_id[ptr] = node_counter++;
696
697 casadi_int n = current.n_dep();
698 for (casadi_int i = 0; i < n; ++i) {
699 queue.push_back(current.dep(i));
700 }
701 }
702
703 // Build node data JSON and edges JSON
704 node_data_stream << "{";
705 std::ostringstream edges_stream;
706 edges_stream << "[";
707 bool first_node = true;
708 bool first_edge = true;
709
710 // Second pass: generate nodes and edges
711 visited.clear();
712 queue = {expr};
713 while (!queue.empty()) {
714 SymbolicScalar current = queue.back();
715 queue.pop_back();
716
717 const void *ptr = current.get();
718 if (visited.count(ptr))
719 continue;
720 visited.insert(ptr);
721
722 int current_id = ptr_to_id[ptr];
723
724 // Get full expression string (no truncation)
725 std::ostringstream full_expr;
726 current.disp(full_expr, false);
727 std::string full_label = full_expr.str();
728
729 // Get short label for display
730 std::string short_label = detail::get_op_name(current);
731 std::string node_type = "operation";
732 std::string color = "#87CEEB";
733 std::string shape = "box";
734
735 if (current.is_symbolic()) {
736 color = "#90EE90";
737 shape = "ellipse";
738 node_type = "input";
739 } else if (current.is_constant()) {
740 color = "#FFE4B5";
741 shape = "ellipse";
742 node_type = "constant";
743 } else if (current.n_dep() == 0) {
744 color = "#DDA0DD";
745 node_type = "leaf";
746 }
747
748 // Add to node data JSON
749 if (!first_node)
750 node_data_stream << ",";
751 first_node = false;
752 node_data_stream << "\"node_" << current_id << "\":{";
753 node_data_stream << "\"id\":" << current_id << ",";
754 node_data_stream << "\"short\":\"" << detail::escape_for_json(short_label) << "\",";
755 node_data_stream << "\"full\":\"" << detail::escape_for_json(full_label) << "\",";
756 node_data_stream << "\"type\":\"" << node_type << "\",";
757 node_data_stream << "\"deps\":[";
758
759 casadi_int n = current.n_dep();
760 for (casadi_int i = 0; i < n; ++i) {
761 if (i > 0)
762 node_data_stream << ",";
763 SymbolicScalar dep = current.dep(i);
764 node_data_stream << ptr_to_id[dep.get()];
765 }
766 node_data_stream << "]}";
767
768 dot_stream << " node_" << current_id << " [label=\""
769 << detail::escape_dot_label(short_label) << "\", fillcolor=\"" << color
770 << "\", shape=" << shape << ", id=\"node_" << current_id << "\"];\n";
771
772 for (casadi_int i = 0; i < n; ++i) {
773 SymbolicScalar dep = current.dep(i);
774 const void *dep_ptr = dep.get();
775 int dep_id = ptr_to_id[dep_ptr];
776 dot_stream << " node_" << dep_id << " -> node_" << current_id << ";\n";
777
778 if (!first_edge)
779 edges_stream << ",";
780 first_edge = false;
781 edges_stream << "[" << dep_id << "," << current_id << "]";
782
783 queue.push_back(dep);
784 }
785 }
786
787 // Mark output
788 if (!ptr_to_id.empty()) {
789 const void *out_ptr = expr.get();
790 int out_id = ptr_to_id[out_ptr];
791 dot_stream << "\n output [label=\"Output\", shape=doublecircle, fillcolor=\"#FFD700\", "
792 "id=\"output\"];\n";
793 dot_stream << " node_" << out_id << " -> output;\n";
794
795 if (!first_edge)
796 edges_stream << ",";
797 edges_stream << "[" << out_id << ",-1]";
798 }
799 dot_stream << "}\n";
800 node_data_stream << "}";
801 edges_stream << "]";
802
803 std::string dot_content = dot_stream.str();
804
805 // Write HTML file
806 std::string html_filename = filename + ".html";
807 std::ofstream out(html_filename);
808 if (!out.is_open()) {
809 throw RuntimeError("Failed to open file for writing: " + html_filename);
810 }
811
813 detail::escape_for_js(dot_content), node_data_stream.str(),
814 edges_stream.str());
815 out.close();
816}
817
818// ======================================================================
819// Deep Graph Visualization (SX-based)
820// ======================================================================
821
822namespace detail {
823
830inline std::string get_sx_operation(const casadi::SXElem &elem) {
831 if (elem.is_symbolic()) {
832 return elem.name();
833 }
834 if (elem.is_constant()) {
835 double val = static_cast<double>(elem);
836 if (val == 0.0)
837 return "0";
838 if (val == 1.0)
839 return "1";
840 if (val == -1.0)
841 return "-1";
842 std::ostringstream oss;
843 oss << std::setprecision(4) << val;
844 return oss.str();
845 }
846 if (!elem.is_leaf()) {
847 // Non-leaf node: has an operation
848 // CasADi operation codes from casadi/core/casadi_math.hpp
849 casadi_int op = elem.op();
850 switch (op) {
851 // Unary operations
852 case casadi::OP_ASSIGN:
853 return "=";
854 case casadi::OP_NEG:
855 return "neg";
856 case casadi::OP_NOT:
857 return "not";
858 case casadi::OP_SQ:
859 return "sq";
860 case casadi::OP_SQRT:
861 return "sqrt";
862 case casadi::OP_EXP:
863 return "exp";
864 case casadi::OP_LOG:
865 return "log";
866 case casadi::OP_SIN:
867 return "sin";
868 case casadi::OP_COS:
869 return "cos";
870 case casadi::OP_TAN:
871 return "tan";
872 case casadi::OP_ASIN:
873 return "asin";
874 case casadi::OP_ACOS:
875 return "acos";
876 case casadi::OP_ATAN:
877 return "atan";
878 case casadi::OP_SINH:
879 return "sinh";
880 case casadi::OP_COSH:
881 return "cosh";
882 case casadi::OP_TANH:
883 return "tanh";
884 case casadi::OP_ASINH:
885 return "asinh";
886 case casadi::OP_ACOSH:
887 return "acosh";
888 case casadi::OP_ATANH:
889 return "atanh";
890 case casadi::OP_FABS:
891 return "abs";
892 case casadi::OP_FLOOR:
893 return "floor";
894 case casadi::OP_CEIL:
895 return "ceil";
896 case casadi::OP_SIGN:
897 return "sign";
898 case casadi::OP_ERF:
899 return "erf";
900 case casadi::OP_ERFINV:
901 return "erfinv";
902 case casadi::OP_INV:
903 return "inv";
904
905 // Binary operations
906 case casadi::OP_ADD:
907 return "+";
908 case casadi::OP_SUB:
909 return "-";
910 case casadi::OP_MUL:
911 return "*";
912 case casadi::OP_DIV:
913 return "/";
914 case casadi::OP_POW:
915 return "pow";
916 case casadi::OP_ATAN2:
917 return "atan2";
918 case casadi::OP_FMIN:
919 return "min";
920 case casadi::OP_FMAX:
921 return "max";
922 case casadi::OP_FMOD:
923 return "mod";
924 case casadi::OP_COPYSIGN:
925 return "copysign";
926 case casadi::OP_HYPOT:
927 return "hypot";
928
929 // Comparison operations
930 case casadi::OP_LT:
931 return "<";
932 case casadi::OP_LE:
933 return "<=";
934 case casadi::OP_EQ:
935 return "==";
936 case casadi::OP_NE:
937 return "!=";
938 case casadi::OP_AND:
939 return "&&";
940 case casadi::OP_OR:
941 return "||";
942
943 // Conditional
944 case casadi::OP_IF_ELSE_ZERO:
945 return "if_else_zero";
946
947 default:
948 return "op" + std::to_string(op);
949 }
950 }
951 return "?";
952}
953
957inline void get_sx_node_style(const casadi::SXElem &elem, std::string &color, std::string &shape) {
958 if (elem.is_symbolic()) {
959 color = "#90EE90"; // light green for inputs
960 shape = "ellipse";
961 } else if (elem.is_constant()) {
962 color = "#FFE4B5"; // moccasin for constants
963 shape = "ellipse";
964 } else if (!elem.is_leaf()) {
965 // Non-leaf: operation node
966 casadi_int op = elem.op();
967 // Color by operation category
968 if (op == casadi::OP_ADD || op == casadi::OP_SUB || op == casadi::OP_MUL ||
969 op == casadi::OP_DIV || op == casadi::OP_NEG) {
970 color = "#87CEEB"; // light blue for arithmetic
971 } else if (op == casadi::OP_SIN || op == casadi::OP_COS || op == casadi::OP_TAN ||
972 op == casadi::OP_ASIN || op == casadi::OP_ACOS || op == casadi::OP_ATAN) {
973 color = "#DDA0DD"; // plum for trig
974 } else if (op == casadi::OP_EXP || op == casadi::OP_LOG || op == casadi::OP_SQRT ||
975 op == casadi::OP_SQ || op == casadi::OP_POW) {
976 color = "#FFB6C1"; // light pink for power/exp
977 } else if (op == casadi::OP_LT || op == casadi::OP_LE || op == casadi::OP_EQ ||
978 op == casadi::OP_NE || op == casadi::OP_AND || op == casadi::OP_OR) {
979 color = "#98FB98"; // pale green for comparison
980 } else {
981 color = "#B0C4DE"; // light steel blue for other
982 }
983 shape = "box";
984 } else {
985 color = "#D3D3D3"; // light gray for unknown
986 shape = "box";
987 }
988}
989
990} // namespace detail
991
1002inline void export_sx_graph_dot(const casadi::SX &expr, const std::string &filename,
1003 const std::string &name = "expression") {
1004 std::string dot_filename = filename + ".dot";
1005 std::ofstream out(dot_filename);
1006 if (!out.is_open()) {
1007 throw RuntimeError("Failed to open file for writing: " + dot_filename);
1008 }
1009
1010 out << "digraph \"" << name << "\" {\n";
1011 out << " rankdir=BT;\n";
1012 out << " splines=ortho;\n";
1013 out << " node [shape=box, style=\"rounded,filled\", fontname=\"Helvetica\"];\n";
1014 out << " edge [color=\"#666666\", arrowsize=0.7];\n\n";
1015 out << " labelloc=\"t\";\n";
1016 out << " label=\"" << detail::escape_dot_label(name) << "\";\n";
1017 out << " fontsize=16;\n\n";
1018
1019 // Traverse all elements in the SX matrix
1020 std::map<const void *, int> ptr_to_id;
1021 std::set<const void *> visited;
1022 std::vector<casadi::SXElem> queue;
1023 int node_counter = 0;
1024
1025 // Get nonzeros as SXElem vector
1026 const std::vector<casadi::SXElem> &nz = expr.nonzeros();
1027 casadi_int n_elem = static_cast<casadi_int>(nz.size());
1028
1029 // Collect all output elements
1030 for (casadi_int i = 0; i < n_elem; ++i) {
1031 queue.push_back(nz[i]);
1032 }
1033
1034 // First pass: assign IDs to all nodes
1035 size_t queue_idx = 0;
1036 while (queue_idx < queue.size()) {
1037 casadi::SXElem current = queue[queue_idx++];
1038 const void *ptr = current.get();
1039 if (visited.count(ptr))
1040 continue;
1041 visited.insert(ptr);
1042 ptr_to_id[ptr] = node_counter++;
1043
1044 casadi_int n = current.n_dep();
1045 for (casadi_int i = 0; i < n; ++i) {
1046 queue.push_back(current.dep(i));
1047 }
1048 }
1049
1050 // Second pass: generate nodes and edges
1051 visited.clear();
1052 queue.clear();
1053 for (casadi_int i = 0; i < n_elem; ++i) {
1054 queue.push_back(nz[i]);
1055 }
1056
1057 std::vector<int> output_node_ids;
1058 for (casadi_int i = 0; i < n_elem; ++i) {
1059 const void *ptr = nz[i].get();
1060 output_node_ids.push_back(ptr_to_id[ptr]);
1061 }
1062
1063 queue_idx = 0;
1064 while (queue_idx < queue.size()) {
1065 casadi::SXElem current = queue[queue_idx++];
1066 const void *ptr = current.get();
1067 if (visited.count(ptr))
1068 continue;
1069 visited.insert(ptr);
1070
1071 int current_id = ptr_to_id[ptr];
1072 std::string label = detail::get_sx_operation(current);
1073 std::string color, shape;
1074 detail::get_sx_node_style(current, color, shape);
1075
1076 out << " node_" << current_id << " [label=\"" << detail::escape_dot_label(label)
1077 << "\", fillcolor=\"" << color << "\", shape=" << shape << "];\n";
1078
1079 casadi_int n = current.n_dep();
1080 for (casadi_int i = 0; i < n; ++i) {
1081 casadi::SXElem dep = current.dep(i);
1082 const void *dep_ptr = dep.get();
1083 int dep_id = ptr_to_id[dep_ptr];
1084 out << " node_" << dep_id << " -> node_" << current_id << ";\n";
1085 queue.push_back(dep);
1086 }
1087 }
1088
1089 // Mark output nodes
1090 if (!output_node_ids.empty()) {
1091 out << "\n // Output markers\n";
1092 for (size_t i = 0; i < output_node_ids.size(); ++i) {
1093 out << " output_" << i << " [label=\"out[" << i
1094 << "]\", shape=doublecircle, fillcolor=\"#FFD700\"];\n";
1095 out << " node_" << output_node_ids[i] << " -> output_" << i << ";\n";
1096 }
1097 }
1098
1099 out << "}\n";
1100 out.close();
1101}
1102
1113inline void export_sx_graph_html(const casadi::SX &expr, const std::string &filename,
1114 const std::string &name = "expression") {
1115 std::ostringstream dot_stream;
1116 std::ostringstream node_data_stream;
1117 std::ostringstream edges_stream;
1118
1119 dot_stream << "digraph \"" << name << "\" {\n";
1120 dot_stream << " rankdir=BT;\n";
1121 dot_stream << " splines=ortho;\n";
1122 dot_stream << " node [shape=box, style=\"rounded,filled\", fontname=\"Helvetica\"];\n";
1123 dot_stream << " edge [color=\"#666666\", arrowsize=0.7];\n\n";
1124 dot_stream << " labelloc=\"t\";\n";
1125 dot_stream << " label=\"" << detail::escape_dot_label(name) << "\";\n";
1126 dot_stream << " fontsize=16;\n\n";
1127
1128 // Traverse all elements
1129 std::map<const void *, int> ptr_to_id;
1130 std::set<const void *> visited;
1131 std::vector<casadi::SXElem> queue;
1132 int node_counter = 0;
1133
1134 // Get nonzeros as SXElem vector
1135 const std::vector<casadi::SXElem> &nz = expr.nonzeros();
1136 casadi_int n_elem = static_cast<casadi_int>(nz.size());
1137
1138 for (casadi_int i = 0; i < n_elem; ++i) {
1139 queue.push_back(nz[i]);
1140 }
1141
1142 // First pass: assign IDs
1143 size_t queue_idx = 0;
1144 while (queue_idx < queue.size()) {
1145 casadi::SXElem current = queue[queue_idx++];
1146 const void *ptr = current.get();
1147 if (visited.count(ptr))
1148 continue;
1149 visited.insert(ptr);
1150 ptr_to_id[ptr] = node_counter++;
1151
1152 casadi_int n = current.n_dep();
1153 for (casadi_int i = 0; i < n; ++i) {
1154 queue.push_back(current.dep(i));
1155 }
1156 }
1157
1158 // Collect output node IDs
1159 std::vector<int> output_node_ids;
1160 for (casadi_int i = 0; i < n_elem; ++i) {
1161 const void *ptr = nz[i].get();
1162 output_node_ids.push_back(ptr_to_id[ptr]);
1163 }
1164
1165 // Second pass: generate nodes and edges
1166 visited.clear();
1167 queue.clear();
1168 for (casadi_int i = 0; i < n_elem; ++i) {
1169 queue.push_back(nz[i]);
1170 }
1171
1172 node_data_stream << "{";
1173 edges_stream << "[";
1174 bool first_node = true;
1175 bool first_edge = true;
1176
1177 queue_idx = 0;
1178 while (queue_idx < queue.size()) {
1179 casadi::SXElem current = queue[queue_idx++];
1180 const void *ptr = current.get();
1181 if (visited.count(ptr))
1182 continue;
1183 visited.insert(ptr);
1184
1185 int current_id = ptr_to_id[ptr];
1186 std::string label = detail::get_sx_operation(current);
1187 std::string color, shape;
1188 detail::get_sx_node_style(current, color, shape);
1189
1190 // Determine node type
1191 std::string node_type = "operation";
1192 if (current.is_symbolic())
1193 node_type = "input";
1194 else if (current.is_constant())
1195 node_type = "constant";
1196
1197 // Add to node data JSON
1198 if (!first_node)
1199 node_data_stream << ",";
1200 first_node = false;
1201 node_data_stream << "\"node_" << current_id << "\":{";
1202 node_data_stream << "\"id\":" << current_id << ",";
1203 node_data_stream << "\"label\":\"" << detail::escape_for_json(label) << "\",";
1204 node_data_stream << "\"type\":\"" << node_type << "\",";
1205 node_data_stream << "\"deps\":[";
1206
1207 casadi_int n = current.n_dep();
1208 for (casadi_int i = 0; i < n; ++i) {
1209 if (i > 0)
1210 node_data_stream << ",";
1211 casadi::SXElem dep = current.dep(i);
1212 node_data_stream << ptr_to_id[dep.get()];
1213 }
1214 node_data_stream << "]}";
1215
1216 dot_stream << " node_" << current_id << " [label=\"" << detail::escape_dot_label(label)
1217 << "\", fillcolor=\"" << color << "\", shape=" << shape << ", id=\"node_"
1218 << current_id << "\"];\n";
1219
1220 for (casadi_int i = 0; i < n; ++i) {
1221 casadi::SXElem dep = current.dep(i);
1222 const void *dep_ptr = dep.get();
1223 int dep_id = ptr_to_id[dep_ptr];
1224 dot_stream << " node_" << dep_id << " -> node_" << current_id << ";\n";
1225
1226 if (!first_edge)
1227 edges_stream << ",";
1228 first_edge = false;
1229 edges_stream << "[" << dep_id << "," << current_id << "]";
1230
1231 queue.push_back(dep);
1232 }
1233 }
1234
1235 // Mark outputs
1236 for (size_t i = 0; i < output_node_ids.size(); ++i) {
1237 dot_stream << " output_" << i << " [label=\"out[" << i
1238 << "]\", shape=doublecircle, fillcolor=\"#FFD700\", id=\"output_" << i
1239 << "\"];\n";
1240 dot_stream << " node_" << output_node_ids[i] << " -> output_" << i << ";\n";
1241
1242 if (!first_edge)
1243 edges_stream << ",";
1244 first_edge = false;
1245 edges_stream << "[" << output_node_ids[i] << ",-" << (i + 1) << "]";
1246 }
1247
1248 dot_stream << "}\n";
1249 node_data_stream << "}";
1250 edges_stream << "]";
1251
1252 std::string dot_content = dot_stream.str();
1253
1254 // Write HTML file
1255 std::string html_filename = filename + ".html";
1256 std::ofstream out(html_filename);
1257 if (!out.is_open()) {
1258 throw RuntimeError("Failed to open file for writing: " + html_filename);
1259 }
1260
1261 std::string stats_js = "const nodeCount = Object.keys(nodeData).length;\n"
1262 " document.getElementById('stats').textContent = "
1263 "'Nodes: ' + nodeCount + ' | Edges: ' + edges.length;";
1264 detail::write_graph_html(out, detail::escape_dot_label(name) + " - Deep Graph",
1265 detail::escape_for_js(dot_content), node_data_stream.str(),
1266 edges_stream.str(), stats_js);
1267 out.close();
1268}
1269
1278
1290inline void export_graph_deep(const casadi::Function &fn, const std::string &filename,
1292 const std::string &name = "") {
1293 // Use function name if no name provided
1294 std::string graph_name = name.empty() ? fn.name() : name;
1295
1296 // Expand function to SX (inlines all nested calls)
1297 casadi::Function expanded = fn.expand();
1298
1299 // Create SX symbolic inputs matching the function signature
1300 std::vector<casadi::SX> sx_inputs;
1301 for (casadi_int i = 0; i < expanded.n_in(); ++i) {
1302 sx_inputs.push_back(
1303 casadi::SX::sym(expanded.name_in(i), expanded.size1_in(i), expanded.size2_in(i)));
1304 }
1305
1306 // Evaluate to get SX outputs
1307 std::vector<casadi::SX> sx_outputs = expanded(sx_inputs);
1308
1309 // Combine all outputs into a single SX for visualization
1310 casadi::SX combined = casadi::SX::vertcat(sx_outputs);
1311
1312 // Export using SX-specific traversal
1313 switch (format) {
1315 export_sx_graph_dot(combined, filename, graph_name);
1316 break;
1318 export_sx_graph_html(combined, filename, graph_name);
1319 break;
1321 export_sx_graph_dot(combined, filename, graph_name);
1322 render_graph(filename + ".dot", filename + ".pdf");
1323 break;
1324 }
1325}
1326
1334inline bool visualize_graph_deep(const casadi::Function &fn, const std::string &output_base) {
1335 try {
1336 export_graph_deep(fn, output_base, DeepGraphFormat::PDF);
1337 return true;
1338 } catch (const std::exception &) {
1339 return false;
1340 }
1341}
1342
1343} // namespace janus
Custom exception hierarchy for Janus framework.
Core type aliases for numeric and symbolic Eigen/CasADi interop.
Wrapper around casadi::Function providing Eigen-native IO.
Definition Function.hpp:46
Operation failed at runtime (e.g., CasADi eval with free variables).
Definition JanusError.hpp:41
std::string escape_for_json(const std::string &s)
Escape a string for embedding in JSON.
Definition JanusIO.hpp:409
void get_sx_node_style(const casadi::SXElem &elem, std::string &color, std::string &shape)
Get node styling based on SX element type.
Definition JanusIO.hpp:957
std::string get_op_name(const SymbolicScalar &mx)
Get a short description of an MX operation type.
Definition JanusIO.hpp:168
void write_graph_html(std::ostream &out, const std::string &title, const std::string &escaped_dot, const std::string &node_data_json, const std::string &edges_json, const std::string &extra_header_js="")
Write a complete interactive graph HTML page.
Definition JanusIO.hpp:437
std::string escape_dot_label(const std::string &s)
Escape special characters for DOT format.
Definition JanusIO.hpp:130
std::string get_sx_operation(const casadi::SXElem &elem)
Map CasADi SX operation codes to readable labels.
Definition JanusIO.hpp:830
std::string escape_for_js(const std::string &content)
Escape a string for embedding in a JavaScript string literal.
Definition JanusIO.hpp:387
Definition Diagnostics.hpp:19
void export_graph_html(const SymbolicScalar &expr, const std::string &filename, const std::string &name="expression")
Export a symbolic expression to an interactive HTML file.
Definition JanusIO.hpp:663
void export_graph_deep(const casadi::Function &fn, const std::string &filename, DeepGraphFormat format=DeepGraphFormat::HTML, const std::string &name="")
Export a CasADi Function to deep graph format showing all operations.
Definition JanusIO.hpp:1290
JanusMatrix< NumericScalar > NumericMatrix
Eigen::MatrixXd equivalent.
Definition JanusTypes.hpp:66
bool visualize_graph(const SymbolicScalar &expr, const std::string &output_base)
Convenience function: export expression to DOT and render to PDF.
Definition JanusIO.hpp:373
void export_graph_dot(const SymbolicScalar &expr, const std::string &filename, const std::string &name="expression")
Export a symbolic expression to DOT format for visualization.
Definition JanusIO.hpp:221
bool render_graph(const std::string &dot_file, const std::string &output_file)
Export a janus::Function to DOT format for visualization.
Definition JanusIO.hpp:349
void export_sx_graph_html(const casadi::SX &expr, const std::string &filename, const std::string &name="expression")
Export an SX expression to an interactive HTML file for deep visualization.
Definition JanusIO.hpp:1113
DeepGraphFormat
Export format for deep graph visualization.
Definition JanusIO.hpp:1273
@ DOT
Graphviz DOT text format.
Definition JanusIO.hpp:1274
@ HTML
Self-contained interactive HTML.
Definition JanusIO.hpp:1275
@ PDF
Rendered PDF via Graphviz.
Definition JanusIO.hpp:1276
bool visualize_graph_deep(const casadi::Function &fn, const std::string &output_base)
Convenience function: export Function to deep graph and render to PDF.
Definition JanusIO.hpp:1334
void print(const std::string &label, const Eigen::MatrixBase< Derived > &mat)
Print a matrix to stdout with a label.
Definition JanusIO.hpp:40
auto eval(const Eigen::MatrixBase< Derived > &mat)
Evaluate a symbolic matrix to a numeric Eigen matrix.
Definition JanusIO.hpp:62
void disp(const std::string &label, const Eigen::MatrixBase< Derived > &mat)
Deprecated alias for print.
Definition JanusIO.hpp:52
casadi::MX SymbolicScalar
CasADi MX symbolic scalar.
Definition JanusTypes.hpp:70
void export_sx_graph_dot(const casadi::SX &expr, const std::string &filename, const std::string &name="expression")
Export an SX expression to DOT format for deep visualization.
Definition JanusIO.hpp:1002