Janus 2.0.0
High-performance C++20 dual-mode numerical framework
Loading...
Searching...
No Matches
Quaternion.hpp
Go to the documentation of this file.
1#pragma once
7
11#include "janus/math/Linalg.hpp"
12#include "janus/math/Logic.hpp"
13#include "janus/math/Trig.hpp"
14
15namespace janus {
16
25template <typename Scalar> class Quaternion {
26 public:
27 Scalar w, x, y, z;
28
29 // --- Constructors ---
30
33 : w(static_cast<Scalar>(1.0)), x(static_cast<Scalar>(0.0)), y(static_cast<Scalar>(0.0)),
34 z(static_cast<Scalar>(0.0)) {}
35
37 Quaternion(Scalar w, Scalar x, Scalar y, Scalar z) : w(w), x(x), y(y), z(z) {}
38
40 explicit Quaternion(const Vec4<Scalar> &v) : w(v(0)), x(v(1)), y(v(2)), z(v(3)) {}
41
42 // --- Algebraic Operations ---
43
47 Quaternion operator*(const Quaternion &other) const {
48 return Quaternion(w * other.w - x * other.x - y * other.y - z * other.z,
49 w * other.x + x * other.w + y * other.z - z * other.y,
50 w * other.y - x * other.z + y * other.w + z * other.x,
51 w * other.z + x * other.y - y * other.x + z * other.w);
52 }
53
57 Quaternion operator*(const Scalar &s) const { return Quaternion(w * s, x * s, y * s, z * s); }
58
62 Quaternion operator+(const Quaternion &other) const {
63 return Quaternion(w + other.w, x + other.x, y + other.y, z + other.z);
64 }
65
68 Quaternion conjugate() const { return Quaternion(w, -x, -y, -z); }
69
72 Quaternion inverse() const { return conjugate() * (static_cast<Scalar>(1.0) / squared_norm()); }
73
76 Scalar squared_norm() const { return w * w + x * x + y * y + z * z; }
77
80 Scalar norm() const { return janus::sqrt(squared_norm()); }
81
85 Scalar n = norm();
86 // Avoid division by zero check for symbolic if possible,
87 // strictly speaking we should probably use a safe variant or assume it's not zero.
88 // For now, standard division.
89 return Quaternion(w / n, x / n, y / n, z / n);
90 }
91
96 // Optimization: q * (0, v) * q_conj
97 // Or specific formula: v + 2 * cross(q_vec, cross(q_vec, v) + q_w * v)
98
99 Scalar q0 = w;
100 Scalar q1 = x;
101 Scalar q2 = y;
102 Scalar q3 = z;
103
104 // Extract vector part of quaternion
105 Vec3<Scalar> q_vec;
106 q_vec << q1, q2, q3;
107
108 Vec3<Scalar> t = static_cast<Scalar>(2.0) * janus::cross(q_vec, v);
109 return v + (q0 * t) + janus::cross(q_vec, t);
110 }
111
112 // --- Conversions ---
113
117 Mat3<Scalar> R;
118 Scalar one = static_cast<Scalar>(1.0);
119 Scalar two = static_cast<Scalar>(2.0);
120
121 Scalar xx = x * x;
122 Scalar yy = y * y;
123 Scalar zz = z * z;
124 Scalar xy = x * y;
125 Scalar xz = x * z;
126 Scalar yz = y * z;
127 Scalar wx = w * x;
128 Scalar wy = w * y;
129 Scalar wz = w * z;
130
131 R(0, 0) = one - two * (yy + zz);
132 R(0, 1) = two * (xy - wz);
133 R(0, 2) = two * (xz + wy);
134
135 R(1, 0) = two * (xy + wz);
136 R(1, 1) = one - two * (xx + zz);
137 R(1, 2) = two * (yz - wx);
138
139 R(2, 0) = two * (xz - wy);
140 R(2, 1) = two * (yz + wx);
141 R(2, 2) = one - two * (xx + yy);
142
143 return R;
144 }
145
149 Vec4<Scalar> res;
150 res << w, x, y, z;
151 return res;
152 }
153
154 // --- Static Factories ---
155
164 static Quaternion from_euler(Scalar roll, Scalar pitch, Scalar yaw) {
165 Scalar half = static_cast<Scalar>(0.5);
166 Scalar cr = janus::cos(roll * half);
167 Scalar sr = janus::sin(roll * half);
168 Scalar cp = janus::cos(pitch * half);
169 Scalar sp = janus::sin(pitch * half);
170 Scalar cy = janus::cos(yaw * half);
171 Scalar sy = janus::sin(yaw * half);
172
173 return Quaternion(cr * cp * cy + sr * sp * sy, // w
174 sr * cp * cy - cr * sp * sy, // x
175 cr * sp * cy + sr * cp * sy, // y
176 cr * cp * sy - sr * sp * cy // z
177 );
178 }
179
186 static Quaternion from_axis_angle(const Vec3<Scalar> &axis, Scalar angle) {
187 Scalar half = static_cast<Scalar>(0.5);
188 Scalar s = janus::sin(angle * half);
189 Scalar c = janus::cos(angle * half);
190
191 // Assume axis is normalized? Usually safer to normalize.
192 // If symbolic, normalization adds complexity, but for correctness it's good.
193 // Let's assume user passes normalized axis or we normalize it.
194 // Standard library implementations usually assume normalized or normalize.
195 // We will normalize to be safe.
196 auto n_axis = axis / janus::norm(axis);
197
198 return Quaternion(c, n_axis(0) * s, n_axis(1) * s, n_axis(2) * s);
199 }
200
207 Scalar half = static_cast<Scalar>(0.5);
208 Scalar eps = static_cast<Scalar>(1e-12);
209 Scalar angle = janus::norm(rot_vec);
210 Scalar half_angle = angle * half;
211 Scalar safe_angle = angle + eps;
212
213 // sin(angle/2)/angle with small-angle fallback (limit = 0.5)
214 Scalar scale_raw = janus::sin(half_angle) / safe_angle;
215 Scalar scale = janus::where(angle > eps, scale_raw, half);
216
217 return Quaternion(janus::cos(half_angle), rot_vec(0) * scale, rot_vec(1) * scale,
218 rot_vec(2) * scale);
219 }
220
227 // Implementation based on standard robust algorithms (e.g., Eigen's or Shepperd's)
228 // Here we use a simplified version for brevity but covering standard cases.
229 // For symbolic compatibility, we need to be careful with branching.
230 // It is notoriously hard to do robust rotation matrix -> quaternion symbolicly because of
231 // the 4-way branching. If we must be symbolic, we might pick one branch (e.g. max trace)
232 // and hope, or use 'where'.
233
234 // For now, let's implement a standard numeric-friendly trace check.
235 // If Scalar is symbolic (casadi::MX), regular if/else won't work on values.
236
237 Scalar trace = mat.trace();
238 Scalar q_w, q_x, q_y, q_z;
239 Scalar one = static_cast<Scalar>(1.0);
240 Scalar half = static_cast<Scalar>(0.5);
241 Scalar two = static_cast<Scalar>(2.0);
242
243 if constexpr (std::is_floating_point_v<Scalar>) {
244 // Numeric implementation (efficient branching)
245 if (trace > 0) {
246 Scalar s = static_cast<Scalar>(0.5) / janus::sqrt(trace + 1.0);
247 q_w = 0.25 / s;
248 q_x = (mat(2, 1) - mat(1, 2)) * s;
249 q_y = (mat(0, 2) - mat(2, 0)) * s;
250 q_z = (mat(1, 0) - mat(0, 1)) * s;
251 } else {
252 if (mat(0, 0) > mat(1, 1) && mat(0, 0) > mat(2, 2)) {
253 Scalar s = 2.0 * janus::sqrt(1.0 + mat(0, 0) - mat(1, 1) - mat(2, 2));
254 q_w = (mat(2, 1) - mat(1, 2)) / s;
255 q_x = 0.25 * s;
256 q_y = (mat(0, 1) + mat(1, 0)) / s;
257 q_z = (mat(0, 2) + mat(2, 0)) / s;
258 } else if (mat(1, 1) > mat(2, 2)) {
259 Scalar s = 2.0 * janus::sqrt(1.0 + mat(1, 1) - mat(0, 0) - mat(2, 2));
260 q_w = (mat(0, 2) - mat(2, 0)) / s;
261 q_x = (mat(0, 1) + mat(1, 0)) / s;
262 q_y = 0.25 * s;
263 q_z = (mat(1, 2) + mat(2, 1)) / s;
264 } else {
265 Scalar s = 2.0 * janus::sqrt(1.0 + mat(2, 2) - mat(0, 0) - mat(1, 1));
266 q_w = (mat(1, 0) - mat(0, 1)) / s;
267 q_x = (mat(0, 2) + mat(2, 0)) / s;
268 q_y = (mat(1, 2) + mat(2, 1)) / s;
269 q_z = 0.25 * s;
270 }
271 }
272 } else {
273 // Symbolic: Full 4-branch using nested janus::where (Shepperd's method)
274
275 // Guard radicands: in symbolic mode all branches are eagerly evaluated,
276 // so untaken branches can have negative radicands. Clamp to eps.
277 Scalar eps = static_cast<Scalar>(1e-12);
278
279 // Branch 0: trace > 0
280 Scalar r0 = janus::where(trace + one > eps, trace + one, eps);
281 Scalar s0 = half / janus::sqrt(r0);
282 Scalar w0 = static_cast<Scalar>(0.25) / s0;
283 Scalar x0 = (mat(2, 1) - mat(1, 2)) * s0;
284 Scalar y0 = (mat(0, 2) - mat(2, 0)) * s0;
285 Scalar z0 = (mat(1, 0) - mat(0, 1)) * s0;
286
287 // Branch 1: mat(0,0) is largest diagonal
288 Scalar r1 = one + mat(0, 0) - mat(1, 1) - mat(2, 2);
289 Scalar safe_r1 = janus::where(r1 > eps, r1, eps);
290 Scalar s1 = two * janus::sqrt(safe_r1);
291 Scalar w1 = (mat(2, 1) - mat(1, 2)) / s1;
292 Scalar x1 = static_cast<Scalar>(0.25) * s1;
293 Scalar y1 = (mat(0, 1) + mat(1, 0)) / s1;
294 Scalar z1 = (mat(0, 2) + mat(2, 0)) / s1;
295
296 // Branch 2: mat(1,1) is largest diagonal
297 Scalar r2 = one + mat(1, 1) - mat(0, 0) - mat(2, 2);
298 Scalar safe_r2 = janus::where(r2 > eps, r2, eps);
299 Scalar s2 = two * janus::sqrt(safe_r2);
300 Scalar w2 = (mat(0, 2) - mat(2, 0)) / s2;
301 Scalar x2 = (mat(0, 1) + mat(1, 0)) / s2;
302 Scalar y2 = static_cast<Scalar>(0.25) * s2;
303 Scalar z2 = (mat(1, 2) + mat(2, 1)) / s2;
304
305 // Branch 3: mat(2,2) is largest diagonal
306 Scalar r3 = one + mat(2, 2) - mat(0, 0) - mat(1, 1);
307 Scalar safe_r3 = janus::where(r3 > eps, r3, eps);
308 Scalar s3 = two * janus::sqrt(safe_r3);
309 Scalar w3 = (mat(1, 0) - mat(0, 1)) / s3;
310 Scalar x3 = (mat(0, 2) + mat(2, 0)) / s3;
311 Scalar y3 = (mat(1, 2) + mat(2, 1)) / s3;
312 Scalar z3 = static_cast<Scalar>(0.25) * s3;
313
314 // Select via nested where
315 auto cond_trace = trace > static_cast<Scalar>(0.0);
316 auto cond_r00 = janus::logical_and(mat(0, 0) > mat(1, 1), mat(0, 0) > mat(2, 2));
317 auto cond_r11 = mat(1, 1) > mat(2, 2);
318
319 // Inner: branch2 vs branch3
320 Scalar wi = janus::where(cond_r11, w2, w3);
321 Scalar xi = janus::where(cond_r11, x2, x3);
322 Scalar yi = janus::where(cond_r11, y2, y3);
323 Scalar zi = janus::where(cond_r11, z2, z3);
324
325 // Middle: branch1 vs inner
326 Scalar wm = janus::where(cond_r00, w1, wi);
327 Scalar xm = janus::where(cond_r00, x1, xi);
328 Scalar ym = janus::where(cond_r00, y1, yi);
329 Scalar zm = janus::where(cond_r00, z1, zi);
330
331 // Outer: branch0 vs middle
332 q_w = janus::where(cond_trace, w0, wm);
333 q_x = janus::where(cond_trace, x0, xm);
334 q_y = janus::where(cond_trace, y0, ym);
335 q_z = janus::where(cond_trace, z0, zm);
336 }
337 return Quaternion(q_w, q_x, q_y, q_z);
338 }
339
343 // Roll (x-axis rotation)
344 Scalar sinr_cosp = static_cast<Scalar>(2.0) * (w * x + y * z);
345 Scalar cosr_cosp = static_cast<Scalar>(1.0) - static_cast<Scalar>(2.0) * (x * x + y * y);
346 Scalar roll = janus::atan2(sinr_cosp, cosr_cosp);
347
348 // Pitch (y-axis rotation)
349 Scalar sinp = static_cast<Scalar>(2.0) * (w * y - z * x);
350 Scalar pitch;
351 // Check for gimbal lock
352 // Logic::where ideally
353 if constexpr (std::is_floating_point_v<Scalar>) {
354 if (std::abs(sinp) >= 1)
355 pitch = std::copysign(std::numbers::pi_v<double> / 2,
356 sinp); // use 90 degrees if out of range
357 else
358 pitch = std::asin(sinp);
359 } else {
360 // Symbolic: assume no gimbal lock or underlying library handles asin domain
361 pitch = janus::asin(sinp);
362 }
363
364 // Yaw (z-axis rotation)
365 Scalar siny_cosp = static_cast<Scalar>(2.0) * (w * z + x * y);
366 Scalar cosy_cosp = static_cast<Scalar>(1.0) - static_cast<Scalar>(2.0) * (y * y + z * z);
367 Scalar yaw = janus::atan2(siny_cosp, cosy_cosp);
368
369 return Vec3<Scalar>(roll, pitch, yaw);
370 }
371};
372
373// --- Free Functions ---
374
386template <typename Scalar>
388 Scalar one = static_cast<Scalar>(1.0);
389 Scalar zero = static_cast<Scalar>(0.0);
390 Scalar dot_threshold = static_cast<Scalar>(0.9995); // Threshold for linear fallback
391
392 // Compute dot product
393 Scalar dot = q0.w * q1.w + q0.x * q1.x + q0.y * q1.y + q0.z * q1.z;
394
395 // --- Shortest path fix ---
396 // If dot < 0, negate q1 to take shorter arc.
397 // We compute a sign factor: sign = where(dot < 0, -1, 1)
398 Scalar sign = janus::where(dot < zero, -one, one);
399
400 // Effective q1 and dot (flipped if needed)
401 Quaternion<Scalar> q1_eff(q1.w * sign, q1.x * sign, q1.y * sign, q1.z * sign);
402 Scalar dot_eff = dot * sign; // Now dot_eff >= 0
403
404 // --- Numerical stability: handle near-identity case ---
405 // If dot_eff is very close to 1, theta ≈ 0 and sin(theta) ≈ 0 (division issues).
406 // Fall back to normalized linear interpolation (nlerp).
407
408 Scalar theta = janus::acos(dot_eff);
409 Scalar sin_theta = janus::sin(theta);
410
411 // Compute slerp weights
412 Scalar wa_slerp = janus::sin((one - t) * theta) / sin_theta;
413 Scalar wb_slerp = janus::sin(t * theta) / sin_theta;
414
415 // Compute nlerp weights (simple linear blend, then normalize result)
416 Scalar wa_nlerp = one - t;
417 Scalar wb_nlerp = t;
418
419 // Use janus::where to select between slerp and nlerp based on dot_eff
420 Scalar use_slerp = dot_eff < dot_threshold; // True if slerp is safe
421
422 Scalar wa = janus::where(use_slerp, wa_slerp, wa_nlerp);
423 Scalar wb = janus::where(use_slerp, wb_slerp, wb_nlerp);
424
425 // Compute result
426 Quaternion<Scalar> result = q0 * wa + q1_eff * wb;
427
428 // Normalize for nlerp case (harmless for slerp case, just ensures unit quaternion)
429 return result.normalized();
430}
431
432} // namespace janus
Scalar and element-wise arithmetic functions (abs, sqrt, pow, exp, log, etc.).
C++20 concepts constraining valid Janus scalar types.
Core type aliases for numeric and symbolic Eigen/CasADi interop.
Linear algebra operations (solve, inverse, determinant, eigendecomposition, norms).
Conditional selection, comparison, and logical operations.
Trigonometric and inverse trigonometric functions.
Quaternion class for rotation representation.
Definition Quaternion.hpp:25
Scalar z
Definition Quaternion.hpp:27
Quaternion(const Vec4< Scalar > &v)
Construct from Vec4 [w, x, y, z].
Definition Quaternion.hpp:40
static Quaternion from_axis_angle(const Vec3< Scalar > &axis, Scalar angle)
Create from axis-angle representation.
Definition Quaternion.hpp:186
Scalar y
Definition Quaternion.hpp:27
Vec3< Scalar > to_euler() const
Extract Euler angles (Roll-Pitch-Yaw / XYZ).
Definition Quaternion.hpp:342
static Quaternion from_euler(Scalar roll, Scalar pitch, Scalar yaw)
Create from Euler Angles (Yaw-Pitch-Roll / Z-Y-X sequence).
Definition Quaternion.hpp:164
Quaternion operator*(const Scalar &s) const
Scalar multiplication.
Definition Quaternion.hpp:57
Scalar x
Definition Quaternion.hpp:27
Scalar squared_norm() const
Squared norm (w^2 + x^2 + y^2 + z^2).
Definition Quaternion.hpp:76
Scalar norm() const
Quaternion norm.
Definition Quaternion.hpp:80
Quaternion inverse() const
Inverse (conjugate / norm_sq).
Definition Quaternion.hpp:72
Vec4< Scalar > coeffs() const
Export as vector [w, x, y, z].
Definition Quaternion.hpp:148
Mat3< Scalar > to_rotation_matrix() const
Convert to 3x3 rotation matrix.
Definition Quaternion.hpp:116
Quaternion normalized() const
Return unit quaternion.
Definition Quaternion.hpp:84
Vec3< Scalar > rotate(const Vec3< Scalar > &v) const
Rotate a 3D vector: v_rot = q * v * q_conj.
Definition Quaternion.hpp:95
Scalar w
Definition Quaternion.hpp:27
Quaternion(Scalar w, Scalar x, Scalar y, Scalar z)
Component constructor.
Definition Quaternion.hpp:37
Quaternion conjugate() const
Conjugate (w, -x, -y, -z).
Definition Quaternion.hpp:68
static Quaternion from_rotation_matrix(const Mat3< Scalar > &mat)
Create from 3x3 rotation matrix.
Definition Quaternion.hpp:226
static Quaternion from_rotation_vector(const Vec3< Scalar > &rot_vec)
Create from rotation vector (axis * angle).
Definition Quaternion.hpp:206
Quaternion operator+(const Quaternion &other) const
Quaternion addition.
Definition Quaternion.hpp:62
Quaternion operator*(const Quaternion &other) const
Hamilton product.
Definition Quaternion.hpp:47
Quaternion()
Default constructor: Identity quaternion (1, 0, 0, 0).
Definition Quaternion.hpp:32
Definition Diagnostics.hpp:19
auto where(const Cond &cond, const T1 &if_true, const T2 &if_false)
Select values based on condition (ternary operator) Returns: cond ? if_true : if_false Supports mixed...
Definition Logic.hpp:43
auto dot(const Eigen::MatrixBase< DerivedA > &a, const Eigen::MatrixBase< DerivedB > &b)
Computes dot product of two vectors.
Definition Linalg.hpp:500
T sqrt(const T &x)
Computes the square root of a scalar.
Definition Arithmetic.hpp:46
T acos(const T &x)
Computes arc cosine of x.
Definition Trig.hpp:121
Eigen::Matrix< Scalar, 4, 1 > Vec4
Definition JanusTypes.hpp:58
Quaternion< Scalar > slerp(const Quaternion< Scalar > &q0, const Quaternion< Scalar > &q1, Scalar t)
Spherical Linear Interpolation (full fidelity).
Definition Quaternion.hpp:387
T cos(const T &x)
Computes cosine of x.
Definition Trig.hpp:46
T sign(const T &x)
Computes sign of x.
Definition Arithmetic.hpp:326
Eigen::Matrix< Scalar, 3, 3 > Mat3
Definition JanusTypes.hpp:61
auto norm(const Eigen::MatrixBase< Derived > &x, NormType type=NormType::L2)
Computes vector/matrix norm.
Definition Linalg.hpp:607
T asin(const T &x)
Computes arc sine of x.
Definition Trig.hpp:96
T sin(const T &x)
Computes sine of x.
Definition Trig.hpp:21
auto logical_and(const T1 &x1, const T2 &x2)
Logical AND (x && y).
Definition Logic.hpp:422
Eigen::Matrix< Scalar, 3, 1 > Vec3
Definition JanusTypes.hpp:57
auto cross(const Eigen::MatrixBase< DerivedA > &a, const Eigen::MatrixBase< DerivedB > &b)
Computes 3D cross product.
Definition Linalg.hpp:513
T atan2(const T &y, const T &x)
Computes arc tangent of y/x using signs of both arguments.
Definition Trig.hpp:172