Example Cases#
Example cases are the demonstration of physical example with known analytical solution or well-studied phenomenon. Each cases follows the recommended workflow, shown here. Feel free to use them as an initial template to build your own case study.
Axial Stretching#
1""" Axial stretching test-case
2
3 Assume we have a rod lying aligned in the x-direction, with high internal
4 damping.
5
6 We fix one end (say, the left end) of the rod to a wall. On the right
7 end we apply a force directed axially pulling the rods tip. Linear
8 theory (assuming small displacements) predict that the net displacement
9 experienced by the rod tip is Δx = FL/AE where the symbols carry their
10 usual meaning (the rod is just a linear spring). We compare our results
11 with the above result.
12
13 We can "improve" the theory by having a better estimate for the rod's
14 spring constant by assuming that it equilibriates under the new position,
15 with
16 Δx = F * (L + Δx)/ (A * E)
17 which results in Δx = (F*l)/(A*E - F). Our rod reaches equilibrium wrt to
18 this position.
19
20 Note that if the damping is not high, the rod oscillates about the eventual
21 resting position (and this agrees with the theoretical predictions without
22 any damping : we should see the rod oscillating simple-harmonically in time).
23
24 isort:skip_file
25"""
26
27import numpy as np
28from matplotlib import pyplot as plt
29
30import elastica as ea
31
32
33class StretchingBeamSimulator(
34 ea.BaseSystemCollection, ea.Constraints, ea.Forcing, ea.Damping, ea.CallBacks
35):
36 pass
37
38
39stretch_sim = StretchingBeamSimulator()
40final_time = 200.0
41
42# Options
43PLOT_FIGURE = True
44SAVE_FIGURE = False
45SAVE_RESULTS = False
46
47# setting up test params
48n_elem = 19
49start = np.zeros((3,))
50direction = np.array([1.0, 0.0, 0.0])
51normal = np.array([0.0, 1.0, 0.0])
52base_length = 1.0
53base_radius = 0.025
54base_area = np.pi * base_radius**2
55density = 1000
56youngs_modulus = 1e4
57# For shear modulus of 1e4, nu is 99!
58poisson_ratio = 0.5
59shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
60
61stretchable_rod = ea.CosseratRod.straight_rod(
62 n_elem,
63 start,
64 direction,
65 normal,
66 base_length,
67 base_radius,
68 density,
69 youngs_modulus=youngs_modulus,
70 shear_modulus=shear_modulus,
71)
72
73stretch_sim.append(stretchable_rod)
74stretch_sim.constrain(stretchable_rod).using(
75 ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
76)
77
78end_force_x = 1.0
79end_force = np.array([end_force_x, 0.0, 0.0])
80stretch_sim.add_forcing_to(stretchable_rod).using(
81 ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=1e-2
82)
83
84# add damping
85dl = base_length / n_elem
86dt = 0.1 * dl
87damping_constant = 0.1
88stretch_sim.dampen(stretchable_rod).using(
89 ea.AnalyticalLinearDamper,
90 damping_constant=damping_constant,
91 time_step=dt,
92)
93
94
95# Add call backs
96class AxialStretchingCallBack(ea.CallBackBaseClass):
97 """
98 Tracks the velocity norms of the rod
99 """
100
101 def __init__(self, step_skip: int, callback_params: dict):
102 ea.CallBackBaseClass.__init__(self)
103 self.every = step_skip
104 self.callback_params = callback_params
105
106 def make_callback(self, system, time, current_step: int):
107
108 if current_step % self.every == 0:
109
110 self.callback_params["time"].append(time)
111 # Collect only x
112 self.callback_params["position"].append(
113 system.position_collection[0, -1].copy()
114 )
115 self.callback_params["velocity_norms"].append(
116 np.linalg.norm(system.velocity_collection.copy())
117 )
118 return
119
120
121recorded_history = ea.defaultdict(list)
122stretch_sim.collect_diagnostics(stretchable_rod).using(
123 AxialStretchingCallBack, step_skip=200, callback_params=recorded_history
124)
125
126stretch_sim.finalize()
127timestepper = ea.PositionVerlet()
128# timestepper = PEFRL()
129
130total_steps = int(final_time / dt)
131print("Total steps", total_steps)
132ea.integrate(timestepper, stretch_sim, final_time, total_steps)
133
134if PLOT_FIGURE:
135 # First-order theory with base-length
136 expected_tip_disp = end_force_x * base_length / base_area / youngs_modulus
137 # First-order theory with modified-length, gives better estimates
138 expected_tip_disp_improved = (
139 end_force_x * base_length / (base_area * youngs_modulus - end_force_x)
140 )
141
142 fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
143 ax = fig.add_subplot(111)
144 ax.plot(recorded_history["time"], recorded_history["position"], lw=2.0)
145 ax.hlines(base_length + expected_tip_disp, 0.0, final_time, "k", "dashdot", lw=1.0)
146 ax.hlines(
147 base_length + expected_tip_disp_improved, 0.0, final_time, "k", "dashed", lw=2.0
148 )
149 if SAVE_FIGURE:
150 fig.savefig("axial_stretching.pdf")
151 plt.show()
152
153if SAVE_RESULTS:
154 import pickle
155
156 filename = "axial_stretching_data.dat"
157 file = open(filename, "wb")
158 pickle.dump(stretchable_rod, file)
159 file.close()
160
161 tv = (
162 np.asarray(recorded_history["time"]),
163 np.asarray(recorded_history["velocity_norms"]),
164 )
165
166 def as_time_series(v):
167 return v.T
168
169 np.savetxt(
170 "velocity_norms.csv",
171 as_time_series(np.stack(tv)),
172 delimiter=",",
173 )
Timoshenko#
1__doc__ = """Timoshenko beam validation case, for detailed explanation refer to
2Gazzola et. al. R. Soc. 2018 section 3.4.3 """
3
4import numpy as np
5import elastica as ea
6from examples.TimoshenkoBeamCase.timoshenko_postprocessing import plot_timoshenko
7
8
9class TimoshenkoBeamSimulator(
10 ea.BaseSystemCollection, ea.Constraints, ea.Forcing, ea.CallBacks, ea.Damping
11):
12 pass
13
14
15timoshenko_sim = TimoshenkoBeamSimulator()
16final_time = 5000.0
17
18# Options
19PLOT_FIGURE = True
20SAVE_FIGURE = True
21SAVE_RESULTS = False
22ADD_UNSHEARABLE_ROD = False
23
24# setting up test params
25n_elem = 100
26start = np.zeros((3,))
27direction = np.array([0.0, 0.0, 1.0])
28normal = np.array([0.0, 1.0, 0.0])
29base_length = 3.0
30base_radius = 0.25
31base_area = np.pi * base_radius**2
32density = 5000
33nu = 0.1 / 7 / density / base_area
34E = 1e6
35# For shear modulus of 1e4, nu is 99!
36poisson_ratio = 99
37shear_modulus = E / (poisson_ratio + 1.0)
38
39shearable_rod = ea.CosseratRod.straight_rod(
40 n_elem,
41 start,
42 direction,
43 normal,
44 base_length,
45 base_radius,
46 density,
47 youngs_modulus=E,
48 shear_modulus=shear_modulus,
49)
50
51timoshenko_sim.append(shearable_rod)
52# add damping
53dl = base_length / n_elem
54dt = 0.07 * dl
55timoshenko_sim.dampen(shearable_rod).using(
56 ea.AnalyticalLinearDamper,
57 damping_constant=nu,
58 time_step=dt,
59)
60
61timoshenko_sim.constrain(shearable_rod).using(
62 ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
63)
64
65end_force = np.array([-15.0, 0.0, 0.0])
66timoshenko_sim.add_forcing_to(shearable_rod).using(
67 ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
68)
69
70
71if ADD_UNSHEARABLE_ROD:
72 # Start into the plane
73 unshearable_start = np.array([0.0, -1.0, 0.0])
74 shear_modulus = E / (-0.7 + 1.0)
75 unshearable_rod = ea.CosseratRod.straight_rod(
76 n_elem,
77 unshearable_start,
78 direction,
79 normal,
80 base_length,
81 base_radius,
82 density,
83 youngs_modulus=E,
84 # Unshearable rod needs G -> inf, which is achievable with -ve poisson ratio
85 shear_modulus=shear_modulus,
86 )
87
88 timoshenko_sim.append(unshearable_rod)
89
90 # add damping
91 timoshenko_sim.dampen(unshearable_rod).using(
92 ea.AnalyticalLinearDamper,
93 damping_constant=nu,
94 time_step=dt,
95 )
96 timoshenko_sim.constrain(unshearable_rod).using(
97 ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
98 )
99 timoshenko_sim.add_forcing_to(unshearable_rod).using(
100 ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
101 )
102
103
104# Add call backs
105class VelocityCallBack(ea.CallBackBaseClass):
106 """
107 Tracks the velocity norms of the rod
108 """
109
110 def __init__(self, step_skip: int, callback_params: dict):
111 ea.CallBackBaseClass.__init__(self)
112 self.every = step_skip
113 self.callback_params = callback_params
114
115 def make_callback(self, system, time, current_step: int):
116
117 if current_step % self.every == 0:
118
119 self.callback_params["time"].append(time)
120 # Collect x
121 self.callback_params["velocity_norms"].append(
122 np.linalg.norm(system.velocity_collection.copy())
123 )
124 return
125
126
127recorded_history = ea.defaultdict(list)
128timoshenko_sim.collect_diagnostics(shearable_rod).using(
129 VelocityCallBack, step_skip=500, callback_params=recorded_history
130)
131
132timoshenko_sim.finalize()
133timestepper = ea.PositionVerlet()
134# timestepper = PEFRL()
135
136total_steps = int(final_time / dt)
137print("Total steps", total_steps)
138ea.integrate(timestepper, timoshenko_sim, final_time, total_steps)
139
140if PLOT_FIGURE:
141 plot_timoshenko(shearable_rod, end_force, SAVE_FIGURE, ADD_UNSHEARABLE_ROD)
142
143if SAVE_RESULTS:
144 import pickle
145
146 filename = "Timoshenko_beam_data.dat"
147 file = open(filename, "wb")
148 pickle.dump(shearable_rod, file)
149 file.close()
150
151 tv = (
152 np.asarray(recorded_history["time"]),
153 np.asarray(recorded_history["velocity_norms"]),
154 )
155
156 def as_time_series(v):
157 return v.T
158
159 np.savetxt(
160 "velocity_norms.csv",
161 as_time_series(np.stack(tv)),
162 delimiter=",",
163 )
Butterfly#
1import numpy as np
2from matplotlib import pyplot as plt
3from matplotlib.colors import to_rgb
4
5
6import elastica as ea
7from elastica.utils import MaxDimension
8
9
10class ButterflySimulator(ea.BaseSystemCollection, ea.CallBacks):
11 pass
12
13
14butterfly_sim = ButterflySimulator()
15final_time = 40.0
16
17# Options
18PLOT_FIGURE = True
19SAVE_FIGURE = True
20SAVE_RESULTS = True
21ADD_UNSHEARABLE_ROD = False
22
23# setting up test params
24# FIXME : Doesn't work with elements > 10 (the inverse rotate kernel fails)
25n_elem = 4 # Change based on requirements, but be careful
26n_elem += n_elem % 2
27half_n_elem = n_elem // 2
28
29origin = np.zeros((3, 1))
30angle_of_inclination = np.deg2rad(45.0)
31
32# in-plane
33horizontal_direction = np.array([0.0, 0.0, 1.0]).reshape(-1, 1)
34vertical_direction = np.array([1.0, 0.0, 0.0]).reshape(-1, 1)
35
36# out-of-plane
37normal = np.array([0.0, 1.0, 0.0])
38
39total_length = 3.0
40base_radius = 0.25
41base_area = np.pi * base_radius**2
42density = 5000
43youngs_modulus = 1e4
44poisson_ratio = 0.5
45shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
46
47positions = np.empty((MaxDimension.value(), n_elem + 1))
48dl = total_length / n_elem
49
50# First half of positions stem from slope angle_of_inclination
51first_half = np.arange(half_n_elem + 1.0).reshape(1, -1)
52positions[..., : half_n_elem + 1] = origin + dl * first_half * (
53 np.cos(angle_of_inclination) * horizontal_direction
54 + np.sin(angle_of_inclination) * vertical_direction
55)
56positions[..., half_n_elem:] = positions[
57 ..., half_n_elem : half_n_elem + 1
58] + dl * first_half * (
59 np.cos(angle_of_inclination) * horizontal_direction
60 - np.sin(angle_of_inclination) * vertical_direction
61)
62
63butterfly_rod = ea.CosseratRod.straight_rod(
64 n_elem,
65 start=origin.reshape(3),
66 direction=np.array([0.0, 0.0, 1.0]),
67 normal=normal,
68 base_length=total_length,
69 base_radius=base_radius,
70 density=density,
71 youngs_modulus=youngs_modulus,
72 shear_modulus=shear_modulus,
73 position=positions,
74)
75
76butterfly_sim.append(butterfly_rod)
77
78
79# Add call backs
80class VelocityCallBack(ea.CallBackBaseClass):
81 """
82 Call back function for continuum snake
83 """
84
85 def __init__(self, step_skip: int, callback_params: dict):
86 ea.CallBackBaseClass.__init__(self)
87 self.every = step_skip
88 self.callback_params = callback_params
89
90 def make_callback(self, system, time, current_step: int):
91
92 if current_step % self.every == 0:
93
94 self.callback_params["time"].append(time)
95 # Collect x
96 self.callback_params["position"].append(system.position_collection.copy())
97 # Collect energies as well
98 self.callback_params["te"].append(system.compute_translational_energy())
99 self.callback_params["re"].append(system.compute_rotational_energy())
100 self.callback_params["se"].append(system.compute_shear_energy())
101 self.callback_params["be"].append(system.compute_bending_energy())
102 return
103
104
105recorded_history = ea.defaultdict(list)
106# initially record history
107recorded_history["time"].append(0.0)
108recorded_history["position"].append(butterfly_rod.position_collection.copy())
109recorded_history["te"].append(butterfly_rod.compute_translational_energy())
110recorded_history["re"].append(butterfly_rod.compute_rotational_energy())
111recorded_history["se"].append(butterfly_rod.compute_shear_energy())
112recorded_history["be"].append(butterfly_rod.compute_bending_energy())
113
114butterfly_sim.collect_diagnostics(butterfly_rod).using(
115 VelocityCallBack, step_skip=100, callback_params=recorded_history
116)
117
118
119butterfly_sim.finalize()
120timestepper = ea.PositionVerlet()
121# timestepper = PEFRL()
122
123dt = 0.01 * dl
124total_steps = int(final_time / dt)
125print("Total steps", total_steps)
126ea.integrate(timestepper, butterfly_sim, final_time, total_steps)
127
128if PLOT_FIGURE:
129 # Plot the histories
130 fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
131 ax = fig.add_subplot(111)
132 positions = recorded_history["position"]
133 # record first position
134 first_position = positions.pop(0)
135 ax.plot(first_position[2, ...], first_position[0, ...], "r--", lw=2.0)
136 n_positions = len(positions)
137 for i, pos in enumerate(positions):
138 alpha = np.exp(i / n_positions - 1)
139 ax.plot(pos[2, ...], pos[0, ...], "b", lw=0.6, alpha=alpha)
140 # final position is also separate
141 last_position = positions.pop()
142 ax.plot(last_position[2, ...], last_position[0, ...], "k--", lw=2.0)
143 # don't block
144 fig.show()
145
146 # Plot the energies
147 energy_fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
148 energy_ax = energy_fig.add_subplot(111)
149 times = np.asarray(recorded_history["time"])
150 te = np.asarray(recorded_history["te"])
151 re = np.asarray(recorded_history["re"])
152 be = np.asarray(recorded_history["be"])
153 se = np.asarray(recorded_history["se"])
154
155 energy_ax.plot(times, te, c=to_rgb("xkcd:reddish"), lw=2.0, label="Translations")
156 energy_ax.plot(times, re, c=to_rgb("xkcd:bluish"), lw=2.0, label="Rotation")
157 energy_ax.plot(times, be, c=to_rgb("xkcd:burple"), lw=2.0, label="Bend")
158 energy_ax.plot(times, se, c=to_rgb("xkcd:goldenrod"), lw=2.0, label="Shear")
159 energy_ax.plot(times, te + re + be + se, c="k", lw=2.0, label="Total energy")
160 energy_ax.legend()
161 # don't block
162 energy_fig.show()
163
164 if SAVE_FIGURE:
165 fig.savefig("butterfly.png")
166 energy_fig.savefig("energies.png")
167
168 plt.show()
169
170if SAVE_RESULTS:
171 import pickle
172
173 filename = "butterfly_data.dat"
174 file = open(filename, "wb")
175 pickle.dump(butterfly_rod, file)
176 file.close()
Helical Buckling#
1__doc__ = """Helical buckling validation case, for detailed explanation refer to
2Gazzola et. al. R. Soc. 2018 section 3.4.1 """
3
4import numpy as np
5import elastica as ea
6from examples.HelicalBucklingCase.helicalbuckling_postprocessing import (
7 plot_helicalbuckling,
8)
9
10
11class HelicalBucklingSimulator(
12 ea.BaseSystemCollection, ea.Constraints, ea.Damping, ea.Forcing
13):
14 pass
15
16
17helicalbuckling_sim = HelicalBucklingSimulator()
18
19# Options
20PLOT_FIGURE = True
21SAVE_FIGURE = True
22SAVE_RESULTS = False
23
24# setting up test params
25n_elem = 100
26start = np.zeros((3,))
27direction = np.array([0.0, 0.0, 1.0])
28normal = np.array([0.0, 1.0, 0.0])
29base_length = 100.0
30base_radius = 0.35
31base_area = np.pi * base_radius**2
32density = 1.0 / (base_area)
33nu = 0.01 / density / base_area
34E = 1e6
35slack = 3
36number_of_rotations = 27
37# For shear modulus of 1e5, nu is 99!
38poisson_ratio = 9
39shear_modulus = E / (poisson_ratio + 1.0)
40shear_matrix = np.repeat(
41 shear_modulus * np.identity((3))[:, :, np.newaxis], n_elem, axis=2
42)
43temp_bend_matrix = np.zeros((3, 3))
44np.fill_diagonal(temp_bend_matrix, [1.345, 1.345, 0.789])
45bend_matrix = np.repeat(temp_bend_matrix[:, :, np.newaxis], n_elem - 1, axis=2)
46
47shearable_rod = ea.CosseratRod.straight_rod(
48 n_elem,
49 start,
50 direction,
51 normal,
52 base_length,
53 base_radius,
54 density,
55 youngs_modulus=E,
56 shear_modulus=shear_modulus,
57)
58# TODO: CosseratRod has to be able to take shear matrix as input, we should change it as done below
59
60shearable_rod.shear_matrix = shear_matrix
61shearable_rod.bend_matrix = bend_matrix
62
63
64helicalbuckling_sim.append(shearable_rod)
65# add damping
66dl = base_length / n_elem
67dt = 1e-3 * dl
68helicalbuckling_sim.dampen(shearable_rod).using(
69 ea.AnalyticalLinearDamper,
70 damping_constant=nu,
71 time_step=dt,
72)
73
74helicalbuckling_sim.constrain(shearable_rod).using(
75 ea.HelicalBucklingBC,
76 constrained_position_idx=(0, -1),
77 constrained_director_idx=(0, -1),
78 twisting_time=500,
79 slack=slack,
80 number_of_rotations=number_of_rotations,
81)
82
83helicalbuckling_sim.finalize()
84timestepper = ea.PositionVerlet()
85shearable_rod.velocity_collection[..., int((n_elem) / 2)] += np.array([0, 1e-6, 0.0])
86# timestepper = PEFRL()
87
88final_time = 10500.0
89total_steps = int(final_time / dt)
90print("Total steps", total_steps)
91ea.integrate(timestepper, helicalbuckling_sim, final_time, total_steps)
92
93if PLOT_FIGURE:
94 plot_helicalbuckling(shearable_rod, SAVE_FIGURE)
95
96if SAVE_RESULTS:
97 import pickle
98
99 filename = "HelicalBuckling_data.dat"
100 file = open(filename, "wb")
101 pickle.dump(shearable_rod, file)
102 file.close()
Continuum Snake#
1__doc__ = """Snake friction case from X. Zhang et. al. Nat. Comm. 2021"""
2
3import os
4import numpy as np
5import elastica as ea
6
7from examples.ContinuumSnakeCase.continuum_snake_postprocessing import (
8 plot_snake_velocity,
9 plot_video,
10 compute_projected_velocity,
11 plot_curvature,
12)
13
14
15class SnakeSimulator(
16 ea.BaseSystemCollection,
17 ea.Constraints,
18 ea.Forcing,
19 ea.Damping,
20 ea.CallBacks,
21 ea.Contact,
22):
23 pass
24
25
26def run_snake(
27 b_coeff, PLOT_FIGURE=False, SAVE_FIGURE=False, SAVE_VIDEO=False, SAVE_RESULTS=False
28):
29 # Initialize the simulation class
30 snake_sim = SnakeSimulator()
31
32 # Simulation parameters
33 period = 2
34 final_time = (11.0 + 0.01) * period
35
36 # setting up test params
37 n_elem = 50
38 start = np.zeros((3,))
39 direction = np.array([0.0, 0.0, 1.0])
40 normal = np.array([0.0, 1.0, 0.0])
41 base_length = 0.35
42 base_radius = base_length * 0.011
43 density = 1000
44 E = 1e6
45 poisson_ratio = 0.5
46 shear_modulus = E / (poisson_ratio + 1.0)
47
48 shearable_rod = ea.CosseratRod.straight_rod(
49 n_elem,
50 start,
51 direction,
52 normal,
53 base_length,
54 base_radius,
55 density,
56 youngs_modulus=E,
57 shear_modulus=shear_modulus,
58 )
59
60 snake_sim.append(shearable_rod)
61
62 # Add gravitational forces
63 gravitational_acc = -9.80665
64 snake_sim.add_forcing_to(shearable_rod).using(
65 ea.GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])
66 )
67
68 # Add muscle torques
69 wave_length = b_coeff[-1]
70 snake_sim.add_forcing_to(shearable_rod).using(
71 ea.MuscleTorques,
72 base_length=base_length,
73 b_coeff=b_coeff[:-1],
74 period=period,
75 wave_number=2.0 * np.pi / (wave_length),
76 phase_shift=0.0,
77 rest_lengths=shearable_rod.rest_lengths,
78 ramp_up_time=period,
79 direction=normal,
80 with_spline=True,
81 )
82
83 # Add friction forces
84 ground_plane = ea.Plane(
85 plane_origin=np.array([0.0, -base_radius, 0.0]), plane_normal=normal
86 )
87 snake_sim.append(ground_plane)
88 slip_velocity_tol = 1e-8
89 froude = 0.1
90 mu = base_length / (period * period * np.abs(gravitational_acc) * froude)
91 kinetic_mu_array = np.array(
92 [mu, 1.5 * mu, 2.0 * mu]
93 ) # [forward, backward, sideways]
94 static_mu_array = np.zeros(kinetic_mu_array.shape)
95 snake_sim.detect_contact_between(shearable_rod, ground_plane).using(
96 ea.RodPlaneContactWithAnisotropicFriction,
97 k=1.0,
98 nu=1e-6,
99 slip_velocity_tol=slip_velocity_tol,
100 static_mu_array=static_mu_array,
101 kinetic_mu_array=kinetic_mu_array,
102 )
103
104 # add damping
105 damping_constant = 2e-3
106 time_step = 1e-4
107 snake_sim.dampen(shearable_rod).using(
108 ea.AnalyticalLinearDamper,
109 damping_constant=damping_constant,
110 time_step=time_step,
111 )
112
113 total_steps = int(final_time / time_step)
114 rendering_fps = 60
115 step_skip = int(1.0 / (rendering_fps * time_step))
116
117 # Add call backs
118 class ContinuumSnakeCallBack(ea.CallBackBaseClass):
119 """
120 Call back function for continuum snake
121 """
122
123 def __init__(self, step_skip: int, callback_params: dict):
124 ea.CallBackBaseClass.__init__(self)
125 self.every = step_skip
126 self.callback_params = callback_params
127
128 def make_callback(self, system, time, current_step: int):
129
130 if current_step % self.every == 0:
131
132 self.callback_params["time"].append(time)
133 self.callback_params["step"].append(current_step)
134 self.callback_params["position"].append(
135 system.position_collection.copy()
136 )
137 self.callback_params["velocity"].append(
138 system.velocity_collection.copy()
139 )
140 self.callback_params["avg_velocity"].append(
141 system.compute_velocity_center_of_mass()
142 )
143
144 self.callback_params["center_of_mass"].append(
145 system.compute_position_center_of_mass()
146 )
147 self.callback_params["curvature"].append(system.kappa.copy())
148
149 return
150
151 pp_list = ea.defaultdict(list)
152 snake_sim.collect_diagnostics(shearable_rod).using(
153 ContinuumSnakeCallBack, step_skip=step_skip, callback_params=pp_list
154 )
155
156 snake_sim.finalize()
157
158 timestepper = ea.PositionVerlet()
159 ea.integrate(timestepper, snake_sim, final_time, total_steps)
160
161 if PLOT_FIGURE:
162 filename_plot = "continuum_snake_velocity.png"
163 plot_snake_velocity(pp_list, period, filename_plot, SAVE_FIGURE)
164 plot_curvature(pp_list, shearable_rod.rest_lengths, period, SAVE_FIGURE)
165
166 if SAVE_VIDEO:
167 filename_video = "continuum_snake.mp4"
168 plot_video(
169 pp_list,
170 video_name=filename_video,
171 fps=rendering_fps,
172 xlim=(0, 4),
173 ylim=(-1, 1),
174 )
175
176 if SAVE_RESULTS:
177 import pickle
178
179 filename = "continuum_snake.dat"
180 file = open(filename, "wb")
181 pickle.dump(pp_list, file)
182 file.close()
183
184 # Compute the average forward velocity. These will be used for optimization.
185 [_, _, avg_forward, avg_lateral] = compute_projected_velocity(pp_list, period)
186
187 return avg_forward, avg_lateral, pp_list
188
189
190if __name__ == "__main__":
191
192 # Options
193 PLOT_FIGURE = True
194 SAVE_FIGURE = True
195 SAVE_VIDEO = True
196 SAVE_RESULTS = False
197 CMA_OPTION = False
198
199 if CMA_OPTION:
200 import cma
201
202 SAVE_OPTIMIZED_COEFFICIENTS = False
203
204 def optimize_snake(spline_coefficient):
205 [avg_forward, _, _] = run_snake(
206 spline_coefficient,
207 PLOT_FIGURE=False,
208 SAVE_FIGURE=False,
209 SAVE_VIDEO=False,
210 SAVE_RESULTS=False,
211 )
212 return -avg_forward
213
214 # Optimize snake for forward velocity. In cma.fmin first input is function
215 # to be optimized, second input is initial guess for coefficients you are optimizing
216 # for and third input is standard deviation you initially set.
217 optimized_spline_coefficients = cma.fmin(optimize_snake, 7 * [0], 0.5)
218
219 # Save the optimized coefficients to a file
220 filename_data = "optimized_coefficients.txt"
221 if SAVE_OPTIMIZED_COEFFICIENTS:
222 assert filename_data != "", "provide a file name for coefficients"
223 np.savetxt(filename_data, optimized_spline_coefficients, delimiter=",")
224
225 else:
226 # Add muscle forces on the rod
227 if os.path.exists("optimized_coefficients.txt"):
228 t_coeff_optimized = np.genfromtxt(
229 "optimized_coefficients.txt", delimiter=","
230 )
231 else:
232 wave_length = 1.0
233 t_coeff_optimized = np.array(
234 [3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3]
235 )
236 t_coeff_optimized = np.hstack((t_coeff_optimized, wave_length))
237
238 # run the simulation
239 [avg_forward, avg_lateral, pp_list] = run_snake(
240 t_coeff_optimized, PLOT_FIGURE, SAVE_FIGURE, SAVE_VIDEO, SAVE_RESULTS
241 )
242
243 print("average forward velocity:", avg_forward)
244 print("average forward lateral:", avg_lateral)