Skip to content

Commit 7927eb0

Browse files
Merge pull request #265 from MyoHub/dev
BUGFIX: Manipulation ENV
2 parents 2be0106 + b4c9ce5 commit 7927eb0

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

myosuite/envs/myo/myochallenge/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def register_env_with_variants(id, entry_point, max_episode_steps, kwargs):
3333

3434
register_env_with_variants(id='myoChallengeBimanual-v0',
3535
entry_point='myosuite.envs.myo.myochallenge.bimanual_v0:BimanualEnvV1',
36-
max_episode_steps=300,
36+
max_episode_steps=1000,
3737
kwargs={
3838
'model_path': curr_dir + '/../assets/arm/myoarm_bionic_bimanual.xml',
3939
'normalize_act': True,

myosuite/envs/myo/myochallenge/bimanual_v0.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from myosuite.envs.myo.base_v0 import BaseV0
1919

2020
CONTACT_TRAJ_MIN_LENGTH = 100
21+
GOAL_CONTACT = 10
22+
MAX_TIME = 10.0
2123

2224

2325
class BimanualEnvV1(BaseV0):
@@ -116,7 +118,7 @@ def _setup(self,
116118
self.over_max = False
117119
self.max_force = 0
118120
self.goal_touch = 0
119-
self.TARGET_GOAL_TOUCH = 5
121+
self.TARGET_GOAL_TOUCH = GOAL_CONTACT
120122

121123

122124
self.touch_history = []
@@ -291,10 +293,10 @@ def get_reward_dict(self, obs_dict):
291293
return rwd_dict
292294

293295
def _get_done(self, z):
294-
if self.obs_dict['time'] > 3.0:
296+
if self.obs_dict['time'] > MAX_TIME:
295297
return 1
296298
elif z < 0.3:
297-
self.obs_dict['time'] = 3.0
299+
self.obs_dict['time'] = MAX_TIME
298300
return 1
299301
elif self.rwd_dict and self.rwd_dict['solved']:
300302
return 1
@@ -448,9 +450,9 @@ def get_touching_objects(model: mujoco.MjModel, data: mujoco.MjData, id_info: Id
448450

449451

450452
def body_id_to_label(body_id, id_info: IdInfo):
451-
if id_info.myo_body_range[0] < body_id < id_info.myo_body_range[1]:
453+
if id_info.myo_body_range[0] <= body_id <= id_info.myo_body_range[1]:
452454
return ObjLabels.MYO
453-
elif id_info.prosth_body_range[0] < body_id < id_info.prosth_body_range[1]:
455+
elif id_info.prosth_body_range[0] <= body_id <= id_info.prosth_body_range[1]:
454456
return ObjLabels.PROSTH
455457
elif body_id == id_info.start_id:
456458
return ObjLabels.START
@@ -474,5 +476,5 @@ def evaluate_contact_trajectory(contact_trajectory: List[set]):
474476
return ContactTrajIssue.PROSTH_SHORT
475477

476478
# Check if only goal was touching object for the last CONTACT_TRAJ_MIN_LENGTH frames
477-
elif not np.all([{ObjLabels.GOAL} == s for s in contact_trajectory[-CONTACT_TRAJ_MIN_LENGTH:]]):
479+
elif not np.all([{ObjLabels.GOAL} == s for s in contact_trajectory[-GOAL_CONTACT + 2:]]): # Subtract 2 from the calculation to maintain a buffer zone around trajectory boundaries for safety/accuracy.
478480
return ContactTrajIssue.NO_GOAL

0 commit comments

Comments
 (0)