import os
import warnings
from pathlib import Path
from copy import deepcopy
from tempfile import mkdtemp
from itertools import product
import mujoco
from dm_control import mjcf
from mushroom_rl.core import Environment
from mushroom_rl.environments import MultiMuJoCo
from mushroom_rl.utils import spaces
from mushroom_rl.utils.running_stats import *
from mushroom_rl.utils.mujoco import *
from mushroom_rl.utils.record import VideoRecorder
import loco_mujoco
from loco_mujoco.utils import Trajectory
from loco_mujoco.utils import NoReward, CustomReward,\
TargetVelocityReward, PosReward, DomainRandomizationHandler
[docs]
class LocoEnv(MultiMuJoCo):
"""
Base class for all kinds of locomotion environments.
"""
def __init__(self, xml_handles, action_spec, observation_spec, collision_groups=None, gamma=0.99, horizon=1000,
n_substeps=10, reward_type=None, reward_params=None, traj_params=None, random_start=True,
init_step_no=None, timestep=0.001, use_foot_forces=False, default_camera_mode="follow",
use_absorbing_states=True, domain_randomization_config=None, parallel_dom_rand=True,
N_worker_per_xml_dom_rand=4, **viewer_params):
"""
Constructor.
Args:
xml_handle : MuJoCo xml handle.
actuation_spec (list): A list specifying the names of the joints
which should be controllable by the agent. Can be left empty
when all actuators should be used;
observation_spec (list): A list containing the names of data that
should be made available to the agent as an observation and
their type (ObservationType). They are combined with a key,
which is used to access the data. An entry in the list
is given by: (key, name, type). The name can later be used
to retrieve specific observations;
collision_groups (list, None): A list containing groups of geoms for
which collisions should be checked during simulation via
``check_collision``. The entries are given as:
``(key, geom_names)``, where key is a string for later
referencing in the "check_collision" method, and geom_names is
a list of geom names in the XML specification;
gamma (float): The discounting factor of the environment;
horizon (int): The maximum horizon for the environment;
n_substeps (int): The number of substeps to use by the MuJoCo
simulator. An action given by the agent will be applied for
n_substeps before the agent receives the next observation and
can act accordingly;
reward_type (string): Type of reward function to be used.
reward_params (dict): Dictionary of parameters corresponding to
the chosen reward function;
traj_params (dict): Dictionrary of parameters to construct trajectories.
random_start (bool): If True, a random sample from the trajectories
is chosen at the beginning of each time step and initializes the
simulation according to that. This requires traj_params to be passed!
init_step_no (int): If set, the respective sample from the trajectories
is taken to initialize the simulation;
timestep (float): The timestep used by the MuJoCo simulator. If None, the
default timestep specified in the XML will be used;
use_foot_forces (bool): If True, foot forces are computed and added to
the observation space;
default_camera_mode (str): String defining the default camera mode. Available modes are "static",
"follow", and "top_static".
use_absorbing_states (bool): If True, absorbing states are defined for each environment. This means
that episodes can terminate earlier.
domain_randomization_config (str): Path to the domain/dynamics randomization config file.
parallel_dom_rand (bool): If True and a domain_randomization_config file is passed, the domain
randomization will run in parallel to speed up simulation run-time.
N_worker_per_xml_dom_rand (int): Number of workers used per xml-file for parallel domain randomization.
If parallel is set to True, this number has to be greater 1.
"""
if type(xml_handles) != list:
xml_handles = [xml_handles]
if collision_groups is None:
collision_groups = list()
if use_foot_forces:
n_intermediate_steps = n_substeps
n_substeps = 1
else:
n_intermediate_steps = 1
if "geom_group_visualization_on_startup" not in viewer_params.keys():
viewer_params["geom_group_visualization_on_startup"] = [0, 2] # enable robot geom [0] and floor visual [2]
if domain_randomization_config is not None:
self._domain_rand = DomainRandomizationHandler(xml_handles, domain_randomization_config, parallel_dom_rand,
N_worker_per_xml_dom_rand)
else:
self._domain_rand = None
super().__init__(xml_handles, action_spec, observation_spec, gamma=gamma, horizon=horizon,
n_substeps=n_substeps, n_intermediate_steps=n_intermediate_steps, timestep=timestep,
collision_groups=collision_groups, default_camera_mode=default_camera_mode, **viewer_params)
# specify reward function
self._reward_function = self._get_reward_function(reward_type, reward_params)
# optionally use foot forces in the observation space
self._use_foot_forces = use_foot_forces
self.info.observation_space = spaces.Box(*self._get_observation_space())
# the action space is supposed to be between -1 and 1, so we normalize it
low, high = self.info.action_space.low.copy(), self.info.action_space.high.copy()
self.norm_act_mean = (high + low) / 2.0
self.norm_act_delta = (high - low) / 2.0
self.info.action_space.low[:] = -1.0
self.info.action_space.high[:] = 1.0
# setup a running average window for the mean ground forces
self.mean_grf = self._setup_ground_force_statistics()
# dataset dummy
self._dataset= None
if traj_params:
self.trajectories = None
self.load_trajectory(traj_params)
else:
self.trajectories = None
self._random_start = random_start
self._init_step_no = init_step_no
self._use_absorbing_states = use_absorbing_states
[docs]
def load_trajectory(self, traj_params, warn=True):
"""
Loads trajectories. If there were trajectories loaded already, this function overrides the latter.
Args:
traj_params (dict): Dictionary of parameters needed to load trajectories.
warn (bool): If True, a warning will be raised if the
trajectory ranges are violated.
"""
if self.trajectories is not None:
warnings.warn("New trajectories loaded, which overrides the old ones.", RuntimeWarning)
self.trajectories = Trajectory(keys=self.get_all_observation_keys(),
low=self.info.observation_space.low,
high=self.info.observation_space.high,
joint_pos_idx=self.obs_helper.joint_pos_idx,
interpolate_map=self._interpolate_map,
interpolate_remap=self._interpolate_remap,
interpolate_map_params=self._get_interpolate_map_params(),
interpolate_remap_params=self._get_interpolate_remap_params(),
warn=warn,
**traj_params)
[docs]
def reward(self, state, action, next_state, absorbing):
"""
Calls the reward function of the environment.
"""
return self._reward_function(state, action, next_state, absorbing)
[docs]
def reset(self, obs=None):
mujoco.mj_resetData(self._model, self._data)
self.mean_grf.reset()
if self._domain_rand is not None:
self._models[self._current_model_idx] = self._domain_rand.get_randomized_model(self._current_model_idx)
self._datas[self._current_model_idx] = mujoco.MjData(self._models[self._current_model_idx])
if self._random_env_reset:
self._current_model_idx = np.random.randint(0, len(self._models))
else:
self._current_model_idx = self._current_model_idx + 1 \
if self._current_model_idx < len(self._models) - 1 else 0
self._model = self._models[self._current_model_idx]
self._data = self._datas[self._current_model_idx]
self.obs_helper = self.obs_helpers[self._current_model_idx]
self.setup(obs)
if self._viewer is not None and self.more_than_one_env:
self._viewer.load_new_model(self._model)
self._obs = self._create_observation(self.obs_helper._build_obs(self._data))
return self._modify_observation(self._obs)
[docs]
def setup(self, obs):
"""
Function to setup the initial state of the simulation. Initialization can be done either
randomly, from a certain initial, or from the default initial state of the model.
Args:
obs (np.array): Observation to initialize the environment from;
"""
self._reward_function.reset_state()
if obs is not None:
self._init_sim_from_obs(obs)
else:
if not self.trajectories and self._random_start:
raise ValueError("Random start not possible without trajectory data.")
elif not self.trajectories and self._init_step_no is not None:
raise ValueError("Setting an initial step is not possible without trajectory data.")
elif self._init_step_no is not None and self._random_start:
raise ValueError("Either use a random start or set an initial step, not both.")
if self.trajectories is not None:
if self._random_start:
sample = self.trajectories.reset_trajectory()
elif self._init_step_no:
traj_len = self.trajectories.trajectory_length
n_traj = self.trajectories.number_of_trajectories
assert self._init_step_no <= traj_len * n_traj
substep_no = int(self._init_step_no % traj_len)
traj_no = int(self._init_step_no / traj_len)
sample = self.trajectories.reset_trajectory(substep_no, traj_no)
else:
# sample random trajectory and use the first sample
sample = self.trajectories.reset_trajectory(substep_no=0)
self.set_sim_state(sample)
[docs]
def is_absorbing(self, obs):
"""
Checks if an observation is an absorbing state or not.
Args:
obs (np.array): Current observation;
Returns:
True, if the observation is an absorbing state; otherwise False;
"""
return self._has_fallen(obs) if self._use_absorbing_states else False
[docs]
def get_kinematic_obs_mask(self):
"""
Returns a mask (np.array) for the observation specified in observation_spec (or part of it).
"""
return np.arange(len(self.obs_helper.observation_spec) - 2)
[docs]
def get_obs_idx(self, key):
"""
Returns a list of indices corresponding to the respective key.
"""
idx = self.obs_helper.obs_idx_map[key]
# shift by 2 to account for deleted x and y
idx = [i-2 for i in idx]
return idx
[docs]
def create_dataset(self, ignore_keys=None):
"""
Creates a dataset from the specified trajectories.
Args:
ignore_keys (list): List of keys to ignore in the dataset.
Returns:
Dictionary containing states, next_states and absorbing flags. For the states the shape is
(N_traj x N_samples_per_traj, dim_state), while the absorbing flag has the shape is
(N_traj x N_samples_per_traj). For perfect and preference datasets, the actions are also provided.
"""
if self._dataset is None:
if self.trajectories is not None:
dataset = self.trajectories.create_dataset(ignore_keys=ignore_keys)
# check that all state in the dataset satisfy the has fallen method.
for state in dataset["states"]:
has_fallen, msg = self._has_fallen(state, return_err_msg=True)
if has_fallen:
err_msg = "Some of the states in the created dataset are terminal states. " \
"This should not happen.\n\nViolations:\n"
err_msg += msg
raise ValueError(err_msg)
else:
raise ValueError("No trajectory was passed to the environment. "
"To create a dataset pass a trajectory first.")
self._dataset = deepcopy(dataset)
return dataset
else:
return deepcopy(self._dataset)
[docs]
def play_trajectory(self, n_episodes=None, n_steps_per_episode=None, render=True,
record=False, recorder_params=None):
"""
Plays a demo of the loaded trajectory by forcing the model
positions to the ones in the trajectories at every step.
Args:
n_episodes (int): Number of episode to replay.
n_steps_per_episode (int): Number of steps to replay per episode.
render (bool): If True, trajectory will be rendered.
record (bool): If True, the rendered trajectory will be recorded.
recorder_params (dict): Dictionary containing the recorder parameters.
"""
assert self.trajectories is not None
if record:
assert render
fps = 1/self.dt
recorder = VideoRecorder(fps=fps, **recorder_params) if recorder_params is not None else\
VideoRecorder(fps=fps)
else:
recorder = None
self.reset()
sample = self.trajectories.get_current_sample()
self.set_sim_state(sample)
if render:
frame = self.render(record)
else:
frame = None
if record:
recorder(frame)
highest_int = np.iinfo(np.int32).max
if n_steps_per_episode is None:
n_steps_per_episode = highest_int
if n_episodes is None:
n_episodes = highest_int
for i in range(n_episodes):
for j in range(n_steps_per_episode):
self.set_sim_state(sample)
self._simulation_pre_step()
mujoco.mj_forward(self._model, self._data)
self._simulation_post_step()
sample = self.trajectories.get_next_sample()
if sample is None:
self.reset()
sample = self.trajectories.get_current_sample()
obs = self._create_observation(np.concatenate(sample))
if self._has_fallen(obs):
print("Has fallen!")
if render:
frame = self.render(record)
else:
frame = None
if record:
recorder(frame)
self.reset()
self.stop()
if record:
recorder.stop()
[docs]
def play_trajectory_from_velocity(self, n_episodes=None, n_steps_per_episode=None, render=True,
record=False, recorder_params=None):
"""
Plays a demo of the loaded trajectory by forcing the model
positions to the ones calculated from the joint velocities
in the trajectories at every step. Therefore, the joint positions
are set from the trajectory in the first step. Afterwards, numerical
integration is used to calculate the next joint positions using
the joint velocities in the trajectory.
Args:
n_episodes (int): Number of episode to replay.
n_steps_per_episode (int): Number of steps to replay per episode.
render (bool): If True, trajectory will be rendered.
record (bool): If True, the replay will be recorded.
recorder_params (dict): Dictionary containing the recorder parameters.
"""
assert self.trajectories is not None
if record:
assert render
fps = 1/self.dt
recorder = VideoRecorder(fps=fps, **recorder_params) if recorder_params is not None else\
VideoRecorder(fps=fps)
else:
recorder = None
self.reset()
sample = self.trajectories.get_current_sample()
self.set_sim_state(sample)
if render:
frame = self.render(record)
else:
frame = None
if record:
recorder(frame)
highest_int = np.iinfo(np.int32).max
if n_steps_per_episode is None:
n_steps_per_episode = highest_int
if n_episodes is None:
n_episodes = highest_int
len_qpos, len_qvel = self._len_qpos_qvel()
curr_qpos = sample[0:len_qpos]
for i in range(n_episodes):
for j in range(n_steps_per_episode):
qvel = sample[len_qpos:len_qpos + len_qvel]
qpos = [qp + self.dt * qv for qp, qv in zip(curr_qpos, qvel)]
sample[:len(qpos)] = qpos
self.set_sim_state(sample)
self._simulation_pre_step()
mujoco.mj_forward(self._model, self._data)
self._simulation_post_step()
# get current qpos
curr_qpos = self._get_joint_pos()
sample = self.trajectories.get_next_sample()
if sample is None:
self.reset()
sample = self.trajectories.get_current_sample()
curr_qpos = sample[0:len_qpos]
obs = self._create_observation(np.concatenate(sample))
if self._has_fallen(obs):
print("Has fallen!")
if render:
frame = self.render(record)
else:
frame = None
if record:
recorder(frame)
self.reset()
# get current qpos
curr_qpos = self._get_joint_pos()
self.stop()
if record:
recorder.stop()
[docs]
def set_sim_state(self, sample):
"""
Sets the state of the simulation according to an observation.
Args:
sample (list or np.array): Sample used to set the state of the simulation.
"""
obs_spec = self.obs_helper.observation_spec
assert len(sample) == len(obs_spec)
for key_name_ot, value in zip(obs_spec, sample):
key, name, ot = key_name_ot
if ot == ObservationType.JOINT_POS:
self._data.joint(name).qpos = value
elif ot == ObservationType.JOINT_VEL:
self._data.joint(name).qvel = value
elif ot == ObservationType.SITE_ROT:
self._data.site(name).xmat = value
[docs]
def load_dataset_and_get_traj_files(self, dataset_path, freq=None):
"""
Calculates a dictionary containing the kinematics given a dataset. If freq is provided,
the x and z positions are calculated based on the velocity.
Args:
dataset_path (str): Path to the dataset.
freq (float): Frequency of the data in obs.
Returns:
Dictionary containing the keys specified in observation_spec with the corresponding
values from the dataset.
"""
dataset = np.load(str(Path(loco_mujoco.__file__).resolve().parent / dataset_path))
self._dataset = deepcopy({k: d for k, d in dataset.items()})
states = dataset["states"]
last = dataset["last"]
states = np.atleast_2d(states)
rel_keys = [obs_spec[0] for obs_spec in self.obs_helper.observation_spec]
num_data = len(states)
trajectories = dict()
for i, key in enumerate(rel_keys):
if i < 2:
if freq is None:
# fill with zeros for x and y position
data = np.zeros(num_data)
else:
# compute positions from velocities
dt = 1 / float(freq)
assert len(states) > 2
vel_idx = rel_keys.index("d" + key) - 2
data = [0.0]
for j, o in enumerate(states[:-1, vel_idx], 1):
if last is not None and last[j - 1] == 1:
data.append(0.0)
else:
data.append(data[-1] + dt * o)
data = np.array(data)
else:
data = states[:, i - 2]
trajectories[key] = data
# add split points
if len(states) > 2:
trajectories["split_points"] = np.concatenate([[0], np.squeeze(np.argwhere(last == 1) + 1)])
return trajectories
def _get_observation_space(self):
"""
Returns a tuple of the lows and highs (np.array) of the observation space.
"""
sim_low, sim_high = (self.info.observation_space.low[2:],
self.info.observation_space.high[2:])
if self._use_foot_forces:
grf_low, grf_high = (-np.ones((self._get_grf_size(),)) * np.inf,
np.ones((self._get_grf_size(),)) * np.inf)
return (np.concatenate([sim_low, grf_low]),
np.concatenate([sim_high, grf_high]))
else:
return sim_low, sim_high
def _create_observation(self, obs):
"""
Creates a full vector of observations.
Args:
obs (np.array): Observation vector to be modified or extended;
Returns:
New observation vector (np.array);
"""
if self._use_foot_forces:
obs = np.concatenate([obs[2:],
self.mean_grf.mean / 1000.,
]).flatten()
else:
obs = np.concatenate([obs[2:],
]).flatten()
return obs
def _preprocess_action(self, action):
"""
This function preprocesses all actions. All actions in this environment expected to be between -1 and 1.
Hence, we need to unnormalize the action to send to correct action to the simulation.
Note: If the action is not in [-1, 1], the unnormalized version will be clipped in Mujoco.
Args:
action (np.array): Action to be send to the environment;
Returns:
Unnormalized action (np.array) that is send to the environment;
"""
unnormalized_action = ((action.copy() * self.norm_act_delta) + self.norm_act_mean)
return unnormalized_action
def _simulation_post_step(self):
"""
Update the ground forces statistics if needed.
"""
if self._use_foot_forces:
grf = self._get_ground_forces()
self.mean_grf.update_stats(grf)
def _init_sim_from_obs(self, obs):
"""
Initializes the simulation from an observation.
Args:
obs (np.array): The observation to set the simulation state to.
"""
assert len(obs.shape) == 1
# append x and y pos
obs = np.concatenate([[0.0, 0.0], obs])
obs_spec = self.obs_helper.observation_spec
assert len(obs) >= len(obs_spec)
# remove anything added to obs that is not in obs_spec
obs = obs[:len(obs_spec)]
# set state
self.set_sim_state(obs)
def _setup_ground_force_statistics(self):
"""
Returns a running average method for the mean ground forces. By default, 4 ground force sensors are used.
Environments that use more or less have to override this function.
"""
mean_grf = RunningAveragedWindow(shape=(self._get_grf_size(),), window_size=self._n_intermediate_steps)
return mean_grf
def _get_ground_forces(self):
"""
Returns the ground forces (np.array). By default, 4 ground force sensors are used.
Environments that use more or less have to override this function.
"""
grf = np.concatenate([self._get_collision_force("floor", "foot_r")[:3],
self._get_collision_force("floor", "front_foot_r")[:3],
self._get_collision_force("floor", "foot_l")[:3],
self._get_collision_force("floor", "front_foot_l")[:3]])
return grf
def _get_reward_function(self, reward_type, reward_params):
"""
Constructs a reward function.
Args:
reward_type (string): Name of the reward.
reward_params (dict): Parameters of the reward function.
Returns:
Reward function.
"""
if reward_type == "custom":
reward_func = CustomReward(**reward_params)
elif reward_type == "target_velocity":
x_vel_idx = self.get_obs_idx("dq_pelvis_tx")
assert len(x_vel_idx) == 1
x_vel_idx = x_vel_idx[0]
reward_func = TargetVelocityReward(x_vel_idx=x_vel_idx, **reward_params)
elif reward_type == "x_pos":
x_idx = self.get_obs_idx("q_pelvis_tx")
assert len(x_idx) == 1
x_idx = x_idx[0]
reward_func = PosReward(pos_idx=x_idx)
elif reward_type is None:
reward_func = NoReward()
else:
raise NotImplementedError("The specified reward has not been implemented: %s" % reward_type)
return reward_func
def _get_joint_pos(self):
"""
Returns a vector (np.array) containing the current joint position of the model in the simulation.
"""
return self.obs_helper.get_joint_pos_from_obs(self.obs_helper._build_obs(self._data))
def _get_joint_vel(self):
"""
Returns a vector (np.array) containing the current joint velocities of the model in the simulation.
"""
return self.obs_helper.get_joint_vel_from_obs(self.obs_helper._build_obs(self._data))
def _get_from_obs(self, obs, keys):
"""
Returns a part of the observation based on the specified keys.
Args:
obs (np.array): Observation array.
keys (list or str): List of keys or just one key which are
used to extract entries from the observation.
Returns:
np.array including the parts of the original observation whose
keys were specified.
"""
# obs has removed x and y positions, add dummy entries
obs = np.concatenate([[0.0, 0.0], obs])
if type(keys) != list:
assert type(keys) == str
keys = list(keys)
entries = []
for key in keys:
entries.append(self.obs_helper.get_from_obs(obs, key))
return np.concatenate(entries)
def _get_idx(self, keys):
"""
Returns the indices of the specified keys.
Args:
keys (list or str): List of keys or just one key which are
used to get the indices from the observation space.
Returns:
np.array including the indices of the specified keys.
"""
if type(keys) != list:
assert type(keys) == str
keys = [keys]
entries = []
for key in keys:
entries.append(self.obs_helper.obs_idx_map[key])
return np.concatenate(entries) - 2
def _len_qpos_qvel(self):
"""
Returns the lengths of the joint position vector and the joint velocity vector, including x and y.
"""
keys = self.get_all_observation_keys()
len_qpos = len([key for key in keys if key.startswith("q_")])
len_qvel = len([key for key in keys if key.startswith("dq_")])
return len_qpos, len_qvel
def _has_fallen(self, obs, return_err_msg=False):
"""
Checks if a model has fallen. This has to be implemented for each environment.
Args:
obs (np.array): Current observation.
return_err_msg (bool): If True, an error message with violations is returned.
Returns:
True, if the model has fallen for the current observation, False otherwise.
"""
raise NotImplementedError
def _get_interpolate_map_params(self):
"""
Returns all parameters needed to do the interpolation mapping for the respective environment.
"""
pass
def _get_interpolate_remap_params(self):
"""
Returns all parameters needed to do the interpolation remapping for the respective environment.
"""
pass
[docs]
@classmethod
def register(cls):
"""
Register an environment in the environment list and in the loco_mujoco env list.
"""
env_name = cls.__name__
if env_name not in Environment._registered_envs:
Environment._registered_envs[env_name] = cls
if env_name not in LocoEnv._registered_envs:
LocoEnv._registered_envs[env_name] = cls
@staticmethod
def _get_grf_size():
"""
Returns the size of the ground force vector.
"""
return 12
[docs]
@staticmethod
def list_registered_loco_mujoco():
"""
List registered loco_mujoco environments.
Returns:
The list of the registered loco_mujoco environments.
"""
return list(LocoEnv._registered_envs.keys())
@staticmethod
def _interpolate_map(traj, **interpolate_map_params):
"""
A mapping that is supposed to transform a trajectory into a space where interpolation is
allowed. E.g., maps a rotation matrix to a set of angles. If this function is not
overwritten, it just converts the list of np.arrays to a np.array.
Args:
traj (list): List of np.arrays containing each observations. Each np.array
has the shape (n_trajectories, n_samples, (dim_observation)). If dim_observation
is one the shape of the array is just (n_trajectories, n_samples).
interpolate_map_params: Set of parameters needed by the individual environments.
Returns:
A np.array with shape (n_observations, n_trajectories, n_samples). dim_observation
has to be one.
"""
return np.array(traj)
@staticmethod
def _interpolate_remap(traj, **interpolate_remap_params):
"""
The corresponding backwards transformation to _interpolation_map. If this function is
not overwritten, it just converts the np.array to a list of np.arrays.
Args:
traj (np.array): Trajectory as np.array with shape (n_observations, n_trajectories, n_samples).
dim_observation is one.
interpolate_remap_params: Set of parameters needed by the individual environments.
Returns:
List of np.arrays containing each observations. Each np.array has the shape
(n_trajectories, n_samples, (dim_observation)). If dim_observation
is one the shape of the array is just (n_trajectories, n_samples).
"""
return [obs for obs in traj]
@staticmethod
def _delete_from_xml_handle(xml_handle, joints_to_remove, motors_to_remove, equ_constraints):
"""
Deletes certain joints, motors and equality constraints from a Mujoco XML handle.
Args:
xml_handle: Handle to Mujoco XML.
joints_to_remove (list): List of joint names to remove.
motors_to_remove (list): List of motor names to remove.
equ_constraints (list): List of equality constraint names to remove.
Returns:
Modified Mujoco XML handle.
"""
for j in joints_to_remove:
j_handle = xml_handle.find("joint", j)
j_handle.remove()
for m in motors_to_remove:
m_handle = xml_handle.find("actuator", m)
m_handle.remove()
for e in equ_constraints:
e_handle = xml_handle.find("equality", e)
e_handle.remove()
return xml_handle
@staticmethod
def _save_xml_handle(xml_handle, tmp_dir_name, file_name="tmp_model.xml"):
"""
Save the Mujoco XML handle to a file at tmp_dir_name. If tmp_dir_name is None,
a temporary directory is created at /tmp.
Args:
xml_handle: Mujoco XML handle.
tmp_dir_name (str): Path to temporary directory. If None, a
temporary directory is created at /tmp.
Returns:
String of the save path.
"""
if tmp_dir_name is not None:
assert os.path.exists(tmp_dir_name), "specified directory (\"%s\") does not exist." % tmp_dir_name
dir = mkdtemp(dir=tmp_dir_name)
file_path = os.path.join(dir, file_name)
# dump data
mjcf.export_with_assets(xml_handle, dir, file_name)
return file_path
[docs]
@classmethod
def get_all_task_names(cls):
"""
Returns a list of all available tasks in LocoMujoco.
"""
task_names = []
for e in cls.list_registered_loco_mujoco():
env = cls._registered_envs[e]
confs = env.valid_task_confs.get_all_combinations()
for conf in confs:
task_name = list(conf.values())
task_name.insert(0, env.__name__, )
task_name = ".".join(task_name)
task_names.append(task_name)
return task_names
_registered_envs = dict()
[docs]
class ValidTaskConf:
""" Simple class that holds all valid configurations of an environments. """
def __init__(self, tasks=None, modes=None, data_types=None, non_combinable=None):
"""
Args:
tasks (list): List of valid tasks.
modes (list): List of valid modes.
data_types (list): List of valid data_types.
non_combinable (list): List of tuples ("task", "mode", "dataset_type"),
which are NOT allowed to be combined. If one of them is None, it is neglected.
"""
self.tasks = tasks
self.modes = modes
self.data_types = data_types
self.non_combinable = non_combinable
if non_combinable is not None:
for nc in non_combinable:
assert len(nc) == 3
[docs]
def get_all(self):
return deepcopy(self.tasks), deepcopy(self.modes),\
deepcopy(self.data_types), deepcopy(self.non_combinable)
[docs]
def get_all_combinations(self):
"""
Returns all possible combinations of configurations.
"""
confs = []
if self.tasks is not None:
tasks = self.tasks
else:
tasks = [None]
if self.modes is not None:
modes = self.modes
else:
modes = [None]
if self.data_types is not None:
data_types = self.data_types
else:
data_types = [None]
for t, m, dt in product(tasks, modes, data_types):
conf = dict()
if t is not None:
conf["task"] = t
if m is not None:
conf["mode"] = m
if dt is not None:
conf["data_type"] = dt
# check for non-combinable
if self.non_combinable is not None:
for nc in self.non_combinable:
bad_t, bad_m, bad_dt = nc
if not((t == bad_t or bad_t is None) and
(m == bad_m or bad_m is None) and
(dt == bad_dt or bad_dt is None)):
confs.append(conf)
else:
confs.append(conf)
return confs