Example Cases#

Example cases are the demonstration of physical example with known analytical solution or well-studied phenomenon. Each cases follows the recommended workflow, shown here. Feel free to use them as an initial template to build your own case study.

Axial Stretching#

  1""" Axial stretching test-case
  2
  3    Assume we have a rod lying aligned in the x-direction, with high internal
  4    damping.
  5
  6    We fix one end (say, the left end) of the rod to a wall. On the right
  7    end we apply a force directed axially pulling the rods tip. Linear
  8    theory (assuming small displacements) predict that the net displacement
  9    experienced by the rod tip is Δx = FL/AE where the symbols carry their
 10    usual meaning (the rod is just a linear spring). We compare our results
 11    with the above result.
 12
 13    We can "improve" the theory by having a better estimate for the rod's
 14    spring constant by assuming that it equilibriates under the new position,
 15    with
 16    Δx = F * (L + Δx)/ (A * E)
 17    which results in Δx = (F*l)/(A*E - F). Our rod reaches equilibrium wrt to
 18    this position.
 19
 20    Note that if the damping is not high, the rod oscillates about the eventual
 21    resting position (and this agrees with the theoretical predictions without
 22    any damping : we should see the rod oscillating simple-harmonically in time).
 23
 24    isort:skip_file
 25"""
 26
 27import numpy as np
 28from matplotlib import pyplot as plt
 29
 30import elastica as ea
 31
 32
 33class StretchingBeamSimulator(
 34    ea.BaseSystemCollection, ea.Constraints, ea.Forcing, ea.Damping, ea.CallBacks
 35):
 36    pass
 37
 38
 39stretch_sim = StretchingBeamSimulator()
 40final_time = 200.0
 41
 42# Options
 43PLOT_FIGURE = True
 44SAVE_FIGURE = False
 45SAVE_RESULTS = False
 46
 47# setting up test params
 48n_elem = 19
 49start = np.zeros((3,))
 50direction = np.array([1.0, 0.0, 0.0])
 51normal = np.array([0.0, 1.0, 0.0])
 52base_length = 1.0
 53base_radius = 0.025
 54base_area = np.pi * base_radius**2
 55density = 1000
 56youngs_modulus = 1e4
 57# For shear modulus of 1e4, nu is 99!
 58poisson_ratio = 0.5
 59shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
 60
 61stretchable_rod = ea.CosseratRod.straight_rod(
 62    n_elem,
 63    start,
 64    direction,
 65    normal,
 66    base_length,
 67    base_radius,
 68    density,
 69    youngs_modulus=youngs_modulus,
 70    shear_modulus=shear_modulus,
 71)
 72
 73stretch_sim.append(stretchable_rod)
 74stretch_sim.constrain(stretchable_rod).using(
 75    ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 76)
 77
 78end_force_x = 1.0
 79end_force = np.array([end_force_x, 0.0, 0.0])
 80stretch_sim.add_forcing_to(stretchable_rod).using(
 81    ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=1e-2
 82)
 83
 84# add damping
 85dl = base_length / n_elem
 86dt = 0.1 * dl
 87damping_constant = 0.1
 88stretch_sim.dampen(stretchable_rod).using(
 89    ea.AnalyticalLinearDamper,
 90    damping_constant=damping_constant,
 91    time_step=dt,
 92)
 93
 94
 95# Add call backs
 96class AxialStretchingCallBack(ea.CallBackBaseClass):
 97    """
 98    Tracks the velocity norms of the rod
 99    """
100
101    def __init__(self, step_skip: int, callback_params: dict):
102        ea.CallBackBaseClass.__init__(self)
103        self.every = step_skip
104        self.callback_params = callback_params
105
106    def make_callback(self, system, time, current_step: int):
107
108        if current_step % self.every == 0:
109
110            self.callback_params["time"].append(time)
111            # Collect only x
112            self.callback_params["position"].append(
113                system.position_collection[0, -1].copy()
114            )
115            self.callback_params["velocity_norms"].append(
116                np.linalg.norm(system.velocity_collection.copy())
117            )
118            return
119
120
121recorded_history = ea.defaultdict(list)
122stretch_sim.collect_diagnostics(stretchable_rod).using(
123    AxialStretchingCallBack, step_skip=200, callback_params=recorded_history
124)
125
126stretch_sim.finalize()
127timestepper = ea.PositionVerlet()
128# timestepper = PEFRL()
129
130total_steps = int(final_time / dt)
131print("Total steps", total_steps)
132ea.integrate(timestepper, stretch_sim, final_time, total_steps)
133
134if PLOT_FIGURE:
135    # First-order theory with base-length
136    expected_tip_disp = end_force_x * base_length / base_area / youngs_modulus
137    # First-order theory with modified-length, gives better estimates
138    expected_tip_disp_improved = (
139        end_force_x * base_length / (base_area * youngs_modulus - end_force_x)
140    )
141
142    fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
143    ax = fig.add_subplot(111)
144    ax.plot(recorded_history["time"], recorded_history["position"], lw=2.0)
145    ax.hlines(base_length + expected_tip_disp, 0.0, final_time, "k", "dashdot", lw=1.0)
146    ax.hlines(
147        base_length + expected_tip_disp_improved, 0.0, final_time, "k", "dashed", lw=2.0
148    )
149    if SAVE_FIGURE:
150        fig.savefig("axial_stretching.pdf")
151    plt.show()
152
153if SAVE_RESULTS:
154    import pickle
155
156    filename = "axial_stretching_data.dat"
157    file = open(filename, "wb")
158    pickle.dump(stretchable_rod, file)
159    file.close()
160
161    tv = (
162        np.asarray(recorded_history["time"]),
163        np.asarray(recorded_history["velocity_norms"]),
164    )
165
166    def as_time_series(v):
167        return v.T
168
169    np.savetxt(
170        "velocity_norms.csv",
171        as_time_series(np.stack(tv)),
172        delimiter=",",
173    )

Timoshenko#

  1__doc__ = """Timoshenko beam validation case, for detailed explanation refer to
  2Gazzola et. al. R. Soc. 2018  section 3.4.3 """
  3
  4import numpy as np
  5import elastica as ea
  6from examples.TimoshenkoBeamCase.timoshenko_postprocessing import plot_timoshenko
  7
  8
  9class TimoshenkoBeamSimulator(
 10    ea.BaseSystemCollection, ea.Constraints, ea.Forcing, ea.CallBacks, ea.Damping
 11):
 12    pass
 13
 14
 15timoshenko_sim = TimoshenkoBeamSimulator()
 16final_time = 5000.0
 17
 18# Options
 19PLOT_FIGURE = True
 20SAVE_FIGURE = True
 21SAVE_RESULTS = False
 22ADD_UNSHEARABLE_ROD = False
 23
 24# setting up test params
 25n_elem = 100
 26start = np.zeros((3,))
 27direction = np.array([0.0, 0.0, 1.0])
 28normal = np.array([0.0, 1.0, 0.0])
 29base_length = 3.0
 30base_radius = 0.25
 31base_area = np.pi * base_radius**2
 32density = 5000
 33nu = 0.1 / 7 / density / base_area
 34E = 1e6
 35# For shear modulus of 1e4, nu is 99!
 36poisson_ratio = 99
 37shear_modulus = E / (poisson_ratio + 1.0)
 38
 39shearable_rod = ea.CosseratRod.straight_rod(
 40    n_elem,
 41    start,
 42    direction,
 43    normal,
 44    base_length,
 45    base_radius,
 46    density,
 47    youngs_modulus=E,
 48    shear_modulus=shear_modulus,
 49)
 50
 51timoshenko_sim.append(shearable_rod)
 52# add damping
 53dl = base_length / n_elem
 54dt = 0.07 * dl
 55timoshenko_sim.dampen(shearable_rod).using(
 56    ea.AnalyticalLinearDamper,
 57    damping_constant=nu,
 58    time_step=dt,
 59)
 60
 61timoshenko_sim.constrain(shearable_rod).using(
 62    ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 63)
 64
 65end_force = np.array([-15.0, 0.0, 0.0])
 66timoshenko_sim.add_forcing_to(shearable_rod).using(
 67    ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
 68)
 69
 70
 71if ADD_UNSHEARABLE_ROD:
 72    # Start into the plane
 73    unshearable_start = np.array([0.0, -1.0, 0.0])
 74    shear_modulus = E / (-0.7 + 1.0)
 75    unshearable_rod = ea.CosseratRod.straight_rod(
 76        n_elem,
 77        unshearable_start,
 78        direction,
 79        normal,
 80        base_length,
 81        base_radius,
 82        density,
 83        youngs_modulus=E,
 84        # Unshearable rod needs G -> inf, which is achievable with -ve poisson ratio
 85        shear_modulus=shear_modulus,
 86    )
 87
 88    timoshenko_sim.append(unshearable_rod)
 89
 90    # add damping
 91    timoshenko_sim.dampen(unshearable_rod).using(
 92        ea.AnalyticalLinearDamper,
 93        damping_constant=nu,
 94        time_step=dt,
 95    )
 96    timoshenko_sim.constrain(unshearable_rod).using(
 97        ea.OneEndFixedBC, constrained_position_idx=(0,), constrained_director_idx=(0,)
 98    )
 99    timoshenko_sim.add_forcing_to(unshearable_rod).using(
100        ea.EndpointForces, 0.0 * end_force, end_force, ramp_up_time=final_time / 2.0
101    )
102
103
104# Add call backs
105class VelocityCallBack(ea.CallBackBaseClass):
106    """
107    Tracks the velocity norms of the rod
108    """
109
110    def __init__(self, step_skip: int, callback_params: dict):
111        ea.CallBackBaseClass.__init__(self)
112        self.every = step_skip
113        self.callback_params = callback_params
114
115    def make_callback(self, system, time, current_step: int):
116
117        if current_step % self.every == 0:
118
119            self.callback_params["time"].append(time)
120            # Collect x
121            self.callback_params["velocity_norms"].append(
122                np.linalg.norm(system.velocity_collection.copy())
123            )
124            return
125
126
127recorded_history = ea.defaultdict(list)
128timoshenko_sim.collect_diagnostics(shearable_rod).using(
129    VelocityCallBack, step_skip=500, callback_params=recorded_history
130)
131
132timoshenko_sim.finalize()
133timestepper = ea.PositionVerlet()
134# timestepper = PEFRL()
135
136total_steps = int(final_time / dt)
137print("Total steps", total_steps)
138ea.integrate(timestepper, timoshenko_sim, final_time, total_steps)
139
140if PLOT_FIGURE:
141    plot_timoshenko(shearable_rod, end_force, SAVE_FIGURE, ADD_UNSHEARABLE_ROD)
142
143if SAVE_RESULTS:
144    import pickle
145
146    filename = "Timoshenko_beam_data.dat"
147    file = open(filename, "wb")
148    pickle.dump(shearable_rod, file)
149    file.close()
150
151    tv = (
152        np.asarray(recorded_history["time"]),
153        np.asarray(recorded_history["velocity_norms"]),
154    )
155
156    def as_time_series(v):
157        return v.T
158
159    np.savetxt(
160        "velocity_norms.csv",
161        as_time_series(np.stack(tv)),
162        delimiter=",",
163    )

Butterfly#

  1import numpy as np
  2from matplotlib import pyplot as plt
  3from matplotlib.colors import to_rgb
  4
  5
  6import elastica as ea
  7from elastica.utils import MaxDimension
  8
  9
 10class ButterflySimulator(ea.BaseSystemCollection, ea.CallBacks):
 11    pass
 12
 13
 14butterfly_sim = ButterflySimulator()
 15final_time = 40.0
 16
 17# Options
 18PLOT_FIGURE = True
 19SAVE_FIGURE = True
 20SAVE_RESULTS = True
 21ADD_UNSHEARABLE_ROD = False
 22
 23# setting up test params
 24# FIXME : Doesn't work with elements > 10 (the inverse rotate kernel fails)
 25n_elem = 4  # Change based on requirements, but be careful
 26n_elem += n_elem % 2
 27half_n_elem = n_elem // 2
 28
 29origin = np.zeros((3, 1))
 30angle_of_inclination = np.deg2rad(45.0)
 31
 32# in-plane
 33horizontal_direction = np.array([0.0, 0.0, 1.0]).reshape(-1, 1)
 34vertical_direction = np.array([1.0, 0.0, 0.0]).reshape(-1, 1)
 35
 36# out-of-plane
 37normal = np.array([0.0, 1.0, 0.0])
 38
 39total_length = 3.0
 40base_radius = 0.25
 41base_area = np.pi * base_radius**2
 42density = 5000
 43youngs_modulus = 1e4
 44poisson_ratio = 0.5
 45shear_modulus = youngs_modulus / (poisson_ratio + 1.0)
 46
 47positions = np.empty((MaxDimension.value(), n_elem + 1))
 48dl = total_length / n_elem
 49
 50# First half of positions stem from slope angle_of_inclination
 51first_half = np.arange(half_n_elem + 1.0).reshape(1, -1)
 52positions[..., : half_n_elem + 1] = origin + dl * first_half * (
 53    np.cos(angle_of_inclination) * horizontal_direction
 54    + np.sin(angle_of_inclination) * vertical_direction
 55)
 56positions[..., half_n_elem:] = positions[
 57    ..., half_n_elem : half_n_elem + 1
 58] + dl * first_half * (
 59    np.cos(angle_of_inclination) * horizontal_direction
 60    - np.sin(angle_of_inclination) * vertical_direction
 61)
 62
 63butterfly_rod = ea.CosseratRod.straight_rod(
 64    n_elem,
 65    start=origin.reshape(3),
 66    direction=np.array([0.0, 0.0, 1.0]),
 67    normal=normal,
 68    base_length=total_length,
 69    base_radius=base_radius,
 70    density=density,
 71    youngs_modulus=youngs_modulus,
 72    shear_modulus=shear_modulus,
 73    position=positions,
 74)
 75
 76butterfly_sim.append(butterfly_rod)
 77
 78
 79# Add call backs
 80class VelocityCallBack(ea.CallBackBaseClass):
 81    """
 82    Call back function for continuum snake
 83    """
 84
 85    def __init__(self, step_skip: int, callback_params: dict):
 86        ea.CallBackBaseClass.__init__(self)
 87        self.every = step_skip
 88        self.callback_params = callback_params
 89
 90    def make_callback(self, system, time, current_step: int):
 91
 92        if current_step % self.every == 0:
 93
 94            self.callback_params["time"].append(time)
 95            # Collect x
 96            self.callback_params["position"].append(system.position_collection.copy())
 97            # Collect energies as well
 98            self.callback_params["te"].append(system.compute_translational_energy())
 99            self.callback_params["re"].append(system.compute_rotational_energy())
100            self.callback_params["se"].append(system.compute_shear_energy())
101            self.callback_params["be"].append(system.compute_bending_energy())
102            return
103
104
105recorded_history = ea.defaultdict(list)
106# initially record history
107recorded_history["time"].append(0.0)
108recorded_history["position"].append(butterfly_rod.position_collection.copy())
109recorded_history["te"].append(butterfly_rod.compute_translational_energy())
110recorded_history["re"].append(butterfly_rod.compute_rotational_energy())
111recorded_history["se"].append(butterfly_rod.compute_shear_energy())
112recorded_history["be"].append(butterfly_rod.compute_bending_energy())
113
114butterfly_sim.collect_diagnostics(butterfly_rod).using(
115    VelocityCallBack, step_skip=100, callback_params=recorded_history
116)
117
118
119butterfly_sim.finalize()
120timestepper = ea.PositionVerlet()
121# timestepper = PEFRL()
122
123dt = 0.01 * dl
124total_steps = int(final_time / dt)
125print("Total steps", total_steps)
126ea.integrate(timestepper, butterfly_sim, final_time, total_steps)
127
128if PLOT_FIGURE:
129    # Plot the histories
130    fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
131    ax = fig.add_subplot(111)
132    positions = recorded_history["position"]
133    # record first position
134    first_position = positions.pop(0)
135    ax.plot(first_position[2, ...], first_position[0, ...], "r--", lw=2.0)
136    n_positions = len(positions)
137    for i, pos in enumerate(positions):
138        alpha = np.exp(i / n_positions - 1)
139        ax.plot(pos[2, ...], pos[0, ...], "b", lw=0.6, alpha=alpha)
140    # final position is also separate
141    last_position = positions.pop()
142    ax.plot(last_position[2, ...], last_position[0, ...], "k--", lw=2.0)
143    # don't block
144    fig.show()
145
146    # Plot the energies
147    energy_fig = plt.figure(figsize=(5, 4), frameon=True, dpi=150)
148    energy_ax = energy_fig.add_subplot(111)
149    times = np.asarray(recorded_history["time"])
150    te = np.asarray(recorded_history["te"])
151    re = np.asarray(recorded_history["re"])
152    be = np.asarray(recorded_history["be"])
153    se = np.asarray(recorded_history["se"])
154
155    energy_ax.plot(times, te, c=to_rgb("xkcd:reddish"), lw=2.0, label="Translations")
156    energy_ax.plot(times, re, c=to_rgb("xkcd:bluish"), lw=2.0, label="Rotation")
157    energy_ax.plot(times, be, c=to_rgb("xkcd:burple"), lw=2.0, label="Bend")
158    energy_ax.plot(times, se, c=to_rgb("xkcd:goldenrod"), lw=2.0, label="Shear")
159    energy_ax.plot(times, te + re + be + se, c="k", lw=2.0, label="Total energy")
160    energy_ax.legend()
161    # don't block
162    energy_fig.show()
163
164    if SAVE_FIGURE:
165        fig.savefig("butterfly.png")
166        energy_fig.savefig("energies.png")
167
168    plt.show()
169
170if SAVE_RESULTS:
171    import pickle
172
173    filename = "butterfly_data.dat"
174    file = open(filename, "wb")
175    pickle.dump(butterfly_rod, file)
176    file.close()

Helical Buckling#

  1__doc__ = """Helical buckling validation case, for detailed explanation refer to
  2Gazzola et. al. R. Soc. 2018  section 3.4.1 """
  3
  4import numpy as np
  5import elastica as ea
  6from examples.HelicalBucklingCase.helicalbuckling_postprocessing import (
  7    plot_helicalbuckling,
  8)
  9
 10
 11class HelicalBucklingSimulator(
 12    ea.BaseSystemCollection, ea.Constraints, ea.Damping, ea.Forcing
 13):
 14    pass
 15
 16
 17helicalbuckling_sim = HelicalBucklingSimulator()
 18
 19# Options
 20PLOT_FIGURE = True
 21SAVE_FIGURE = True
 22SAVE_RESULTS = False
 23
 24# setting up test params
 25n_elem = 100
 26start = np.zeros((3,))
 27direction = np.array([0.0, 0.0, 1.0])
 28normal = np.array([0.0, 1.0, 0.0])
 29base_length = 100.0
 30base_radius = 0.35
 31base_area = np.pi * base_radius**2
 32density = 1.0 / (base_area)
 33nu = 0.01 / density / base_area
 34E = 1e6
 35slack = 3
 36number_of_rotations = 27
 37# For shear modulus of 1e5, nu is 99!
 38poisson_ratio = 9
 39shear_modulus = E / (poisson_ratio + 1.0)
 40shear_matrix = np.repeat(
 41    shear_modulus * np.identity((3))[:, :, np.newaxis], n_elem, axis=2
 42)
 43temp_bend_matrix = np.zeros((3, 3))
 44np.fill_diagonal(temp_bend_matrix, [1.345, 1.345, 0.789])
 45bend_matrix = np.repeat(temp_bend_matrix[:, :, np.newaxis], n_elem - 1, axis=2)
 46
 47shearable_rod = ea.CosseratRod.straight_rod(
 48    n_elem,
 49    start,
 50    direction,
 51    normal,
 52    base_length,
 53    base_radius,
 54    density,
 55    youngs_modulus=E,
 56    shear_modulus=shear_modulus,
 57)
 58# TODO: CosseratRod has to be able to take shear matrix as input, we should change it as done below
 59
 60shearable_rod.shear_matrix = shear_matrix
 61shearable_rod.bend_matrix = bend_matrix
 62
 63
 64helicalbuckling_sim.append(shearable_rod)
 65# add damping
 66dl = base_length / n_elem
 67dt = 1e-3 * dl
 68helicalbuckling_sim.dampen(shearable_rod).using(
 69    ea.AnalyticalLinearDamper,
 70    damping_constant=nu,
 71    time_step=dt,
 72)
 73
 74helicalbuckling_sim.constrain(shearable_rod).using(
 75    ea.HelicalBucklingBC,
 76    constrained_position_idx=(0, -1),
 77    constrained_director_idx=(0, -1),
 78    twisting_time=500,
 79    slack=slack,
 80    number_of_rotations=number_of_rotations,
 81)
 82
 83helicalbuckling_sim.finalize()
 84timestepper = ea.PositionVerlet()
 85shearable_rod.velocity_collection[..., int((n_elem) / 2)] += np.array([0, 1e-6, 0.0])
 86# timestepper = PEFRL()
 87
 88final_time = 10500.0
 89total_steps = int(final_time / dt)
 90print("Total steps", total_steps)
 91ea.integrate(timestepper, helicalbuckling_sim, final_time, total_steps)
 92
 93if PLOT_FIGURE:
 94    plot_helicalbuckling(shearable_rod, SAVE_FIGURE)
 95
 96if SAVE_RESULTS:
 97    import pickle
 98
 99    filename = "HelicalBuckling_data.dat"
100    file = open(filename, "wb")
101    pickle.dump(shearable_rod, file)
102    file.close()

Continuum Snake#

  1__doc__ = """Snake friction case from X. Zhang et. al. Nat. Comm. 2021"""
  2
  3import os
  4import numpy as np
  5import elastica as ea
  6
  7from examples.ContinuumSnakeCase.continuum_snake_postprocessing import (
  8    plot_snake_velocity,
  9    plot_video,
 10    compute_projected_velocity,
 11    plot_curvature,
 12)
 13
 14
 15class SnakeSimulator(
 16    ea.BaseSystemCollection,
 17    ea.Constraints,
 18    ea.Forcing,
 19    ea.Damping,
 20    ea.CallBacks,
 21    ea.Contact,
 22):
 23    pass
 24
 25
 26def run_snake(
 27    b_coeff, PLOT_FIGURE=False, SAVE_FIGURE=False, SAVE_VIDEO=False, SAVE_RESULTS=False
 28):
 29    # Initialize the simulation class
 30    snake_sim = SnakeSimulator()
 31
 32    # Simulation parameters
 33    period = 2
 34    final_time = (11.0 + 0.01) * period
 35
 36    # setting up test params
 37    n_elem = 50
 38    start = np.zeros((3,))
 39    direction = np.array([0.0, 0.0, 1.0])
 40    normal = np.array([0.0, 1.0, 0.0])
 41    base_length = 0.35
 42    base_radius = base_length * 0.011
 43    density = 1000
 44    E = 1e6
 45    poisson_ratio = 0.5
 46    shear_modulus = E / (poisson_ratio + 1.0)
 47
 48    shearable_rod = ea.CosseratRod.straight_rod(
 49        n_elem,
 50        start,
 51        direction,
 52        normal,
 53        base_length,
 54        base_radius,
 55        density,
 56        youngs_modulus=E,
 57        shear_modulus=shear_modulus,
 58    )
 59
 60    snake_sim.append(shearable_rod)
 61
 62    # Add gravitational forces
 63    gravitational_acc = -9.80665
 64    snake_sim.add_forcing_to(shearable_rod).using(
 65        ea.GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])
 66    )
 67
 68    # Add muscle torques
 69    wave_length = b_coeff[-1]
 70    snake_sim.add_forcing_to(shearable_rod).using(
 71        ea.MuscleTorques,
 72        base_length=base_length,
 73        b_coeff=b_coeff[:-1],
 74        period=period,
 75        wave_number=2.0 * np.pi / (wave_length),
 76        phase_shift=0.0,
 77        rest_lengths=shearable_rod.rest_lengths,
 78        ramp_up_time=period,
 79        direction=normal,
 80        with_spline=True,
 81    )
 82
 83    # Add friction forces
 84    ground_plane = ea.Plane(
 85        plane_origin=np.array([0.0, -base_radius, 0.0]), plane_normal=normal
 86    )
 87    snake_sim.append(ground_plane)
 88    slip_velocity_tol = 1e-8
 89    froude = 0.1
 90    mu = base_length / (period * period * np.abs(gravitational_acc) * froude)
 91    kinetic_mu_array = np.array(
 92        [mu, 1.5 * mu, 2.0 * mu]
 93    )  # [forward, backward, sideways]
 94    static_mu_array = np.zeros(kinetic_mu_array.shape)
 95    snake_sim.detect_contact_between(shearable_rod, ground_plane).using(
 96        ea.RodPlaneContactWithAnisotropicFriction,
 97        k=1.0,
 98        nu=1e-6,
 99        slip_velocity_tol=slip_velocity_tol,
100        static_mu_array=static_mu_array,
101        kinetic_mu_array=kinetic_mu_array,
102    )
103
104    # add damping
105    damping_constant = 2e-3
106    time_step = 1e-4
107    snake_sim.dampen(shearable_rod).using(
108        ea.AnalyticalLinearDamper,
109        damping_constant=damping_constant,
110        time_step=time_step,
111    )
112
113    total_steps = int(final_time / time_step)
114    rendering_fps = 60
115    step_skip = int(1.0 / (rendering_fps * time_step))
116
117    # Add call backs
118    class ContinuumSnakeCallBack(ea.CallBackBaseClass):
119        """
120        Call back function for continuum snake
121        """
122
123        def __init__(self, step_skip: int, callback_params: dict):
124            ea.CallBackBaseClass.__init__(self)
125            self.every = step_skip
126            self.callback_params = callback_params
127
128        def make_callback(self, system, time, current_step: int):
129
130            if current_step % self.every == 0:
131
132                self.callback_params["time"].append(time)
133                self.callback_params["step"].append(current_step)
134                self.callback_params["position"].append(
135                    system.position_collection.copy()
136                )
137                self.callback_params["velocity"].append(
138                    system.velocity_collection.copy()
139                )
140                self.callback_params["avg_velocity"].append(
141                    system.compute_velocity_center_of_mass()
142                )
143
144                self.callback_params["center_of_mass"].append(
145                    system.compute_position_center_of_mass()
146                )
147                self.callback_params["curvature"].append(system.kappa.copy())
148
149                return
150
151    pp_list = ea.defaultdict(list)
152    snake_sim.collect_diagnostics(shearable_rod).using(
153        ContinuumSnakeCallBack, step_skip=step_skip, callback_params=pp_list
154    )
155
156    snake_sim.finalize()
157
158    timestepper = ea.PositionVerlet()
159    ea.integrate(timestepper, snake_sim, final_time, total_steps)
160
161    if PLOT_FIGURE:
162        filename_plot = "continuum_snake_velocity.png"
163        plot_snake_velocity(pp_list, period, filename_plot, SAVE_FIGURE)
164        plot_curvature(pp_list, shearable_rod.rest_lengths, period, SAVE_FIGURE)
165
166        if SAVE_VIDEO:
167            filename_video = "continuum_snake.mp4"
168            plot_video(
169                pp_list,
170                video_name=filename_video,
171                fps=rendering_fps,
172                xlim=(0, 4),
173                ylim=(-1, 1),
174            )
175
176    if SAVE_RESULTS:
177        import pickle
178
179        filename = "continuum_snake.dat"
180        file = open(filename, "wb")
181        pickle.dump(pp_list, file)
182        file.close()
183
184    # Compute the average forward velocity. These will be used for optimization.
185    [_, _, avg_forward, avg_lateral] = compute_projected_velocity(pp_list, period)
186
187    return avg_forward, avg_lateral, pp_list
188
189
190if __name__ == "__main__":
191
192    # Options
193    PLOT_FIGURE = True
194    SAVE_FIGURE = True
195    SAVE_VIDEO = True
196    SAVE_RESULTS = False
197    CMA_OPTION = False
198
199    if CMA_OPTION:
200        import cma
201
202        SAVE_OPTIMIZED_COEFFICIENTS = False
203
204        def optimize_snake(spline_coefficient):
205            [avg_forward, _, _] = run_snake(
206                spline_coefficient,
207                PLOT_FIGURE=False,
208                SAVE_FIGURE=False,
209                SAVE_VIDEO=False,
210                SAVE_RESULTS=False,
211            )
212            return -avg_forward
213
214        # Optimize snake for forward velocity. In cma.fmin first input is function
215        # to be optimized, second input is initial guess for coefficients you are optimizing
216        # for and third input is standard deviation you initially set.
217        optimized_spline_coefficients = cma.fmin(optimize_snake, 7 * [0], 0.5)
218
219        # Save the optimized coefficients to a file
220        filename_data = "optimized_coefficients.txt"
221        if SAVE_OPTIMIZED_COEFFICIENTS:
222            assert filename_data != "", "provide a file name for coefficients"
223            np.savetxt(filename_data, optimized_spline_coefficients, delimiter=",")
224
225    else:
226        # Add muscle forces on the rod
227        if os.path.exists("optimized_coefficients.txt"):
228            t_coeff_optimized = np.genfromtxt(
229                "optimized_coefficients.txt", delimiter=","
230            )
231        else:
232            wave_length = 1.0
233            t_coeff_optimized = np.array(
234                [3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3]
235            )
236            t_coeff_optimized = np.hstack((t_coeff_optimized, wave_length))
237
238        # run the simulation
239        [avg_forward, avg_lateral, pp_list] = run_snake(
240            t_coeff_optimized, PLOT_FIGURE, SAVE_FIGURE, SAVE_VIDEO, SAVE_RESULTS
241        )
242
243        print("average forward velocity:", avg_forward)
244        print("average forward lateral:", avg_lateral)