Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 46 additions & 30 deletions src/agents/evaluator_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import shlex
import subprocess
from abc import ABC
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from time import sleep
from typing import Any
Expand Down Expand Up @@ -290,10 +291,14 @@ def multi_eval(
return per_env_results_last_reward, per_env_results_rewards


@contextmanager
def start_server(
agent_name: str, kwargs: dict[str, Any], port=8080, host="localhost", python_path: str = "python"
) -> subprocess.Popen:
"""Start the agent server in a subprocess.
agent_name: str, kwargs: dict[str, Any], port: int = 8080, host: str = "localhost", python_path: str = "python"
):
"""Start the agent server in a subprocess as a context manager.

This ensures that the server is properly stopped when exiting the context and
that all logs are printed to the console.

Args:
agent_name (str): Name of the agent to start.
Expand All @@ -303,27 +308,36 @@ def start_server(
python_path (str): Path to the Python interpreter to use. If you use conda you can look up the path with `conda info --envs`.
It can also be a format string that will be formatted with the agent_name, e.g. "conda run -n {agent_name} python".
Defaults to "python".
Returns:
subprocess.Popen: The process running the server.
"""

logging.info(
f"Server starting with command: {python_path.format(agent_name=agent_name)} -m agents start-server {agent_name} --port={port} --host={host} --kwargs={json.dumps(kwargs)}"
)
p = subprocess.Popen(
[
python_path.format(agent_name=agent_name),
"-m",
"agents",
"start-server",
f"{agent_name}",
f"--port={port}",
f"--host={host}",
f"--kwargs={json.dumps(kwargs)}",
]
)
logging.info("successfully started")
return p
cmd = [
python_path.format(agent_name=agent_name),
"-m",
"agents",
"start-server",
f"{agent_name}",
f"--port={port}",
f"--host={host}",
f"--kwargs={json.dumps(kwargs)}",
]
logging.info("Server starting: %s", " ".join(cmd))
p = subprocess.Popen(cmd)
sleep(5)
try:
yield p
finally:
# Stop the server no matter how we exit the with-block (success or exception).
try:
p.send_signal(subprocess.signal.SIGINT)
p.wait(timeout=5)
except Exception:
pass
if p.poll() is None:
p.terminate()
try:
p.wait(timeout=3)
except subprocess.TimeoutExpired:
p.kill()
logging.info("Server stopped")


def evaluation(
Expand All @@ -334,13 +348,15 @@ def evaluation(
):
per_process_cache.clear()
logging.info(f"Starting evaluation with {agent_cfg.agent_name} and {agent_cfg.agent_kwargs}")
with start_server(
agent_cfg.agent_name, agent_cfg.agent_kwargs, agent_cfg.port, agent_cfg.host, agent_cfg.python_path
) as p:
res = multi_eval(agent_cfg, eval_cfgs, episodes, n_processes)
logging.info("Evaluation finished")
# send ctrl c signal
p.send_signal(subprocess.signal.SIGINT)
try:
with start_server(
agent_cfg.agent_name, agent_cfg.agent_kwargs, agent_cfg.port, agent_cfg.host, agent_cfg.python_path
):
res = multi_eval(agent_cfg, eval_cfgs, episodes, n_processes)
except Exception:
# Ensures you SEE the client's stack trace and any logged errors.
logging.exception("Client failed")
raise

logging.info(f"Results (success, reward, steps) for all envs: {res[0].mean(axis=1)}")
logging.info(
Expand Down