Janus 2.0.0
High-performance C++20 dual-mode numerical framework
Loading...
Searching...
No Matches
MultiShooting.hpp
Go to the documentation of this file.
1
5
6#pragma once
7
10#include <string>
11#include <tuple>
12
13namespace janus {
14
21 int n_intervals = 20;
22 std::string integrator = "cvodes";
23 double tol = 1e-8;
24 bool normalize_time = true;
25};
26
33class MultipleShooting : public TranscriptionBase<MultipleShooting> {
35
36 public:
42
52 std::tuple<SymbolicMatrix, SymbolicMatrix, NumericVector>
53 setup(int n_states, int n_controls, double t0, double tf,
54 const MultiShootingOptions &opts = {}) {
55 return setup_impl(n_states, n_controls, t0, tf, false, opts);
56 }
57
67 std::tuple<SymbolicMatrix, SymbolicMatrix, NumericVector>
68 setup(int n_states, int n_controls, double t0, const SymbolicScalar &tf,
69 const MultiShootingOptions &opts = {}) {
70 tf_symbolic_ = tf;
71 return setup_impl(n_states, n_controls, t0, 1.0, true, opts);
72 }
73
76 int n_intervals() const { return n_intervals_; }
77
78 private:
80 casadi::Function integrator_;
81 int n_intervals_ = 0;
82
83 std::tuple<SymbolicMatrix, SymbolicMatrix, NumericVector>
84 setup_impl(int n_states, int n_controls, double t0, double tf_val, bool variable_tf,
85 const MultiShootingOptions &opts) {
86 if (opts.n_intervals < 1) {
87 throw InvalidArgument("MultipleShooting: n_intervals must be >= 1");
88 }
89 if (n_states < 1) {
90 throw InvalidArgument("MultipleShooting: n_states must be >= 1");
91 }
92 if (n_controls < 0) {
93 throw InvalidArgument("MultipleShooting: n_controls must be >= 0");
94 }
95
98 opts_ = opts;
99 n_intervals_ = opts.n_intervals;
100 n_nodes_ = n_intervals_ + 1;
101 t0_ = t0;
102
103 if (variable_tf) {
104 tf_is_variable_ = true;
105 } else {
106 tf_fixed_ = tf_val;
107 tf_is_variable_ = false;
108 }
109
111 for (int k = 0; k < n_nodes_; ++k) {
112 for (int i = 0; i < n_states_; ++i) {
113 states_(k, i) = opti_.variable(0.0);
114 }
115 }
116
117 controls_ = SymbolicMatrix(n_intervals_, n_controls_);
118 for (int k = 0; k < n_intervals_; ++k) {
119 for (int i = 0; i < n_controls_; ++i) {
120 controls_(k, i) = opti_.variable(0.0);
121 }
122 }
123
124 tau_ = linspace(0.0, 1.0, n_nodes_);
125 setup_complete_ = true;
126 dynamics_set_ = false;
128 integrator_ = casadi::Function();
129
130 return {states_, controls_, tau_};
131 }
132
133 void ensure_integrator() {
134 if (!integrator_.is_null()) {
135 return;
136 }
137
138 SymbolicScalar x_sym = sym("x", n_states_);
139 SymbolicScalar p_sym = sym("p", n_controls_ + 1); // p = [u; dt]
140 SymbolicScalar t = sym("t");
141
142 SymbolicScalar u_sym;
143 if (n_controls_ > 0) {
144 u_sym = p_sym(casadi::Slice(0, n_controls_));
145 } else {
146 u_sym = casadi::MX(0, 1);
147 }
148 SymbolicScalar dt_sym = p_sym(n_controls_);
149
150 SymbolicVector x_vec = as_vector(x_sym);
151 SymbolicVector u_vec = as_vector(u_sym);
152 SymbolicVector dxdt = dynamics_(x_vec, u_vec, t);
153 SymbolicVector ode_scaled = dxdt * dt_sym;
154
155 casadi::MXDict dae = {{"x", x_sym}, {"p", p_sym}, {"ode", to_mx(ode_scaled)}};
156
157 casadi::Dict intg_opts;
158 if (opts_.integrator == "cvodes") {
159 intg_opts["abstol"] = opts_.tol;
160 intg_opts["reltol"] = opts_.tol;
161 }
162
163 integrator_ = casadi::integrator("shooting_intg", opts_.integrator, dae, intg_opts);
164 }
165
166 void add_dynamics_constraints_impl() {
167 if (!dynamics_set_) {
168 throw RuntimeError(
169 "MultipleShooting: call set_dynamics() before add_dynamics_constraints()");
170 }
171 ensure_integrator();
172
173 const SymbolicScalar dt = get_duration() / static_cast<double>(n_intervals_);
174
175 for (int k = 0; k < n_intervals_; ++k) {
178 SymbolicVector x_kp1 = get_state_at_node(k + 1);
179
180 SymbolicScalar p = SymbolicScalar::vertcat({to_mx(u_k), dt});
181 casadi::MXDict args = {{"x0", to_mx(x_k)}, {"p", p}};
182 casadi::MXDict res = integrator_(args);
183 SymbolicScalar x_integrated = res.at("xf");
184
185 opti_.subject_to(to_mx(x_kp1) == x_integrated);
186 }
187 }
188};
189
190} // namespace janus
Point distribution generators (linspace, cosine, sine, log, geometric).
Shared CRTP base for trajectory transcription methods.
Input validation failed (e.g., mismatched sizes, invalid parameters).
Definition JanusError.hpp:31
MultipleShooting(Opti &opti)
Construct with a reference to the optimization environment.
Definition MultiShooting.hpp:41
int n_intervals() const
Get the number of shooting intervals.
Definition MultiShooting.hpp:76
std::tuple< SymbolicMatrix, SymbolicMatrix, NumericVector > setup(int n_states, int n_controls, double t0, const SymbolicScalar &tf, const MultiShootingOptions &opts={})
Set up the shooting problem with variable final time.
Definition MultiShooting.hpp:68
std::tuple< SymbolicMatrix, SymbolicMatrix, NumericVector > setup(int n_states, int n_controls, double t0, double tf, const MultiShootingOptions &opts={})
Set up the shooting problem with fixed final time.
Definition MultiShooting.hpp:53
Main optimization environment class.
Definition Opti.hpp:167
bool tf_is_variable_
Definition TranscriptionBase.hpp:161
SymbolicMatrix controls_
Definition TranscriptionBase.hpp:169
int n_controls_
Definition TranscriptionBase.hpp:155
int n_controls() const
Definition TranscriptionBase.hpp:136
int n_states_
Definition TranscriptionBase.hpp:154
double t0_
Definition TranscriptionBase.hpp:158
int n_states() const
Definition TranscriptionBase.hpp:133
NumericVector tau_
Definition TranscriptionBase.hpp:167
TranscriptionBase(Opti &opti)
Definition TranscriptionBase.hpp:34
double tf_fixed_
Definition TranscriptionBase.hpp:159
std::function< SymbolicVector(const SymbolicVector &, const SymbolicVector &, const SymbolicScalar &)> dynamics_
Definition TranscriptionBase.hpp:173
SymbolicScalar tf_symbolic_
Definition TranscriptionBase.hpp:160
SymbolicMatrix states_
Definition TranscriptionBase.hpp:168
bool setup_complete_
Definition TranscriptionBase.hpp:163
int n_nodes_
Definition TranscriptionBase.hpp:156
SymbolicVector get_state_at_node(int k) const
Definition TranscriptionBase.hpp:189
SymbolicScalar get_duration() const
Definition TranscriptionBase.hpp:175
SymbolicVector get_control_at_node(int k) const
Definition TranscriptionBase.hpp:200
bool dynamics_constraints_added_
Definition TranscriptionBase.hpp:165
bool dynamics_set_
Definition TranscriptionBase.hpp:164
Opti & opti_
Definition TranscriptionBase.hpp:153
Definition Diagnostics.hpp:19
JanusVector< SymbolicScalar > SymbolicVector
Eigen vector of MX elements.
Definition JanusTypes.hpp:72
SymbolicVector as_vector(const casadi::MX &m)
Convert CasADi MX vector to SymbolicVector (Eigen container of MX).
Definition JanusTypes.hpp:228
casadi::MX to_mx(const Eigen::MatrixBase< Derived > &e)
Convert Eigen matrix of MX (or numeric) to CasADi MX.
Definition JanusTypes.hpp:189
JanusMatrix< SymbolicScalar > SymbolicMatrix
Eigen matrix of MX elements.
Definition JanusTypes.hpp:71
SymbolicScalar sym(const std::string &name)
Create a named symbolic scalar variable.
Definition JanusTypes.hpp:90
JanusVector< T > linspace(const T &start, const T &end, int n)
Generates linearly spaced vector.
Definition Spacing.hpp:26
casadi::MX SymbolicScalar
CasADi MX symbolic scalar.
Definition JanusTypes.hpp:70
Options for MultipleShooting.
Definition MultiShooting.hpp:20
bool normalize_time
If true, integrates on normalized time and scales ODE by dt.
Definition MultiShooting.hpp:24
std::string integrator
Integrator plugin ("cvodes", "rk", "idas").
Definition MultiShooting.hpp:22
double tol
Integrator required tolerance.
Definition MultiShooting.hpp:23
int n_intervals
Number of shooting intervals.
Definition MultiShooting.hpp:21