Source code for mjlab.actuator.learned_actuator

"""Learned actuator models."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal

import mujoco
import mujoco_warp as mjwarp
import torch

from mjlab.actuator.actuator import ActuatorCmd
from mjlab.actuator.dc_actuator import DcMotorActuator, DcMotorActuatorCfg
from mjlab.utils.buffers import CircularBuffer

if TYPE_CHECKING:
  from mjlab.entity import Entity


[docs] @dataclass(kw_only=True) class LearnedMlpActuatorCfg(DcMotorActuatorCfg): """Configuration for MLP-based learned actuator model. This actuator learns the mapping from joint commands and state history to actual torque output. It's useful for capturing actuator dynamics that are difficult to model analytically, such as delays, non-linearities, and friction effects. The network is trained offline using data from the real system and loaded as a TorchScript file. The model uses a sliding window of historical joint position errors and velocities as inputs. """ network_file: str """Path to the TorchScript file containing the trained MLP model.""" pos_scale: float """Scaling factor for joint position error inputs to the network.""" vel_scale: float """Scaling factor for joint velocity inputs to the network.""" torque_scale: float """Scaling factor for torque outputs from the network.""" input_order: Literal["pos_vel", "vel_pos"] = "pos_vel" """Order of inputs to the network. - "pos_vel": position errors followed by velocities - "vel_pos": velocities followed by position errors """ history_length: int = 3 """Number of timesteps of history to use as network inputs. For example, history_length=3 uses the current timestep plus the previous 2 timesteps (total of 3 frames). """ # Learned actuators don't use stiffness/damping from PD controller. stiffness: float = 0.0 damping: float = 0.0
[docs] def build( self, entity: Entity, target_ids: list[int], target_names: list[str] ) -> LearnedMlpActuator: return LearnedMlpActuator(self, entity, target_ids, target_names)
[docs] class LearnedMlpActuator(DcMotorActuator[LearnedMlpActuatorCfg]): """MLP-based learned actuator with joint history. This actuator uses a trained neural network to map from joint commands and state history to torque outputs. The network captures complex actuator dynamics that are difficult to model analytically. The actuator maintains circular buffers of joint position errors and velocities, which are used as inputs to the MLP. The network outputs are then scaled and clipped using the DC motor limits from the parent class. """
[docs] def __init__( self, cfg: LearnedMlpActuatorCfg, entity: Entity, target_ids: list[int], target_names: list[str], ) -> None: super().__init__(cfg, entity, target_ids, target_names) self.network: torch.jit.ScriptModule | None = None self._pos_error_history: CircularBuffer | None = None self._vel_history: CircularBuffer | None = None
[docs] def initialize( self, mj_model: mujoco.MjModel, model: mjwarp.Model, data: mjwarp.Data, device: str, ) -> None: super().initialize(mj_model, model, data, device) # Load the trained network from TorchScript file. self.network = torch.jit.load(self.cfg.network_file, map_location=device) assert self.network is not None self.network.eval() # Create history buffers. num_envs = data.nworld self._pos_error_history = CircularBuffer( max_len=self.cfg.history_length, batch_size=num_envs, device=device, ) self._vel_history = CircularBuffer( max_len=self.cfg.history_length, batch_size=num_envs, device=device, )
[docs] def reset(self, env_ids: torch.Tensor | slice | None = None) -> None: """Reset history buffers for specified environments. Args: env_ids: Environment indices to reset. If None, reset all environments. """ assert self._pos_error_history is not None assert self._vel_history is not None if env_ids is None: self._pos_error_history.reset() self._vel_history.reset() elif isinstance(env_ids, slice): # Convert slice to indices for CircularBuffer. batch_size = self._pos_error_history.batch_size indices = list(range(*env_ids.indices(batch_size))) self._pos_error_history.reset(indices) self._vel_history.reset(indices) else: self._pos_error_history.reset(env_ids) self._vel_history.reset(env_ids)
[docs] def compute(self, cmd: ActuatorCmd) -> torch.Tensor: """Compute actuator torques using the learned MLP model. Args: cmd: High-level actuator command containing targets and current state. Returns: Computed torque tensor of shape (num_envs, num_joints). """ assert self.network is not None assert self._pos_error_history is not None assert self._vel_history is not None assert self._joint_vel_clipped is not None # Update history buffers with current state. pos_error = cmd.position_target - cmd.pos self._pos_error_history.append(pos_error) self._vel_history.append(cmd.vel) # Save velocity for DC motor clipping in parent class. self._joint_vel_clipped[:] = cmd.vel num_envs = cmd.pos.shape[0] num_joints = cmd.pos.shape[1] # Extract history from current to history_length-1 steps back. # Each returns shape: (num_envs, num_joints). pos_inputs = [ self._pos_error_history[lag] for lag in range(self.cfg.history_length) ] vel_inputs = [self._vel_history[lag] for lag in range(self.cfg.history_length)] # Stack along feature dimension: (num_envs, num_joints, history_length). pos_stacked = torch.stack(pos_inputs, dim=2) vel_stacked = torch.stack(vel_inputs, dim=2) # Reshape to (num_envs * num_joints, num_history_steps) for network. pos_flat = pos_stacked.reshape(num_envs * num_joints, -1) vel_flat = vel_stacked.reshape(num_envs * num_joints, -1) # Scale and concatenate inputs based on specified order. if self.cfg.input_order == "pos_vel": network_input = torch.cat( [pos_flat * self.cfg.pos_scale, vel_flat * self.cfg.vel_scale], dim=1 ) elif self.cfg.input_order == "vel_pos": network_input = torch.cat( [vel_flat * self.cfg.vel_scale, pos_flat * self.cfg.pos_scale], dim=1 ) else: raise ValueError( f"Invalid input order: {self.cfg.input_order}. Must be 'pos_vel' or 'vel_pos'." ) # Run network inference. with torch.inference_mode(): torques_flat = self.network(network_input) # Reshape and scale output torques. computed_torques = torques_flat.reshape(num_envs, num_joints) computed_torques = computed_torques * self.cfg.torque_scale # Clip using DC motor limits from parent class. return self._clip_effort(computed_torques)