diff --git a/tests/envs/test_base.py b/tests/envs/test_base.py index 8eaf76bc..a3a1382a 100644 --- a/tests/envs/test_base.py +++ b/tests/envs/test_base.py @@ -7,6 +7,7 @@ from tests.helpers.test_case import TestCase from tests.helpers.modular_microgrid import get_modular_microgrid +from pymgrid.modules import BatteryModule from pymgrid.envs import DiscreteMicrogridEnv, ContinuousMicrogridEnv, NetLoadContinuousMicrogridEnv from pymgrid.envs.base import BaseMicrogridEnv @@ -131,6 +132,65 @@ class ObsKeysWithNetLoadParent(ObsKeysNoNetLoadParent): observation_keys = ['net_load', 'soc', 'load_current', 'export_price_current'] +class ObsKeysDuplicateKeysParent(ObsKeysNoNetLoadParent): + observation_keys = ['net_load', 'soc', 'load_current', 'load_current', 'export_price_current'] + + @pass_if_parent + def test_get_obs_correct_keys_in_modules(self): + env = deepcopy(self.env) + obs = env._get_obs() + + unique_obs_keys = pd.Index(self.observation_keys).drop_duplicates().tolist() + + for module in env.modules.iterlist(): + module_state_dict = module.state_dict(normalized=True) + matching_keys = [obs_key for obs_key in unique_obs_keys if obs_key in module.state_dict().keys()] + matching_values = [module_state_dict[k] for k in matching_keys] + + with self.subTest(module=module.name, keys=matching_keys): + self.assertEqual(obs[np.isin(unique_obs_keys, matching_keys)], matching_values) + + +class ObsKeysDuplicateModulesParent(Parent): + + @pass_if_parent + def setUp(self) -> None: + second_battery = BatteryModule( + min_capacity=0, + max_capacity=1000, + max_charge=500, + max_discharge=500, + efficiency=1.0, + init_soc=0.5, + normalized_action_bounds=(0, 1)) + + microgrid = get_modular_microgrid( + additional_modules=[second_battery], + ) + + self.env = self.env_class.from_microgrid(microgrid, observation_keys=self.observation_keys) + + @pass_if_parent + def test_pre_reset_state_series_invariant_to_observation_keys(self): + env = deepcopy(self.env) + + self.assertEqual(env.state_series().shape, (15, )) + + @pass_if_parent + def test_state_series_values(self): + env = deepcopy(self.env) + + expected_state_series = np.array([10., -60., 50., 1., 1., 0., 0., 0.5, 50., 0.5, 500, 1., 1., 1., 1.]) + self.assertEqual(env.state_series(normalized=False).values, expected_state_series) + + @pass_if_parent + def test_state_series_values_normalized(self): + env = deepcopy(self.env) + + expected_state_series = np.array([1/6., 0., 1., 1., 1., 0., 0., 0.5, 0.5, 0.5, 0.5, 0., 0., 0., 0.]) + self.assertEqual(env.state_series(normalized=True).values, expected_state_series) + + class TestDiscrete(Parent): env_class = DiscreteMicrogridEnv @@ -155,9 +215,34 @@ class TestNetLoadContinuousObsKeysNoNetLoad(ObsKeysNoNetLoadParent): env_class = NetLoadContinuousMicrogridEnv +class TestDiscreteObsDuplicateKeys(ObsKeysDuplicateKeysParent): + env_class = DiscreteMicrogridEnv + + +class TestContinuousObsDuplicateKeys(ObsKeysDuplicateKeysParent): + env_class = ContinuousMicrogridEnv + + +class TestNetLoadContinuousObsDuplicateKeys(ObsKeysDuplicateKeysParent): + env_class = NetLoadContinuousMicrogridEnv + + +class TestDiscreteDuplicateModules(ObsKeysDuplicateModulesParent): + env_class = DiscreteMicrogridEnv + + +class TestContinuousDuplicateModules(ObsKeysDuplicateModulesParent): + env_class = ContinuousMicrogridEnv + + +class TestNetLoadContinuousDuplicateModules(ObsKeysDuplicateModulesParent): + env_class = NetLoadContinuousMicrogridEnv + + def flatten_nested_dict(nested_dict): def extract_list(l): - assert len(l) == 1, 'reduction only works with length 1 lists' - return l[0].tolist() + # assert len(l) == 1, 'reduction only works with length 1 lists' + # return l[0].tolist() + return sum([_l.tolist() for _l in l], []) return functools.reduce(lambda x, y: x + extract_list(y), nested_dict.values(), [])