From cd7a496ce73a25c5337c2ae29631151be6d16662 Mon Sep 17 00:00:00 2001 From: ahalev Date: Fri, 20 Dec 2024 18:44:04 -0500 Subject: [PATCH 1/3] update observation keys to be unique --- src/pymgrid/envs/base/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/pymgrid/envs/base/base.py b/src/pymgrid/envs/base/base.py index 8df52e9c..c6f887fa 100644 --- a/src/pymgrid/envs/base/base.py +++ b/src/pymgrid/envs/base/base.py @@ -133,7 +133,12 @@ def _get_observation_space(self): state_series = self.state_series() - observation_keys = self.observation_keys or state_series.index.get_level_values(-1) + if self.observation_keys is None: + observation_keys = state_series.index.get_level_values(-1) + else: + observation_keys = pd.Index(self.observation_keys) + + observation_keys = observation_keys.drop_duplicates() if 'net_load' in observation_keys: obs_space['general'] = Tuple([Box(low=-np.inf, high=1, shape=(1, ), dtype=np.float64)]) From 7738060eaee12d479138febc56090122a0f8e2a0 Mon Sep 17 00:00:00 2001 From: ahalev Date: Fri, 20 Dec 2024 19:57:14 -0500 Subject: [PATCH 2/3] check for empty observation keys --- src/pymgrid/envs/base/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pymgrid/envs/base/base.py b/src/pymgrid/envs/base/base.py index c6f887fa..25fe846e 100644 --- a/src/pymgrid/envs/base/base.py +++ b/src/pymgrid/envs/base/base.py @@ -133,7 +133,7 @@ def _get_observation_space(self): state_series = self.state_series() - if self.observation_keys is None: + if self.observation_keys is None or len(self.observation_keys) == 0: observation_keys = state_series.index.get_level_values(-1) else: observation_keys = pd.Index(self.observation_keys) From 31ec1751b978126e1eecb5bb2a1e334f4f5ff491 Mon Sep 17 00:00:00 2001 From: ahalev Date: Fri, 20 Dec 2024 20:01:05 -0500 Subject: [PATCH 3/3] remove dupe obs keys --- src/pymgrid/envs/base/base.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/pymgrid/envs/base/base.py b/src/pymgrid/envs/base/base.py index 25fe846e..fcc3a2d0 100644 --- a/src/pymgrid/envs/base/base.py +++ b/src/pymgrid/envs/base/base.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +import warnings from collections import OrderedDict from gym import Env @@ -122,7 +123,19 @@ def _validate_observation_keys(self, keys): if net_load_pos.size: keys[[0, net_load_pos.item()]] = keys[[net_load_pos.item(), 0]] - return keys.tolist() + unique_keys, dupe_keys = [], [] + + for k in keys: + if k in unique_keys: + dupe_keys.append(k) + continue + + unique_keys.append(k) + + if dupe_keys: + warnings.warn(f'Found duplicated keys, will be dropped:\n\t{dupe_keys}') + + return unique_keys @abstractmethod def _get_action_space(self, remove_redundant_actions=False):