18
18
from myosuite .envs .myo .base_v0 import BaseV0
19
19
20
20
CONTACT_TRAJ_MIN_LENGTH = 100
21
+ GOAL_CONTACT = 10
22
+ MAX_TIME = 10.0
21
23
22
24
23
25
class BimanualEnvV1 (BaseV0 ):
@@ -116,7 +118,7 @@ def _setup(self,
116
118
self .over_max = False
117
119
self .max_force = 0
118
120
self .goal_touch = 0
119
- self .TARGET_GOAL_TOUCH = 5
121
+ self .TARGET_GOAL_TOUCH = GOAL_CONTACT
120
122
121
123
122
124
self .touch_history = []
@@ -291,10 +293,10 @@ def get_reward_dict(self, obs_dict):
291
293
return rwd_dict
292
294
293
295
def _get_done (self , z ):
294
- if self .obs_dict ['time' ] > 3.0 :
296
+ if self .obs_dict ['time' ] > MAX_TIME :
295
297
return 1
296
298
elif z < 0.3 :
297
- self .obs_dict ['time' ] = 3.0
299
+ self .obs_dict ['time' ] = MAX_TIME
298
300
return 1
299
301
elif self .rwd_dict and self .rwd_dict ['solved' ]:
300
302
return 1
@@ -448,9 +450,9 @@ def get_touching_objects(model: mujoco.MjModel, data: mujoco.MjData, id_info: Id
448
450
449
451
450
452
def body_id_to_label (body_id , id_info : IdInfo ):
451
- if id_info .myo_body_range [0 ] < body_id < id_info .myo_body_range [1 ]:
453
+ if id_info .myo_body_range [0 ] <= body_id <= id_info .myo_body_range [1 ]:
452
454
return ObjLabels .MYO
453
- elif id_info .prosth_body_range [0 ] < body_id < id_info .prosth_body_range [1 ]:
455
+ elif id_info .prosth_body_range [0 ] <= body_id <= id_info .prosth_body_range [1 ]:
454
456
return ObjLabels .PROSTH
455
457
elif body_id == id_info .start_id :
456
458
return ObjLabels .START
@@ -474,5 +476,5 @@ def evaluate_contact_trajectory(contact_trajectory: List[set]):
474
476
return ContactTrajIssue .PROSTH_SHORT
475
477
476
478
# Check if only goal was touching object for the last CONTACT_TRAJ_MIN_LENGTH frames
477
- elif not np .all ([{ObjLabels .GOAL } == s for s in contact_trajectory [- CONTACT_TRAJ_MIN_LENGTH :]]):
479
+ elif not np .all ([{ObjLabels .GOAL } == s for s in contact_trajectory [- GOAL_CONTACT + 2 :]]): # Subtract 2 from the calculation to maintain a buffer zone around trajectory boundaries for safety/accuracy.
478
480
return ContactTrajIssue .NO_GOAL
0 commit comments