diff --git a/src/pymgrid/envs/base/base.py b/src/pymgrid/envs/base/base.py index 8df52e9c..fcc3a2d0 100644 --- a/src/pymgrid/envs/base/base.py +++ b/src/pymgrid/envs/base/base.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +import warnings from collections import OrderedDict from gym import Env @@ -122,7 +123,19 @@ def _validate_observation_keys(self, keys): if net_load_pos.size: keys[[0, net_load_pos.item()]] = keys[[net_load_pos.item(), 0]] - return keys.tolist() + unique_keys, dupe_keys = [], [] + + for k in keys: + if k in unique_keys: + dupe_keys.append(k) + continue + + unique_keys.append(k) + + if dupe_keys: + warnings.warn(f'Found duplicated keys, will be dropped:\n\t{dupe_keys}') + + return unique_keys @abstractmethod def _get_action_space(self, remove_redundant_actions=False): @@ -133,7 +146,12 @@ def _get_observation_space(self): state_series = self.state_series() - observation_keys = self.observation_keys or state_series.index.get_level_values(-1) + if self.observation_keys is None or len(self.observation_keys) == 0: + observation_keys = state_series.index.get_level_values(-1) + else: + observation_keys = pd.Index(self.observation_keys) + + observation_keys = observation_keys.drop_duplicates() if 'net_load' in observation_keys: obs_space['general'] = Tuple([Box(low=-np.inf, high=1, shape=(1, ), dtype=np.float64)])