Janus 2.0.0
High-performance C++20 dual-mode numerical framework
Loading...
Searching...
No Matches
Logic.hpp
Go to the documentation of this file.
1#pragma once
7
12#include <Eigen/Dense>
13#include <algorithm>
14#include <type_traits>
15#include <utility>
16
17namespace janus {
18
19// --- Trait to deduce the boolean type for a scalar ---
20template <typename T> struct BooleanType {
21 using type = bool;
22};
23
24template <> struct BooleanType<SymbolicScalar> {
26};
27
28template <typename T> using BooleanType_t = typename BooleanType<T>::type;
29
30// --- where (Scalar) ---
41// Relaxed to allow mixed types (e.g. MX and double)
42template <typename Cond, JanusScalar T1, JanusScalar T2>
43auto where(const Cond &cond, const T1 &if_true, const T2 &if_false) {
44 if constexpr (std::is_floating_point_v<T1> && std::is_floating_point_v<T2>) {
45 return cond ? if_true : if_false;
46 } else {
47 // Assume non-floating point (e.g. CasADi) handles mixed ops
48 return if_else(cond, if_true, if_false);
49 }
50}
51
52namespace detail {
53
54template <typename Cond, typename DerivedTrue, typename DerivedFalse>
55auto select(const Cond &cond, const Eigen::MatrixBase<DerivedTrue> &if_true,
56 const Eigen::MatrixBase<DerivedFalse> &if_false) {
57 if (if_true.rows() != if_false.rows() || if_true.cols() != if_false.cols()) {
58 throw InvalidArgument("select: matrix inputs must have the same shape");
59 }
60
61 using ResultScalar = std::decay_t<decltype(janus::where(
62 std::declval<Cond>(), std::declval<typename DerivedTrue::Scalar>(),
63 std::declval<typename DerivedFalse::Scalar>()))>;
64 using ResultMatrix =
65 Eigen::Matrix<ResultScalar, DerivedTrue::RowsAtCompileTime, DerivedTrue::ColsAtCompileTime,
66 DerivedTrue::Options, DerivedTrue::MaxRowsAtCompileTime,
67 DerivedTrue::MaxColsAtCompileTime>;
68
69 ResultMatrix res(if_true.rows(), if_true.cols());
70 for (Eigen::Index i = 0; i < if_true.rows(); ++i) {
71 for (Eigen::Index j = 0; j < if_true.cols(); ++j) {
72 res(i, j) = janus::where(cond, if_true(i, j), if_false(i, j));
73 }
74 }
75 return res;
76}
77
78inline bool is_symbolic_predicate(const SymbolicScalar &expr) {
79 if (expr.n_dep() == 0) {
80 return false;
81 }
82
83 switch (expr.op()) {
84 case casadi::OP_LT:
85 case casadi::OP_LE:
86 case casadi::OP_EQ:
87 case casadi::OP_NE:
88 case casadi::OP_AND:
89 case casadi::OP_OR:
90 case casadi::OP_NOT:
91 return true;
92 default:
93 return false;
94 }
95}
96
98 return is_symbolic_predicate(expr) ? expr : expr != 0;
99}
100
101template <typename Derived>
102SymbolicScalar count_symbolic_truthy(const Eigen::MatrixBase<Derived> &a) {
103 SymbolicScalar result = 0.0;
104 for (Eigen::Index i = 0; i < a.rows(); ++i) {
105 for (Eigen::Index j = 0; j < a.cols(); ++j) {
106 result = result + as_symbolic_predicate(a(i, j));
107 }
108 }
109 return result;
110}
111
112} // namespace detail
113
114// --- where (Vector/Matrix) ---
123template <typename DerivedCond, typename DerivedTrue, typename DerivedFalse>
124auto where(const Eigen::ArrayBase<DerivedCond> &cond, const Eigen::MatrixBase<DerivedTrue> &if_true,
125 const Eigen::MatrixBase<DerivedFalse> &if_false) {
126 using Scalar = typename DerivedTrue::Scalar;
127 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
128 // Manual element-wise select for generic types/CasADi
129 // Assuming if_true and if_false have same dimensions as cond
130 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> res(if_true.rows(), if_true.cols());
131 for (Eigen::Index i = 0; i < if_true.rows(); ++i) {
132 for (Eigen::Index j = 0; j < if_true.cols(); ++j) {
133 // cond(i,j) might be an expression, evaluate it.
134 res(i, j) = janus::where(cond.derived().coeff(i, j), if_true(i, j), if_false(i, j));
135 }
136 }
137 return res;
138 } else {
139 return cond.select(if_true, if_false);
140 }
141}
142
143// --- Min ---
150// Relaxed for mixed types
151template <JanusScalar T1, JanusScalar T2> auto min(const T1 &a, const T2 &b) {
152 if constexpr (std::is_floating_point_v<T1> && std::is_floating_point_v<T2>) {
153 return std::min(a, b);
154 } else {
155 // use fmin for mixed (fmin(double, MX) works in CasADi)
156 // std::min(double, MX) does NOT work usually?
157 // Actually, CasADi overloads fmin.
158 return fmin(a, b);
159 }
160}
161
169template <typename Derived>
170auto min(const Eigen::MatrixBase<Derived> &a, const Eigen::MatrixBase<Derived> &b) {
171 using Scalar = typename Derived::Scalar;
172 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
173 Eigen::Matrix<Scalar, Derived::RowsAtCompileTime, Derived::ColsAtCompileTime> res(a.rows(),
174 a.cols());
175 for (Eigen::Index i = 0; i < a.rows(); ++i) {
176 for (Eigen::Index j = 0; j < a.cols(); ++j) {
177 res(i, j) = janus::min(a(i, j), b(i, j));
178 }
179 }
180 return res;
181 } else {
182 return a.cwiseMin(b);
183 }
184}
185
186// --- Max ---
193// Relaxed for mixed types
194template <JanusScalar T1, JanusScalar T2> auto max(const T1 &a, const T2 &b) {
195 if constexpr (std::is_floating_point_v<T1> && std::is_floating_point_v<T2>) {
196 return std::max(a, b);
197 } else {
198 return fmax(a, b);
199 }
200}
201
209template <typename Derived>
210auto max(const Eigen::MatrixBase<Derived> &a, const Eigen::MatrixBase<Derived> &b) {
211 using Scalar = typename Derived::Scalar;
212 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
213 Eigen::Matrix<Scalar, Derived::RowsAtCompileTime, Derived::ColsAtCompileTime> res(a.rows(),
214 a.cols());
215 for (Eigen::Index i = 0; i < a.rows(); ++i) {
216 for (Eigen::Index j = 0; j < a.cols(); ++j) {
217 res(i, j) = janus::max(a(i, j), b(i, j));
218 }
219 }
220 return res;
221 } else {
222 return a.cwiseMax(b);
223 }
224}
225
226// --- Clamp ---
234// Relaxed for mixed types
235template <JanusScalar T, JanusScalar TLow, JanusScalar THigh>
236auto clamp(const T &val, const TLow &low, const THigh &high) {
237 return janus::min(janus::max(val, low), high);
238}
239
249template <typename Derived, typename Scalar>
250auto clamp(const Eigen::MatrixBase<Derived> &val, const Scalar &low, const Scalar &high) {
251 using MatrixScalar = typename Derived::Scalar;
252 if constexpr (std::is_same_v<MatrixScalar, SymbolicScalar>) {
253 return val.unaryExpr([=](const auto &x) { return janus::clamp(x, low, high); });
254 } else {
255 return val.cwiseMax(low).cwiseMin(high);
256 }
257}
258
259// --- Less Than (lt) ---
268// Returns expressions suitable for 'where' condition
269template <typename DerivedA, typename DerivedB>
270auto lt(const Eigen::MatrixBase<DerivedA> &a, const Eigen::MatrixBase<DerivedB> &b) {
271 using Scalar = typename DerivedA::Scalar;
272 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
273 return a.binaryExpr(b, [](const auto &x, const auto &y) { return x < y; });
274 } else {
275 return (a.array() < b.array());
276 }
277}
278
279// --- Greater Than (gt) ---
288template <typename DerivedA, typename DerivedB>
289auto gt(const Eigen::MatrixBase<DerivedA> &a, const Eigen::MatrixBase<DerivedB> &b) {
290 using Scalar = typename DerivedA::Scalar;
291 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
292 return a.binaryExpr(b, [](const auto &x, const auto &y) { return x > y; });
293 } else {
294 return (a.array() > b.array());
295 }
296}
297
298// --- Less Than or Equal (le) ---
307template <typename DerivedA, typename DerivedB>
308auto le(const Eigen::MatrixBase<DerivedA> &a, const Eigen::MatrixBase<DerivedB> &b) {
309 using Scalar = typename DerivedA::Scalar;
310 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
311 return a.binaryExpr(b, [](const auto &x, const auto &y) { return x <= y; });
312 } else {
313 return (a.array() <= b.array());
314 }
315}
316
317// --- Greater Than or Equal (ge) ---
326template <typename DerivedA, typename DerivedB>
327auto ge(const Eigen::MatrixBase<DerivedA> &a, const Eigen::MatrixBase<DerivedB> &b) {
328 using Scalar = typename DerivedA::Scalar;
329 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
330 return a.binaryExpr(b, [](const auto &x, const auto &y) { return x >= y; });
331 } else {
332 return (a.array() >= b.array());
333 }
334}
335
336// --- Equal (eq) ---
345template <typename DerivedA, typename DerivedB>
346auto eq(const Eigen::MatrixBase<DerivedA> &a, const Eigen::MatrixBase<DerivedB> &b) {
347 using Scalar = typename DerivedA::Scalar;
348 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
349 return a.binaryExpr(b, [](const auto &x, const auto &y) { return x == y; });
350 } else {
351 return (a.array() == b.array());
352 }
353}
354
355// --- Not Equal (neq) ---
364template <typename DerivedA, typename DerivedB>
365auto neq(const Eigen::MatrixBase<DerivedA> &a, const Eigen::MatrixBase<DerivedB> &b) {
366 using Scalar = typename DerivedA::Scalar;
367 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
368 return a.binaryExpr(b, [](const auto &x, const auto &y) { return x != y; });
369 } else {
370 return (a.array() != b.array());
371 }
372}
373
374// --- sigmoid_blend ---
385// Relaxed for mixed types
386template <JanusScalar T, JanusScalar TLow, JanusScalar THigh, JanusScalar Sharpness = double>
387auto sigmoid_blend(const T &x, const TLow &val_low, const THigh &val_high,
388 const Sharpness &sharpness = 1.0) {
389 // using janus::exp from Arithmetic.hpp
390 auto alpha = 1.0 / (1.0 + janus::exp(-sharpness * x));
391 return val_low + alpha * (val_high - val_low);
392}
393
394// Vectorized sigmoid_blend could be added here if needed,
395// strictly relying on .array() operations in implementation code might be enough
396// if we make a vectorized wrapper like in Arithmetic.hpp
397
408template <typename Derived, typename Scalar>
409auto sigmoid_blend(const Eigen::MatrixBase<Derived> &x, const Scalar &val_low,
410 const Scalar &val_high, const Scalar &sharpness = 1.0) {
411 auto alpha = (1.0 + (-sharpness * x.array()).exp()).inverse();
412 return (val_low + alpha * (val_high - val_low)).matrix();
413}
414
415// --- Logical AND ---
422template <JanusScalar T1, JanusScalar T2> auto logical_and(const T1 &x1, const T2 &x2) {
423 // Both define operator &&
424 return x1 && x2;
425}
426
435template <typename DerivedA, typename DerivedB>
436auto logical_and(const Eigen::MatrixBase<DerivedA> &a, const Eigen::MatrixBase<DerivedB> &b) {
437 using Scalar = typename DerivedA::Scalar;
438 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
439 return a.binaryExpr(b, [](const auto &x, const auto &y) { return x && y; });
440 } else {
441 // Ensure boolean context for Eigen arrays
442 return ((a.array() != 0) && (b.array() != 0));
443 }
444}
445
446// --- Logical OR ---
453template <JanusScalar T1, JanusScalar T2> auto logical_or(const T1 &x1, const T2 &x2) {
454 return x1 || x2;
455}
456
465template <typename DerivedA, typename DerivedB>
466auto logical_or(const Eigen::MatrixBase<DerivedA> &a, const Eigen::MatrixBase<DerivedB> &b) {
467 using Scalar = typename DerivedA::Scalar;
468 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
469 return a.binaryExpr(b, [](const auto &x, const auto &y) { return x || y; });
470 } else {
471 return ((a.array() != 0) || (b.array() != 0));
472 }
473}
474
475// --- Logical NOT ---
481template <JanusScalar T> auto logical_not(const T &x) { return !x; }
482
489template <typename Derived> auto logical_not(const Eigen::MatrixBase<Derived> &a) {
490 using Scalar = typename Derived::Scalar;
491 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
492 return a.unaryExpr([](const auto &x) { return !x; });
493 } else {
494 return (a.array() == 0);
495 }
496}
497
498// --- All ---
504template <typename Derived> auto all(const Eigen::MatrixBase<Derived> &a) {
505 using Scalar = typename Derived::Scalar;
506 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
507 return detail::count_symbolic_truthy(a) == static_cast<double>(a.size());
508 } else {
509 return (a.array() != 0).all();
510 }
511}
512
513// --- Any ---
519template <typename Derived> auto any(const Eigen::MatrixBase<Derived> &a) {
520 using Scalar = typename Derived::Scalar;
521 if constexpr (std::is_same_v<Scalar, SymbolicScalar>) {
522 return detail::count_symbolic_truthy(a) >= 1.0;
523 } else {
524 return (a.array() != 0).any();
525 }
526}
527
528// --- Select (Multi-way branching like switch/case) ---
546template <typename CondType, typename Scalar>
547Scalar select(const std::vector<CondType> &conditions, const std::vector<Scalar> &values,
548 const Scalar &default_value) {
549 if (conditions.size() != values.size()) {
550 throw InvalidArgument("select: conditions and values must have same size");
551 }
552
553 // Start with default
554 Scalar result = default_value;
555
556 // Work backwards so earlier conditions override later ones
557 for (int i = static_cast<int>(conditions.size()) - 1; i >= 0; --i) {
558 result = where(conditions[i], values[i], result);
559 }
560
561 return result;
562}
563
564// Overload for initializer lists (cleaner syntax)
565template <typename CondType, typename Scalar>
566Scalar select(std::initializer_list<CondType> conditions, std::initializer_list<Scalar> values,
567 const Scalar &default_value) {
568 return select(std::vector<CondType>(conditions), std::vector<Scalar>(values), default_value);
569}
570
571} // namespace janus
Scalar and element-wise arithmetic functions (abs, sqrt, pow, exp, log, etc.).
C++20 concepts constraining valid Janus scalar types.
Custom exception hierarchy for Janus framework.
Core type aliases for numeric and symbolic Eigen/CasADi interop.
Input validation failed (e.g., mismatched sizes, invalid parameters).
Definition JanusError.hpp:31
SymbolicScalar as_symbolic_predicate(const SymbolicScalar &expr)
Definition Logic.hpp:97
bool is_symbolic_predicate(const SymbolicScalar &expr)
Definition Logic.hpp:78
auto select(const Cond &cond, const Eigen::MatrixBase< DerivedTrue > &if_true, const Eigen::MatrixBase< DerivedFalse > &if_false)
Definition Logic.hpp:55
SymbolicScalar count_symbolic_truthy(const Eigen::MatrixBase< Derived > &a)
Definition Logic.hpp:102
Definition Diagnostics.hpp:19
auto sigmoid_blend(const T &x, const TLow &val_low, const THigh &val_high, const Sharpness &sharpness=1.0)
Smoothly blends between val_low and val_high using a sigmoid function blend = val_low + (val_high - v...
Definition Logic.hpp:387
Scalar select(const std::vector< CondType > &conditions, const std::vector< Scalar > &values, const Scalar &default_value)
Multi-way conditional selection (cleaner alternative to nested where).
Definition Logic.hpp:547
auto where(const Cond &cond, const T1 &if_true, const T2 &if_false)
Select values based on condition (ternary operator) Returns: cond ? if_true : if_false Supports mixed...
Definition Logic.hpp:43
auto logical_not(const T &x)
Logical NOT (!x).
Definition Logic.hpp:481
auto clamp(const T &val, const TLow &low, const THigh &high)
Clamps value between low and high.
Definition Logic.hpp:236
auto le(const Eigen::MatrixBase< DerivedA > &a, const Eigen::MatrixBase< DerivedB > &b)
Element-wise less than or equal comparison.
Definition Logic.hpp:308
auto ge(const Eigen::MatrixBase< DerivedA > &a, const Eigen::MatrixBase< DerivedB > &b)
Element-wise greater than or equal comparison.
Definition Logic.hpp:327
auto any(const Eigen::MatrixBase< Derived > &a)
Returns true if any element is true (non-zero).
Definition Logic.hpp:519
typename BooleanType< T >::type BooleanType_t
Definition Logic.hpp:28
auto logical_and(const T1 &x1, const T2 &x2)
Logical AND (x && y).
Definition Logic.hpp:422
auto lt(const Eigen::MatrixBase< DerivedA > &a, const Eigen::MatrixBase< DerivedB > &b)
Element-wise less than comparison.
Definition Logic.hpp:270
auto gt(const Eigen::MatrixBase< DerivedA > &a, const Eigen::MatrixBase< DerivedB > &b)
Element-wise greater than comparison.
Definition Logic.hpp:289
auto neq(const Eigen::MatrixBase< DerivedA > &a, const Eigen::MatrixBase< DerivedB > &b)
Element-wise inequality comparison.
Definition Logic.hpp:365
auto logical_or(const T1 &x1, const T2 &x2)
Logical OR (x || y).
Definition Logic.hpp:453
auto eq(const Eigen::MatrixBase< DerivedA > &a, const Eigen::MatrixBase< DerivedB > &b)
Element-wise equality comparison.
Definition Logic.hpp:346
auto min(const T1 &a, const T2 &b)
Computes minimum of two values.
Definition Logic.hpp:151
casadi::MX SymbolicScalar
CasADi MX symbolic scalar.
Definition JanusTypes.hpp:70
auto max(const T1 &a, const T2 &b)
Computes maximum of two values.
Definition Logic.hpp:194
auto all(const Eigen::MatrixBase< Derived > &a)
Returns true if all elements are true (non-zero).
Definition Logic.hpp:504
T exp(const T &x)
Computes the exponential function e^x.
Definition Arithmetic.hpp:131
SymbolicScalar type
Definition Logic.hpp:25
Definition Logic.hpp:20
bool type
Definition Logic.hpp:21