Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 87 additions & 2 deletions tests/envs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tests.helpers.test_case import TestCase
from tests.helpers.modular_microgrid import get_modular_microgrid

from pymgrid.modules import BatteryModule
from pymgrid.envs import DiscreteMicrogridEnv, ContinuousMicrogridEnv, NetLoadContinuousMicrogridEnv
from pymgrid.envs.base import BaseMicrogridEnv

Expand Down Expand Up @@ -131,6 +132,65 @@ class ObsKeysWithNetLoadParent(ObsKeysNoNetLoadParent):
observation_keys = ['net_load', 'soc', 'load_current', 'export_price_current']


class ObsKeysDuplicateKeysParent(ObsKeysNoNetLoadParent):
observation_keys = ['net_load', 'soc', 'load_current', 'load_current', 'export_price_current']

@pass_if_parent
def test_get_obs_correct_keys_in_modules(self):
env = deepcopy(self.env)
obs = env._get_obs()

unique_obs_keys = pd.Index(self.observation_keys).drop_duplicates().tolist()

for module in env.modules.iterlist():
module_state_dict = module.state_dict(normalized=True)
matching_keys = [obs_key for obs_key in unique_obs_keys if obs_key in module.state_dict().keys()]
matching_values = [module_state_dict[k] for k in matching_keys]

with self.subTest(module=module.name, keys=matching_keys):
self.assertEqual(obs[np.isin(unique_obs_keys, matching_keys)], matching_values)


class ObsKeysDuplicateModulesParent(Parent):

@pass_if_parent
def setUp(self) -> None:
second_battery = BatteryModule(
min_capacity=0,
max_capacity=1000,
max_charge=500,
max_discharge=500,
efficiency=1.0,
init_soc=0.5,
normalized_action_bounds=(0, 1))

microgrid = get_modular_microgrid(
additional_modules=[second_battery],
)

self.env = self.env_class.from_microgrid(microgrid, observation_keys=self.observation_keys)

@pass_if_parent
def test_pre_reset_state_series_invariant_to_observation_keys(self):
env = deepcopy(self.env)

self.assertEqual(env.state_series().shape, (15, ))

@pass_if_parent
def test_state_series_values(self):
env = deepcopy(self.env)

expected_state_series = np.array([10., -60., 50., 1., 1., 0., 0., 0.5, 50., 0.5, 500, 1., 1., 1., 1.])
self.assertEqual(env.state_series(normalized=False).values, expected_state_series)

@pass_if_parent
def test_state_series_values_normalized(self):
env = deepcopy(self.env)

expected_state_series = np.array([1/6., 0., 1., 1., 1., 0., 0., 0.5, 0.5, 0.5, 0.5, 0., 0., 0., 0.])
self.assertEqual(env.state_series(normalized=True).values, expected_state_series)


class TestDiscrete(Parent):
env_class = DiscreteMicrogridEnv

Expand All @@ -155,9 +215,34 @@ class TestNetLoadContinuousObsKeysNoNetLoad(ObsKeysNoNetLoadParent):
env_class = NetLoadContinuousMicrogridEnv


class TestDiscreteObsDuplicateKeys(ObsKeysDuplicateKeysParent):
env_class = DiscreteMicrogridEnv


class TestContinuousObsDuplicateKeys(ObsKeysDuplicateKeysParent):
env_class = ContinuousMicrogridEnv


class TestNetLoadContinuousObsDuplicateKeys(ObsKeysDuplicateKeysParent):
env_class = NetLoadContinuousMicrogridEnv


class TestDiscreteDuplicateModules(ObsKeysDuplicateModulesParent):
env_class = DiscreteMicrogridEnv


class TestContinuousDuplicateModules(ObsKeysDuplicateModulesParent):
env_class = ContinuousMicrogridEnv


class TestNetLoadContinuousDuplicateModules(ObsKeysDuplicateModulesParent):
env_class = NetLoadContinuousMicrogridEnv


def flatten_nested_dict(nested_dict):
def extract_list(l):
assert len(l) == 1, 'reduction only works with length 1 lists'
return l[0].tolist()
# assert len(l) == 1, 'reduction only works with length 1 lists'
# return l[0].tolist()
return sum([_l.tolist() for _l in l], [])

return functools.reduce(lambda x, y: x + extract_list(y), nested_dict.values(), [])
Loading