From 094f4196aad4667cdffa26eb3ea24d41dd31103b Mon Sep 17 00:00:00 2001 From: Attila Gombos Date: Sun, 18 Jan 2026 18:52:52 +0100 Subject: [PATCH 1/6] Initial commit WIP --- .devcontainer | 1 + .github/dependabot.yml | 11 ++ .github/workflows/ci.yaml | 17 +++ .gitmodules | 3 + .idea/.gitignore | 7 ++ .idea/copilot.data.migration.ask2agent.xml | 6 + .idea/misc.xml | 9 ++ .idea/modules.xml | 8 ++ .idea/python-hello.iml | 14 +++ .idea/vcs.xml | 7 ++ .vscode/launch.json | 35 ++++++ .vscode/settings.json | 20 +++ .vscode/tasks.json | 15 +++ README.md | 5 +- deps.json | 20 +++ hello/__init__.py | 7 ++ hello/advertizer.py | 125 +++++++++++++++++++ hello/discoverer.py | 119 ++++++++++++++++++ hello/group.py | 39 ++++++ hello/hello.py | 38 ++++++ hello/py.typed | 0 hello/receiver.py | 88 +++++++++++++ hello/sender.py | 73 +++++++++++ hello/service.py | 28 +++++ pyproject.toml | 27 ++++ python-hello.code-workspace | 15 +++ setup.cfg | 57 +++++++++ tests/__init__.py | 0 tests/advertizerIntegrationTest.py | 136 +++++++++++++++++++++ tests/discovererIntegrationTest.py | 97 +++++++++++++++ tests/receiverIntegrationTest.py | 58 +++++++++ tests/senderIntegrationTest.py | 66 ++++++++++ 32 files changed, 1150 insertions(+), 1 deletion(-) create mode 160000 .devcontainer create mode 100644 .github/dependabot.yml create mode 100644 .github/workflows/ci.yaml create mode 100644 .gitmodules create mode 100644 .idea/.gitignore create mode 100644 .idea/copilot.data.migration.ask2agent.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/python-hello.iml create mode 100644 .idea/vcs.xml create mode 100644 .vscode/launch.json create mode 100644 .vscode/settings.json create mode 100644 .vscode/tasks.json create mode 100644 deps.json create mode 100644 hello/__init__.py create mode 100644 hello/advertizer.py create mode 100644 hello/discoverer.py create mode 100644 hello/group.py create mode 100644 hello/hello.py create mode 100644 hello/py.typed create mode 100644 hello/receiver.py create mode 100644 hello/sender.py create mode 100644 hello/service.py create mode 100644 pyproject.toml create mode 100644 python-hello.code-workspace create mode 100644 setup.cfg create mode 100644 tests/__init__.py create mode 100644 tests/advertizerIntegrationTest.py create mode 100644 tests/discovererIntegrationTest.py create mode 100644 tests/receiverIntegrationTest.py create mode 100644 tests/senderIntegrationTest.py diff --git a/.devcontainer b/.devcontainer new file mode 160000 index 0000000..0ddb12b --- /dev/null +++ b/.devcontainer @@ -0,0 +1 @@ +Subproject commit 0ddb12bd5a2381d9c588f5fc6f96c876716087f4 diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..87ae6dc --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +--- +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" + - package-ecosystem: "gitsubmodule" + directory: "/" + schedule: + interval: "daily" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..2b8ec6b --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,17 @@ +name: CI + +on: + push: + branches: [main] + tags: [v*.*.*] + + pull_request: + branches: [ "main" ] + types: + - synchronize + - opened + - reopened + +jobs: + call_ci: + uses: EffectiveRange/ci-workflows/.github/workflows/python-ci.yaml@latest-python diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..f1b0753 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule ".devcontainer"] + path = .devcontainer + url = https://github.com/EffectiveRange/devcontainer-defs diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..d98e7a2 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,7 @@ +# IntelliJ IDEA folder-specific ignored files +/shelf/ +/workspace.xml +/queries/ +/dataSources/ +/dataSources.local.xml +/httpRequests/ diff --git a/.idea/copilot.data.migration.ask2agent.xml b/.idea/copilot.data.migration.ask2agent.xml new file mode 100644 index 0000000..1f2ea11 --- /dev/null +++ b/.idea/copilot.data.migration.ask2agent.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..30ff3a3 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,9 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..e83f687 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/python-hello.iml b/.idea/python-hello.iml new file mode 100644 index 0000000..6d790ec --- /dev/null +++ b/.idea/python-hello.iml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..8306744 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..d29106d --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,35 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "args": [ + "--backend=scipy" + ], + "console": "integratedTerminal" + }, + { + "name": "Run All Tests (pytest)", + "type": "debugpy", + "request": "launch", + "module": "pytest", + "args": [ + "tests" + ], + "console": "integratedTerminal" + }, + { + "name": "Run All Tests with Coverage (pytest-cov)", + "type": "debugpy", + "request": "launch", + "module": "pytest", + "args": [ + "--cov=hello", "tests" + ], + "console": "integratedTerminal" + } + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..ded9998 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,20 @@ +{ + "python.venvPath": "${workspaceFolder}/.venv", + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "*Test.py" + ], + "python.testing.pytestEnabled": true, + "python.testing.unittestEnabled": false, + "black-formatter.interpreter": [ + "${workspaceFolder}/.venv/bin/python3" + ], + "black-formatter.args": [ + "--config=setup.cfg" + ], + "python.analysis.typeCheckingMode": "standard", + "python.testing.pytestArgs": [] +} diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 0000000..328d1b1 --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,15 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "label": "Create Python venv", + "type": "shell", + "command": "if [ -d /var/chroot/buildroot ];then dpkgdeps -v --arch $(grep TARGET_ARCH /home/crossbuilder/target/target | cut -d'=' -f2 | tr -d \\') .;else dpkgdeps -v .;fi && rm -rf .venv && python3 -m venv --system-site-packages .venv && .venv/bin/pip install -e . && .venv/bin/python3 -m mypy --non-interactive --install-types && .venv/bin/pip install pytest-cov || true", + "group": "build", + "detail": "Creates a Python virtual environment in the .venv folder", + "problemMatcher": [ + "$eslint-compact" + ] + } + ] +} diff --git a/README.md b/README.md index db85593..2d7f101 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,5 @@ +[![CI](https://github.com/EffectiveRange/python-hello/actions/workflows/ci.yaml/badge.svg)](https://github.com/EffectiveRange/python-hello/actions/workflows/ci.yaml) +[![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/EffectiveRange/python-hello/python-coverage-comment-action-data/endpoint.json)](https://htmlpreview.github.io/?https://github.com/EffectiveRange/python-hello/blob/python-coverage-comment-action-data/htmlcov/index.html) + # python-hello -A service advertizer library using ZeroMQ +A service advertizer/discovery protocol library using ZeroMQ diff --git a/deps.json b/deps.json new file mode 100644 index 0000000..a6fe6ec --- /dev/null +++ b/deps.json @@ -0,0 +1,20 @@ +{ + "deps": [ + { + "name": "libzmq-drafts", + "hostinstall": true + }, + { + "name": "python3-zmq-drafts", + "hostinstall": true + }, + { + "name": "python3-context-logger", + "hostinstall": true + }, + { + "name": "python3-common-utility", + "hostinstall": true + } + ] +} diff --git a/hello/__init__.py b/hello/__init__.py new file mode 100644 index 0000000..7556868 --- /dev/null +++ b/hello/__init__.py @@ -0,0 +1,7 @@ +from .group import * +from .sender import * +from .receiver import * +from .service import * +from .advertizer import * +from .discoverer import * +from .hello import * diff --git a/hello/advertizer.py b/hello/advertizer.py new file mode 100644 index 0000000..fefd755 --- /dev/null +++ b/hello/advertizer.py @@ -0,0 +1,125 @@ +from typing import Any + +from common_utility import IReusableTimer +from context_logger import get_logger + +from hello import ServiceInfo, Group, Sender, GroupAccess, Receiver, ServiceMatcher, ServiceQuery + +log = get_logger('Advertizer') + + +class Advertizer: + + def start(self, address: str, group: Group, info: ServiceInfo | None = None) -> None: + raise NotImplementedError() + + def stop(self) -> None: + raise NotImplementedError() + + def advertise(self, info: ServiceInfo | None = None) -> None: + raise NotImplementedError() + + +class DefaultAdvertizer(Advertizer): + + def __init__(self, sender: Sender) -> None: + self._sender = sender + self._group: Group | None = None + self._info: ServiceInfo | None = None + + def __enter__(self) -> Advertizer: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.stop() + + def start(self, address: str, group: Group, info: ServiceInfo | None = None) -> None: + self._sender.start(GroupAccess(address, group.hello())) + self._group = group + self._info = info + + def stop(self) -> None: + self._group = None + self._sender.stop() + + def advertise(self, info: ServiceInfo | None = None) -> None: + if self._group: + if info: + self._info = info + if self._info: + self._sender.send(info) + log.info('Service advertised', service=self._info, group=self._group) + else: + log.warning('Cannot advertise service, advertizer not started', service=info) + + +class RespondingAdvertizer(DefaultAdvertizer): + + def __init__(self, sender: Sender, receiver: Receiver) -> None: + super().__init__(sender) + self._receiver = receiver + + def start(self, address: str, group: Group, info: ServiceInfo | None = None) -> None: + super().start(address, group, info) + self._receiver.start(GroupAccess(address, group.query())) + self._receiver.register(self._handle_query) + + def stop(self) -> None: + super().stop() + self._receiver.stop() + + def _handle_query(self, data: dict[str, str]) -> None: + if self._info: + matcher: ServiceMatcher | None = None + + try: + query = ServiceQuery(**data) + matcher = ServiceMatcher(query) + log.debug('Hail received', group=self._group, query=query) + except Exception as error: + log.warning('Invalid query message received', group=self._group, received=data, error=error) + + if matcher and matcher.matches(self._info): + log.info('Hail matches service', group=self._group, query=matcher.query, service=self._info) + self.advertise(self._info) + + +class ScheduledAdvertizer(Advertizer): + + def schedule(self, info: ServiceInfo, interval: float, one_shot: bool = False) -> None: + raise NotImplementedError() + + +class DefaultScheduledAdvertizer(ScheduledAdvertizer): + + def __init__(self, advertizer: Advertizer, timer: IReusableTimer) -> None: + self._advertizer = advertizer + self._timer = timer + + def __enter__(self) -> ScheduledAdvertizer: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.stop() + + def start(self, address: str, group: Group, info: ServiceInfo | None = None) -> None: + self._advertizer.start(address, group, info) + + def stop(self) -> None: + self._timer.cancel() + self._advertizer.stop() + + def advertise(self, info: ServiceInfo | None = None) -> None: + self._advertizer.advertise(info) + + def schedule(self, info: ServiceInfo, interval: float, one_shot: bool = False) -> None: + if one_shot: + self._timer.start(interval, self.advertise, [info]) + log.info('One-shot service advertisement scheduled', service=info, interval=interval) + else: + def periodic_advertise() -> None: + self.advertise(info) + self._timer.restart() + + self._timer.start(interval, periodic_advertise) + log.info('Periodic service advertisement scheduled', service=info, interval=interval) diff --git a/hello/discoverer.py b/hello/discoverer.py new file mode 100644 index 0000000..7b28223 --- /dev/null +++ b/hello/discoverer.py @@ -0,0 +1,119 @@ +from typing import Any, Callable + +from context_logger import get_logger + +from hello import Group, ServiceQuery, Sender, Receiver, GroupAccess, ServiceInfo, ServiceMatcher + +log = get_logger('Discoverer') + + +class Discoverer: + + def start(self, address: str, group: Group, query: ServiceQuery | None = None) -> None: + raise NotImplementedError() + + def stop(self) -> None: + raise NotImplementedError() + + def discover(self, query: ServiceQuery | None = None) -> None: + raise NotImplementedError() + + def get_services(self) -> dict[str, ServiceInfo]: + raise NotImplementedError() + + def register(self, callback: Callable[[Any], None]) -> None: + raise NotImplementedError() + + def deregister(self, callback: Callable[[Any], None]) -> None: + raise NotImplementedError() + + +class DefaultDiscoverer(Discoverer): + + def __init__(self, sender: Sender, receiver: Receiver) -> None: + self._sender = sender + self._receiver = receiver + self._group: Group | None = None + self._matcher: ServiceMatcher | None = None + self._services: dict[str, ServiceInfo] = {} + self._callbacks: list[Callable[[ServiceInfo], None]] = [] + + def __enter__(self) -> Discoverer: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.stop() + + def start(self, address: str, group: Group, query: ServiceQuery | None = None) -> None: + self._group = group + if query: + self._matcher = ServiceMatcher(query) + self._sender.start(GroupAccess(address, group.query())) + self._receiver.register(self._handle_message) + self._receiver.start(GroupAccess(address, group.hello())) + + def stop(self) -> None: + self._group = None + self._sender.stop() + self._receiver.stop() + + def discover(self, query: ServiceQuery | None = None) -> None: + if self._group: + if query: + self._matcher = ServiceMatcher(query) + if self._matcher: + self._sender.send(self._matcher.query) + log.info('Service discovery initiated', query=self._matcher.query, group=self._group) + else: + log.warning('Cannot initiate service discovery, discoverer not started', query=query) + + def get_services(self) -> dict[str, ServiceInfo]: + return self._services.copy() + + def register(self, callback: Callable[[Any], None]) -> None: + self._callbacks.append(callback) + + def deregister(self, callback: Callable[[Any], None]) -> None: + self._callbacks.remove(callback) + + def _handle_message(self, data: dict[str, Any]) -> None: + service: ServiceInfo | None = None + + try: + service = ServiceInfo(**data) + except Exception as error: + log.warn('Failed to handle received message', data=data, error=error) + + if service: + self._handle_service(service) + + def _handle_service(self, service: ServiceInfo) -> None: + if self._matcher and self._matcher.matches(service): + cached = self._services.get(service.name) + + if self._is_update_needed(cached, service): + self._services[service.name] = service + for callback in self._callbacks: + try: + callback(service) + except Exception as error: + log.warn('Error in callback execution', service=service, error=error) + + def _is_update_needed(self, cached: ServiceInfo | None, service: ServiceInfo) -> bool: + if cached: + if cached != service: + log.info('Service updated', old_service=cached, new_service=service) + return True + else: + log.info('Service discovered', service=service) + return True + + return False + + def _handle_update(self, service: ServiceInfo) -> None: + self._services[service.name] = service + for callback in self._callbacks: + try: + callback(service) + except Exception as error: + log.warn('Error in callback execution', service=service, error=error) diff --git a/hello/group.py b/hello/group.py new file mode 100644 index 0000000..95b0654 --- /dev/null +++ b/hello/group.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from enum import Enum + + +class GroupPrefix(Enum): + HELLO = 'hello' + QUERY = 'query' + + +class IGroup: + + def hello(self) -> str: + raise NotImplementedError() + + def query(self) -> str: + raise NotImplementedError() + + +class Group(IGroup): + def __init__(self, name: str) -> None: + self.name = name + + def hello(self) -> str: + return self._prefix(GroupPrefix.HELLO) + + def query(self) -> str: + return self._prefix(GroupPrefix.QUERY) + + def _prefix(self, group_type: GroupPrefix) -> str: + return f'{group_type.value}:{self.name}' + + def __repr__(self) -> str: + return self.name + + +@dataclass +class GroupAccess: + access_url: str + full_group: str diff --git a/hello/hello.py b/hello/hello.py new file mode 100644 index 0000000..ba7157f --- /dev/null +++ b/hello/hello.py @@ -0,0 +1,38 @@ +from typing import Any + +from common_utility import ReusableTimer, IReusableTimer +from zmq import Context + +from hello import Advertizer, Discoverer, RadioSender, DishReceiver, DefaultAdvertizer, DefaultDiscoverer, \ + ScheduledAdvertizer, RespondingAdvertizer, DefaultScheduledAdvertizer + + +class Hello: + + def default_advertizer(self, respond: bool = True) -> Advertizer: + raise NotImplementedError() + + def scheduled_advertizer(self, timer: IReusableTimer | None = None, respond: bool = True) -> ScheduledAdvertizer: + raise NotImplementedError() + + def discoverer(self) -> Discoverer: + raise NotImplementedError() + + +class DefaultHello(Hello): + + def __init__(self, context: Context[Any] | None = None) -> None: + self._context = context if context else Context() + self._sender = RadioSender(self._context) + self._receiver = DishReceiver(self._context) + + def default_advertizer(self, respond: bool = True) -> Advertizer: + return RespondingAdvertizer(self._sender, self._receiver) if respond else DefaultAdvertizer(self._sender) + + def scheduled_advertizer(self, timer: IReusableTimer | None = None, respond: bool = True) -> ScheduledAdvertizer: + advertizer = self.default_advertizer(respond) + reusable_timer = timer if timer else ReusableTimer() + return DefaultScheduledAdvertizer(advertizer, reusable_timer) + + def discoverer(self) -> Discoverer: + return DefaultDiscoverer(self._sender, self._receiver) diff --git a/hello/py.typed b/hello/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/hello/receiver.py b/hello/receiver.py new file mode 100644 index 0000000..daedc5c --- /dev/null +++ b/hello/receiver.py @@ -0,0 +1,88 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable + +from context_logger import get_logger +from zmq import DISH, Poller, POLLIN, POLLOUT, Context + +from hello import GroupAccess + +log = get_logger('Receiver') + + +class Receiver: + + def start(self, source: GroupAccess) -> None: + raise NotImplementedError() + + def stop(self) -> None: + raise NotImplementedError() + + def register(self, callback: Callable[[Any], None]) -> None: + raise NotImplementedError() + + def deregister(self, callback: Callable[[Any], None]) -> None: + raise NotImplementedError() + + +class DishReceiver(Receiver): + + def __init__(self, context: Context[Any]) -> None: + self._context = context + self._dish = self._context.socket(DISH) + self._poller = Poller() + self._executor = ThreadPoolExecutor(max_workers=1) + self._group: str | None = None + self._callbacks: list[Callable[[Any], None]] = [] + + def __enter__(self) -> Receiver: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.stop() + + def start(self, source: GroupAccess) -> None: + try: + self._poller.register(self._dish, POLLIN) + self._dish.bind(source.access_url) + self._dish.join(source.full_group) + self._group = source.full_group + self._executor.submit(self._handle_messages) + log.info('Receiver started', address=source.access_url, group=source.full_group) + except Exception as error: + log.error('Failed to start receiver', address=source.access_url, group=source.full_group, error=error) + raise error + + def stop(self) -> None: + try: + self._group = None + self._poller.register(self._dish, POLLOUT) + self._executor.shutdown() + self._dish.close() + log.info('Receiver stopped') + except Exception as error: + log.error('Failed to stop receiver', error=error) + raise error + + def register(self, callback: Callable[[Any], None]) -> None: + self._callbacks.append(callback) + + def deregister(self, callback: Callable[[Any], None]) -> None: + self._callbacks.remove(callback) + + def _handle_messages(self) -> None: + while self._group: + sockets = dict(self._poller.poll(timeout=100)) + if self._dish in sockets and sockets[self._dish] == POLLIN: + try: + data = self._dish.recv_json() + log.debug('Message received', data=data, group=self._group) + self._handle_message(data) + except Exception as error: + log.error('Failed to receive message', group=self._group, error=error) + + def _handle_message(self, data: dict[str, str]) -> None: + for callback in self._callbacks: + try: + callback(data) + except Exception as error: + log.warn('Error in callback execution', data=data, group=self._group, error=error) diff --git a/hello/sender.py b/hello/sender.py new file mode 100644 index 0000000..3a562b4 --- /dev/null +++ b/hello/sender.py @@ -0,0 +1,73 @@ +from typing import Any, cast + +from context_logger import get_logger +from zmq import Context, RADIO, Socket + +from hello import GroupAccess + +log = get_logger('Sender') + + +class Sender: + + def start(self, target: GroupAccess) -> None: + raise NotImplementedError() + + def stop(self) -> None: + raise NotImplementedError() + + def send(self, data: object) -> None: + raise NotImplementedError() + + +class RadioSender(Sender): + + def __init__(self, context: Context[Any]) -> None: + self._context = context + self._radio: Socket[bytes] = self._context.socket(RADIO) + self._group: str | None = None + + def __enter__(self) -> Sender: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.stop() + + def start(self, target: GroupAccess) -> None: + try: + self._radio.connect(target.access_url) + self._group = target.full_group + log.info('Sender started', address=target.access_url, group=target.full_group) + except Exception as error: + log.error('Failed to start sender', address=target.access_url, group=target.full_group, error=error) + raise error + + def stop(self) -> None: + try: + self._group = None + self._radio.close() + log.info('Sender stopped') + except Exception as error: + log.error('Failed to stop sender', error=error) + raise error + + def send(self, data: Any) -> None: + if self._group: + if data := self._convert_to_dict(data): + self._send_json(data) + else: + log.warning('Unsupported message type', data=data, group=self._group) + + def _convert_to_dict(self, data: Any) -> dict[str, Any] | None: + if isinstance(data, dict): + return data + elif hasattr(data, '__dict__'): + return cast(dict[str, Any], data.__dict__) + return None + + def _send_json(self, data: dict[str, Any]) -> None: + try: + self._radio.send_json(data, group=self._group) + log.debug('Message sent', data=data, group=self._group) + except Exception as error: + log.error('Failed to send message', data=data, group=self._group, error=error) diff --git a/hello/service.py b/hello/service.py new file mode 100644 index 0000000..374b595 --- /dev/null +++ b/hello/service.py @@ -0,0 +1,28 @@ +import re +from dataclasses import dataclass + + +@dataclass +class ServiceInfo: + name: str + role: str + url: str + + +@dataclass +class ServiceQuery(object): + name: str + role: str + + +class ServiceMatcher(object): + + def __init__(self, query: ServiceQuery) -> None: + self.query = query + self._name_matcher = re.compile(self.query.name) + self._role_matcher = re.compile(self.query.role) + + def matches(self, info: ServiceInfo) -> bool: + name_match = self._name_matcher.match(info.name) + role_match = self._role_matcher.match(info.role) + return bool(name_match and role_match) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8622581 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,27 @@ +[project] +name = "python-hello" +description = "A service advertizer/discovery protocol library using ZeroMQ" +authors = [ + { name = "Ferenc Nandor Janky & Attila Gombos", email = "info@effective-range.com" } +] +dependencies = [ + "pyzmq @ git+https://github.com/EffectiveRange/pyzmq.git@v27.1.1", + "python-context-logger @ git+https://github.com/EffectiveRange/python-context-logger.git@latest", + "python-common-utility @ git+https://github.com/EffectiveRange/python-common-utility.git@latest" +] +dynamic = ["version"] + +[tool.setuptools] +package-dir = {"" = "."} +packages = ["hello"] + +[tool.setuptools.package-data] +hello = ["py.typed"] + +[build-system] +requires = ["setuptools>=61", "setuptools_scm"] +build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] +version_scheme = "guess-next-dev" +local_scheme = "node-and-date" diff --git a/python-hello.code-workspace b/python-hello.code-workspace new file mode 100644 index 0000000..a083c85 --- /dev/null +++ b/python-hello.code-workspace @@ -0,0 +1,15 @@ +{ + "folders": [ + { + "path": "." + } + ], + "settings": { + "python.formatting.provider": "black", + "editor.formatOnSave": true, + "mypy-type-checker.ignorePatterns": [ + ], + "flake8.ignorePatterns": [ + ] + } +} diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..ce39c45 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,57 @@ +[pack-python] +packaging = + wheel + fpm-deb + +[mypy] +packages = hello +strict = True + +[flake8] +exclude = build,dist,.eggs,.venv +max-line-length = 120 +max-complexity = 10 +count = True +statistics = True +show-source = True +per-file-ignores = + # F401: imported but unused + # F403: import * used; unable to detect undefined names + __init__.py: F401,F403 + +[tool:pytest] +addopts = --capture=no --verbose +python_files = *Test.py +python_classes = *Test + +[coverage:run] +relative_files = true +branch = True +source = hello + +[coverage:report] +; Regexes for lines to exclude from consideration +exclude_also = + ; Don't complain about missing debug-only code: + def __repr__ + if self\.debug + + ; Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + + ; Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: + + ; Don't complain about abstract methods, they aren't run: + @(abc\.)?abstractmethod + +ignore_errors = True +skip_empty = True + +[coverage:html] +directory = coverage/html + +[coverage:json] +output = coverage/coverage.json diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/advertizerIntegrationTest.py b/tests/advertizerIntegrationTest.py new file mode 100644 index 0000000..d22630e --- /dev/null +++ b/tests/advertizerIntegrationTest.py @@ -0,0 +1,136 @@ +import unittest +from unittest import TestCase + +from context_logger import setup_logging +from test_utility import wait_for_assertion +from zmq import Context + +from hello import DefaultAdvertizer, ServiceInfo, Group, RadioSender, DishReceiver, GroupAccess, \ + RespondingAdvertizer, ServiceQuery + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) + + +class AdvertizerTest(TestCase): + SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + self.SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + + def test_sends_hello_when_advertises_service(self): + # Given + context = Context() + sender = RadioSender(context) + messages = [] + + with DefaultAdvertizer(sender) as advertizer, DishReceiver(context) as test_receiver: + test_receiver.start(GroupAccess(ACCESS_URL, GROUP.hello())) + test_receiver.register(lambda message: messages.append(message)) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.advertise(self.SERVICE_INFO) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(messages))) + + # Then + self.assertEqual([self.SERVICE_INFO.__dict__], messages) + + def test_sends_hello_when_advertises_service_and_info_changed(self): + # Given + context = Context() + sender = RadioSender(context) + messages = [] + + with DefaultAdvertizer(sender) as advertizer, DishReceiver(context) as test_receiver: + test_receiver.start(GroupAccess(ACCESS_URL, GROUP.hello())) + test_receiver.register(lambda message: messages.append(message)) + advertizer.start(ACCESS_URL, GROUP) + + advertizer.advertise(self.SERVICE_INFO) + + self.SERVICE_INFO.url = 'http://localhost:9090' + + # When + advertizer.advertise(self.SERVICE_INFO) + + wait_for_assertion(0.1, lambda: self.assertEqual(2, len(messages))) + + # Then + self.assertEqual([ + {'name': 'test-service', 'role': 'test-role', 'url': 'http://localhost:8080'}, + {'name': 'test-service', 'role': 'test-role', 'url': 'http://localhost:9090'} + ], messages) + + def test_sends_hello_when_query_received(self): + # Given + context = Context() + sender = RadioSender(context) + receiver = DishReceiver(context) + messages = [] + + with (RespondingAdvertizer(sender, receiver) as advertizer, + RadioSender(context) as test_sender, + DishReceiver(context) as test_receiver): + test_sender.start(GroupAccess(ACCESS_URL, GROUP.query())) + test_receiver.start(GroupAccess(ACCESS_URL, GROUP.hello())) + test_receiver.register(lambda message: messages.append(message)) + + advertizer.start(ACCESS_URL, GROUP, self.SERVICE_INFO) + + # When + test_sender.send(ServiceQuery('test-service', 'test-role')) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(messages))) + + # Then + self.assertEqual([self.SERVICE_INFO.__dict__], messages) + + def test_sends_hello_when_info_changed_and_query_received(self): + # Given + context = Context() + sender = RadioSender(context) + receiver = DishReceiver(context) + messages = [] + + with (RespondingAdvertizer(sender, receiver) as advertizer, + RadioSender(context) as test_sender, + DishReceiver(context) as test_receiver): + test_sender.start(GroupAccess(ACCESS_URL, GROUP.query())) + test_receiver.start(GroupAccess(ACCESS_URL, GROUP.hello())) + test_receiver.register(lambda message: messages.append(message)) + + advertizer.start(ACCESS_URL, GROUP) + advertizer.advertise(self.SERVICE_INFO) + + query = ServiceQuery('test-service', 'test-role') + test_sender.send(query) + + wait_for_assertion(0.1, lambda: self.assertEqual(2, len(messages))) + + self.SERVICE_INFO.url = 'http://localhost:9090' + advertizer.advertise(self.SERVICE_INFO) + + # When + test_sender.send(query) + + wait_for_assertion(0.1, lambda: self.assertEqual(4, len(messages))) + + # Then + self.assertEqual([ + {'name': 'test-service', 'role': 'test-role', 'url': 'http://localhost:8080'}, + {'name': 'test-service', 'role': 'test-role', 'url': 'http://localhost:8080'}, + {'name': 'test-service', 'role': 'test-role', 'url': 'http://localhost:9090'}, + {'name': 'test-service', 'role': 'test-role', 'url': 'http://localhost:9090'} + ], messages) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/discovererIntegrationTest.py b/tests/discovererIntegrationTest.py new file mode 100644 index 0000000..770b673 --- /dev/null +++ b/tests/discovererIntegrationTest.py @@ -0,0 +1,97 @@ +import unittest +from unittest import TestCase + +from context_logger import setup_logging +from test_utility import wait_for_assertion +from zmq import Context + +from hello import ServiceInfo, Group, RadioSender, DishReceiver, GroupAccess, \ + ServiceQuery, DefaultDiscoverer + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_QUERY = ServiceQuery('test-service', 'test-role') + + +class AdvertizerTest(TestCase): + SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + self.SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + + def test_discovers_service_when_hello_received(self): + # Given + context = Context() + sender = RadioSender(context) + receiver = DishReceiver(context) + + with DefaultDiscoverer(sender, receiver) as discoverer, RadioSender(context) as test_sender: + test_sender.start(GroupAccess(ACCESS_URL, GROUP.hello())) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + # When + test_sender.send(self.SERVICE_INFO) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) + + # Then + self.assertEqual({self.SERVICE_INFO.name: self.SERVICE_INFO}, discoverer.get_services()) + + def test_updates_service_when_info_changed(self): + # Given + context = Context() + sender = RadioSender(context) + receiver = DishReceiver(context) + + with DefaultDiscoverer(sender, receiver) as discoverer, RadioSender(context) as test_sender: + test_sender.start(GroupAccess(ACCESS_URL, GROUP.hello())) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + test_sender.send(self.SERVICE_INFO) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) + + # When + self.SERVICE_INFO.url = 'http://localhost:9090' + test_sender.send(self.SERVICE_INFO) + + wait_for_assertion(0.1, lambda: self.assertEqual( + 'http://localhost:9090', + discoverer.get_services()[self.SERVICE_INFO.name].url + )) + + # Then + self.assertEqual( + 'http://localhost:9090', + discoverer.get_services()[self.SERVICE_INFO.name].url + ) + + def test_sends_query(self): + # Given + context = Context() + sender = RadioSender(context) + receiver = DishReceiver(context) + messages = [] + + with DefaultDiscoverer(sender, receiver) as discoverer, DishReceiver(context) as test_receiver: + test_receiver.register(lambda message: messages.append(message)) + test_receiver.start(GroupAccess(ACCESS_URL, GROUP.query())) + discoverer.start(ACCESS_URL, GROUP) + + # When + discoverer.discover(SERVICE_QUERY) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(messages))) + + # Then + self.assertEqual([SERVICE_QUERY.__dict__], messages) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/receiverIntegrationTest.py b/tests/receiverIntegrationTest.py new file mode 100644 index 0000000..a580022 --- /dev/null +++ b/tests/receiverIntegrationTest.py @@ -0,0 +1,58 @@ +import unittest +from unittest import TestCase + +from context_logger import setup_logging +from test_utility import wait_for_assertion +from zmq import Context, ZMQError + +from hello import ServiceInfo, Group, DishReceiver, GroupAccess, RadioSender + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + + +class ReceiverTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_raises_error_when_restarted(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = Context() + + with DishReceiver(context) as receiver: + receiver.start(group_access) + + # When, Then + with self.assertRaises(ZMQError): + receiver.start(group_access) + + def test_receives_message(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = Context() + messages = [] + + with DishReceiver(context) as receiver, RadioSender(context) as test_sender: + receiver.register(lambda message: messages.append(message)) + receiver.start(group_access) + test_sender.start(group_access) + + # When + test_sender.send(SERVICE_INFO) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(messages))) + + # Then + self.assertEqual([SERVICE_INFO.__dict__], messages) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/senderIntegrationTest.py b/tests/senderIntegrationTest.py new file mode 100644 index 0000000..ae44256 --- /dev/null +++ b/tests/senderIntegrationTest.py @@ -0,0 +1,66 @@ +import unittest +from time import sleep +from unittest import TestCase + +from context_logger import setup_logging +from test_utility import wait_for_assertion +from zmq import Context + +from hello import ServiceInfo, Group, GroupAccess, DishReceiver +from hello.sender import RadioSender + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + + +class SenderTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_sends_message(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = Context() + messages = [] + + with RadioSender(context) as sender, DishReceiver(context) as test_receiver: + test_receiver.register(lambda message: messages.append(message)) + test_receiver.start(group_access) + sender.start(group_access) + + # When + sender.send(SERVICE_INFO) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(messages))) + + # Then + self.assertEqual([SERVICE_INFO.__dict__], messages) + + def test_skips_sending_message_when_not_serializable(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = Context() + messages = [] + + with RadioSender(context) as sender, DishReceiver(context) as test_receiver: + test_receiver.register(lambda message: messages.append(message)) + test_receiver.start(group_access) + sender.start(group_access) + + # When + sender.send('not serializable message') + + sleep(0.1) + + self.assertEqual(0, len(messages)) + + +if __name__ == '__main__': + unittest.main() From 48c0079f456c28bae7b0d84d890c8ce7c63e517f Mon Sep 17 00:00:00 2001 From: Attila Gombos Date: Sun, 18 Jan 2026 22:39:25 +0100 Subject: [PATCH 2/6] Initial commit WIP --- .idea/.gitignore | 1 + .idea/copilot.data.migration.ask2agent.xml | 6 ------ .idea/python-hello.iml | 2 ++ .idea/vcs.xml | 1 + hello/advertizer.py | 4 ++-- 5 files changed, 6 insertions(+), 8 deletions(-) delete mode 100644 .idea/copilot.data.migration.ask2agent.xml diff --git a/.idea/.gitignore b/.idea/.gitignore index d98e7a2..ca1b149 100644 --- a/.idea/.gitignore +++ b/.idea/.gitignore @@ -5,3 +5,4 @@ /dataSources/ /dataSources.local.xml /httpRequests/ +copilot.* \ No newline at end of file diff --git a/.idea/copilot.data.migration.ask2agent.xml b/.idea/copilot.data.migration.ask2agent.xml deleted file mode 100644 index 1f2ea11..0000000 --- a/.idea/copilot.data.migration.ask2agent.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - \ No newline at end of file diff --git a/.idea/python-hello.iml b/.idea/python-hello.iml index 6d790ec..a2043e1 100644 --- a/.idea/python-hello.iml +++ b/.idea/python-hello.iml @@ -3,10 +3,12 @@ + + diff --git a/.idea/vcs.xml b/.idea/vcs.xml index 8306744..5831342 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -3,5 +3,6 @@ + \ No newline at end of file diff --git a/hello/advertizer.py b/hello/advertizer.py index fefd755..a255dc2 100644 --- a/hello/advertizer.py +++ b/hello/advertizer.py @@ -75,12 +75,12 @@ def _handle_query(self, data: dict[str, str]) -> None: try: query = ServiceQuery(**data) matcher = ServiceMatcher(query) - log.debug('Hail received', group=self._group, query=query) + log.debug('Query received', group=self._group, query=query) except Exception as error: log.warning('Invalid query message received', group=self._group, received=data, error=error) if matcher and matcher.matches(self._info): - log.info('Hail matches service', group=self._group, query=matcher.query, service=self._info) + log.info('Query matches service', group=self._group, query=matcher.query, service=self._info) self.advertise(self._info) From da6db0aed3acb1b7a51264699d1917be931313ce Mon Sep 17 00:00:00 2001 From: Attila Gombos Date: Sun, 18 Jan 2026 23:04:54 +0100 Subject: [PATCH 3/6] Initial commit WIP --- hello/__init__.py | 2 +- hello/{hello.py => api.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename hello/{hello.py => api.py} (100%) diff --git a/hello/__init__.py b/hello/__init__.py index 7556868..c360ae2 100644 --- a/hello/__init__.py +++ b/hello/__init__.py @@ -4,4 +4,4 @@ from .service import * from .advertizer import * from .discoverer import * -from .hello import * +from .api import * diff --git a/hello/hello.py b/hello/api.py similarity index 100% rename from hello/hello.py rename to hello/api.py From e7dbc2c698fb176dc68879ad803db615b5dd774b Mon Sep 17 00:00:00 2001 From: Attila Gombos Date: Mon, 19 Jan 2026 09:55:50 +0100 Subject: [PATCH 4/6] Initial commit WIP --- hello/advertizer.py | 33 +++++----- hello/api.py | 32 ++++++---- hello/discoverer.py | 59 +++++++++++------- hello/receiver.py | 35 ++++++----- hello/sender.py | 8 ++- tests/advertizerIntegrationTest.py | 47 ++++++++++++++- tests/apiIntegrationTest.py | 97 ++++++++++++++++++++++++++++++ tests/discovererIntegrationTest.py | 2 +- tests/receiverIntegrationTest.py | 4 +- tests/senderIntegrationTest.py | 13 ++++ 10 files changed, 258 insertions(+), 72 deletions(-) create mode 100644 tests/apiIntegrationTest.py diff --git a/hello/advertizer.py b/hello/advertizer.py index a255dc2..efb9922 100644 --- a/hello/advertizer.py +++ b/hello/advertizer.py @@ -1,3 +1,5 @@ +import random +import time from typing import Any from common_utility import IReusableTimer @@ -47,7 +49,7 @@ def advertise(self, info: ServiceInfo | None = None) -> None: if info: self._info = info if self._info: - self._sender.send(info) + self._sender.send(self._info) log.info('Service advertised', service=self._info, group=self._group) else: log.warning('Cannot advertise service, advertizer not started', service=info) @@ -55,38 +57,41 @@ def advertise(self, info: ServiceInfo | None = None) -> None: class RespondingAdvertizer(DefaultAdvertizer): - def __init__(self, sender: Sender, receiver: Receiver) -> None: + def __init__(self, sender: Sender, receiver: Receiver, max_response_delay: float = 0.1) -> None: super().__init__(sender) self._receiver = receiver + self._max_delay = max_response_delay def start(self, address: str, group: Group, info: ServiceInfo | None = None) -> None: super().start(address, group, info) self._receiver.start(GroupAccess(address, group.query())) - self._receiver.register(self._handle_query) + self._receiver.register(self._handle_message) def stop(self) -> None: super().stop() self._receiver.stop() - def _handle_query(self, data: dict[str, str]) -> None: + def _handle_message(self, message: dict[str, Any]) -> None: if self._info: - matcher: ServiceMatcher | None = None - try: - query = ServiceQuery(**data) - matcher = ServiceMatcher(query) + query = ServiceQuery(**message) log.debug('Query received', group=self._group, query=query) + self._handle_query(query, self._info) except Exception as error: - log.warning('Invalid query message received', group=self._group, received=data, error=error) + log.warning('Invalid query message received', group=self._group, received=message, error=error) - if matcher and matcher.matches(self._info): - log.info('Query matches service', group=self._group, query=matcher.query, service=self._info) - self.advertise(self._info) + def _handle_query(self, query: ServiceQuery, info: ServiceInfo) -> None: + matcher = ServiceMatcher(query) + if matcher and matcher.matches(info): + delay = round(self._max_delay * random.random(), 3) + log.info('Responding to query', group=self._group, query=matcher.query, service=info, delay=delay) + time.sleep(delay) + self.advertise(info) class ScheduledAdvertizer(Advertizer): - def schedule(self, info: ServiceInfo, interval: float, one_shot: bool = False) -> None: + def schedule(self, info: ServiceInfo | None = None, interval: float = 10, one_shot: bool = False) -> None: raise NotImplementedError() @@ -112,7 +117,7 @@ def stop(self) -> None: def advertise(self, info: ServiceInfo | None = None) -> None: self._advertizer.advertise(info) - def schedule(self, info: ServiceInfo, interval: float, one_shot: bool = False) -> None: + def schedule(self, info: ServiceInfo | None = None, interval: float = 10, one_shot: bool = False) -> None: if one_shot: self._timer.start(interval, self.advertise, [info]) log.info('One-shot service advertisement scheduled', service=info, interval=interval) diff --git a/hello/api.py b/hello/api.py index ba7157f..57fc5c9 100644 --- a/hello/api.py +++ b/hello/api.py @@ -1,6 +1,6 @@ from typing import Any -from common_utility import ReusableTimer, IReusableTimer +from common_utility import ReusableTimer from zmq import Context from hello import Advertizer, Discoverer, RadioSender, DishReceiver, DefaultAdvertizer, DefaultDiscoverer, \ @@ -9,10 +9,10 @@ class Hello: - def default_advertizer(self, respond: bool = True) -> Advertizer: + def default_advertizer(self, respond: bool = True, delay: float = 0.1) -> Advertizer: raise NotImplementedError() - def scheduled_advertizer(self, timer: IReusableTimer | None = None, respond: bool = True) -> ScheduledAdvertizer: + def scheduled_advertizer(self, respond: bool = True, delay: float = 0.1) -> ScheduledAdvertizer: raise NotImplementedError() def discoverer(self) -> Discoverer: @@ -21,18 +21,24 @@ def discoverer(self) -> Discoverer: class DefaultHello(Hello): - def __init__(self, context: Context[Any] | None = None) -> None: + def __init__(self, context: Context[Any] | None = None, max_workers: int = 1, poll_timeout: float = 0.1) -> None: self._context = context if context else Context() - self._sender = RadioSender(self._context) - self._receiver = DishReceiver(self._context) + self._max_workers = max_workers + self._poll_timeout = poll_timeout - def default_advertizer(self, respond: bool = True) -> Advertizer: - return RespondingAdvertizer(self._sender, self._receiver) if respond else DefaultAdvertizer(self._sender) + def default_advertizer(self, respond: bool = True, delay: float = 0.1) -> Advertizer: + sender = RadioSender(self._context) + if respond: + receiver = DishReceiver(self._context, self._max_workers, self._poll_timeout) + return RespondingAdvertizer(sender, receiver, delay) + else: + return DefaultAdvertizer(sender) - def scheduled_advertizer(self, timer: IReusableTimer | None = None, respond: bool = True) -> ScheduledAdvertizer: - advertizer = self.default_advertizer(respond) - reusable_timer = timer if timer else ReusableTimer() - return DefaultScheduledAdvertizer(advertizer, reusable_timer) + def scheduled_advertizer(self, respond: bool = True, delay: float = 0.1) -> ScheduledAdvertizer: + advertizer = self.default_advertizer(respond, delay) + return DefaultScheduledAdvertizer(advertizer, ReusableTimer()) def discoverer(self) -> Discoverer: - return DefaultDiscoverer(self._sender, self._receiver) + sender = RadioSender(self._context) + receiver = DishReceiver(self._context, self._max_workers, self._poll_timeout) + return DefaultDiscoverer(sender, receiver) diff --git a/hello/discoverer.py b/hello/discoverer.py index 7b28223..bccb0e0 100644 --- a/hello/discoverer.py +++ b/hello/discoverer.py @@ -1,4 +1,6 @@ -from typing import Any, Callable +from dataclasses import dataclass +from enum import Enum +from typing import Any, Protocol from context_logger import get_logger @@ -7,6 +9,21 @@ log = get_logger('Discoverer') +class DiscoveryEventType(Enum): + DISCOVERED = 'discovered' + UPDATED = 'updated' + + +@dataclass +class DiscoveryEvent: + service: ServiceInfo + type: DiscoveryEventType + + +class OnDiscoveryEvent(Protocol): + def __call__(self, event: DiscoveryEvent) -> None: ... + + class Discoverer: def start(self, address: str, group: Group, query: ServiceQuery | None = None) -> None: @@ -21,10 +38,10 @@ def discover(self, query: ServiceQuery | None = None) -> None: def get_services(self) -> dict[str, ServiceInfo]: raise NotImplementedError() - def register(self, callback: Callable[[Any], None]) -> None: + def register(self, callback: OnDiscoveryEvent) -> None: raise NotImplementedError() - def deregister(self, callback: Callable[[Any], None]) -> None: + def deregister(self, callback: OnDiscoveryEvent) -> None: raise NotImplementedError() @@ -36,7 +53,7 @@ def __init__(self, sender: Sender, receiver: Receiver) -> None: self._group: Group | None = None self._matcher: ServiceMatcher | None = None self._services: dict[str, ServiceInfo] = {} - self._callbacks: list[Callable[[ServiceInfo], None]] = [] + self._callbacks: list[OnDiscoveryEvent] = [] def __enter__(self) -> Discoverer: return self @@ -65,24 +82,24 @@ def discover(self, query: ServiceQuery | None = None) -> None: self._sender.send(self._matcher.query) log.info('Service discovery initiated', query=self._matcher.query, group=self._group) else: - log.warning('Cannot initiate service discovery, discoverer not started', query=query) + log.warning('Cannot discover services, discoverer not started', query=query) def get_services(self) -> dict[str, ServiceInfo]: return self._services.copy() - def register(self, callback: Callable[[Any], None]) -> None: + def register(self, callback: OnDiscoveryEvent) -> None: self._callbacks.append(callback) - def deregister(self, callback: Callable[[Any], None]) -> None: + def deregister(self, callback: OnDiscoveryEvent) -> None: self._callbacks.remove(callback) - def _handle_message(self, data: dict[str, Any]) -> None: + def _handle_message(self, message: dict[str, Any]) -> None: service: ServiceInfo | None = None try: - service = ServiceInfo(**data) + service = ServiceInfo(**message) except Exception as error: - log.warn('Failed to handle received message', data=data, error=error) + log.warn('Failed to handle received message', data=message, error=error) if service: self._handle_service(service) @@ -91,29 +108,25 @@ def _handle_service(self, service: ServiceInfo) -> None: if self._matcher and self._matcher.matches(service): cached = self._services.get(service.name) - if self._is_update_needed(cached, service): - self._services[service.name] = service - for callback in self._callbacks: - try: - callback(service) - except Exception as error: - log.warn('Error in callback execution', service=service, error=error) + if event := self._create_event(cached, service): + self._handle_event(event) - def _is_update_needed(self, cached: ServiceInfo | None, service: ServiceInfo) -> bool: + def _create_event(self, cached: ServiceInfo | None, service: ServiceInfo) -> DiscoveryEvent | None: if cached: if cached != service: log.info('Service updated', old_service=cached, new_service=service) - return True + return DiscoveryEvent(service, DiscoveryEventType.UPDATED) else: log.info('Service discovered', service=service) - return True + return DiscoveryEvent(service, DiscoveryEventType.DISCOVERED) - return False + return None - def _handle_update(self, service: ServiceInfo) -> None: + def _handle_event(self, event: DiscoveryEvent) -> None: + service = event.service self._services[service.name] = service for callback in self._callbacks: try: - callback(service) + callback(event) except Exception as error: log.warn('Error in callback execution', service=service, error=error) diff --git a/hello/receiver.py b/hello/receiver.py index daedc5c..e7c2b87 100644 --- a/hello/receiver.py +++ b/hello/receiver.py @@ -1,5 +1,5 @@ from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable +from typing import Any, Protocol from context_logger import get_logger from zmq import DISH, Poller, POLLIN, POLLOUT, Context @@ -9,6 +9,10 @@ log = get_logger('Receiver') +class OnMessage(Protocol): + def __call__(self, message: dict[str, Any]) -> None: ... + + class Receiver: def start(self, source: GroupAccess) -> None: @@ -17,22 +21,23 @@ def start(self, source: GroupAccess) -> None: def stop(self) -> None: raise NotImplementedError() - def register(self, callback: Callable[[Any], None]) -> None: + def register(self, callback: OnMessage) -> None: raise NotImplementedError() - def deregister(self, callback: Callable[[Any], None]) -> None: + def deregister(self, callback: OnMessage) -> None: raise NotImplementedError() class DishReceiver(Receiver): - def __init__(self, context: Context[Any]) -> None: + def __init__(self, context: Context[Any], max_workers: int = 1, poll_timeout: float = 0.1) -> None: self._context = context self._dish = self._context.socket(DISH) self._poller = Poller() - self._executor = ThreadPoolExecutor(max_workers=1) + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._poll_timeout = int(poll_timeout * 1000) self._group: str | None = None - self._callbacks: list[Callable[[Any], None]] = [] + self._callbacks: list[OnMessage] = [] def __enter__(self) -> Receiver: return self @@ -42,12 +47,14 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def start(self, source: GroupAccess) -> None: try: + if self._group: + raise RuntimeError('Receiver already started') self._poller.register(self._dish, POLLIN) self._dish.bind(source.access_url) self._dish.join(source.full_group) self._group = source.full_group self._executor.submit(self._handle_messages) - log.info('Receiver started', address=source.access_url, group=source.full_group) + log.debug('Receiver started', address=source.access_url, group=source.full_group) except Exception as error: log.error('Failed to start receiver', address=source.access_url, group=source.full_group, error=error) raise error @@ -58,20 +65,20 @@ def stop(self) -> None: self._poller.register(self._dish, POLLOUT) self._executor.shutdown() self._dish.close() - log.info('Receiver stopped') + log.debug('Receiver stopped') except Exception as error: log.error('Failed to stop receiver', error=error) raise error - def register(self, callback: Callable[[Any], None]) -> None: + def register(self, callback: OnMessage) -> None: self._callbacks.append(callback) - def deregister(self, callback: Callable[[Any], None]) -> None: + def deregister(self, callback: OnMessage) -> None: self._callbacks.remove(callback) def _handle_messages(self) -> None: while self._group: - sockets = dict(self._poller.poll(timeout=100)) + sockets = dict(self._poller.poll(timeout=self._poll_timeout)) if self._dish in sockets and sockets[self._dish] == POLLIN: try: data = self._dish.recv_json() @@ -80,9 +87,9 @@ def _handle_messages(self) -> None: except Exception as error: log.error('Failed to receive message', group=self._group, error=error) - def _handle_message(self, data: dict[str, str]) -> None: + def _handle_message(self, message: dict[str, Any]) -> None: for callback in self._callbacks: try: - callback(data) + callback(message) except Exception as error: - log.warn('Error in callback execution', data=data, group=self._group, error=error) + log.warn('Error in callback execution', data=message, group=self._group, error=error) diff --git a/hello/sender.py b/hello/sender.py index 3a562b4..c5b7ece 100644 --- a/hello/sender.py +++ b/hello/sender.py @@ -35,9 +35,11 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def start(self, target: GroupAccess) -> None: try: + if self._group: + raise RuntimeError('Sender already started') self._radio.connect(target.access_url) self._group = target.full_group - log.info('Sender started', address=target.access_url, group=target.full_group) + log.debug('Sender started', address=target.access_url, group=target.full_group) except Exception as error: log.error('Failed to start sender', address=target.access_url, group=target.full_group, error=error) raise error @@ -46,7 +48,7 @@ def stop(self) -> None: try: self._group = None self._radio.close() - log.info('Sender stopped') + log.debug('Sender stopped') except Exception as error: log.error('Failed to stop sender', error=error) raise error @@ -57,6 +59,8 @@ def send(self, data: Any) -> None: self._send_json(data) else: log.warning('Unsupported message type', data=data, group=self._group) + else: + log.warning('Cannot send message, sender not started', data=data) def _convert_to_dict(self, data: Any) -> dict[str, Any] | None: if isinstance(data, dict): diff --git a/tests/advertizerIntegrationTest.py b/tests/advertizerIntegrationTest.py index d22630e..0b14455 100644 --- a/tests/advertizerIntegrationTest.py +++ b/tests/advertizerIntegrationTest.py @@ -1,12 +1,13 @@ import unittest from unittest import TestCase +from common_utility import ReusableTimer from context_logger import setup_logging from test_utility import wait_for_assertion from zmq import Context from hello import DefaultAdvertizer, ServiceInfo, Group, RadioSender, DishReceiver, GroupAccess, \ - RespondingAdvertizer, ServiceQuery + RespondingAdvertizer, ServiceQuery, DefaultScheduledAdvertizer ACCESS_URL = 'udp://239.0.0.1:5555' GROUP_NAME = 'test-group' @@ -76,7 +77,7 @@ def test_sends_hello_when_query_received(self): receiver = DishReceiver(context) messages = [] - with (RespondingAdvertizer(sender, receiver) as advertizer, + with (RespondingAdvertizer(sender, receiver, 0.01) as advertizer, RadioSender(context) as test_sender, DishReceiver(context) as test_receiver): test_sender.start(GroupAccess(ACCESS_URL, GROUP.query())) @@ -100,7 +101,7 @@ def test_sends_hello_when_info_changed_and_query_received(self): receiver = DishReceiver(context) messages = [] - with (RespondingAdvertizer(sender, receiver) as advertizer, + with (RespondingAdvertizer(sender, receiver, 0.01) as advertizer, RadioSender(context) as test_sender, DishReceiver(context) as test_receiver): test_sender.start(GroupAccess(ACCESS_URL, GROUP.query())) @@ -131,6 +132,46 @@ def test_sends_hello_when_info_changed_and_query_received(self): {'name': 'test-service', 'role': 'test-role', 'url': 'http://localhost:9090'} ], messages) + def test_sends_hello_when_schedules_advertisement_once(self): + # Given + context = Context() + sender = RadioSender(context) + _advertizer = DefaultAdvertizer(sender) + timer = ReusableTimer() + messages = [] + + with DefaultScheduledAdvertizer(_advertizer, timer) as advertizer, DishReceiver(context) as test_receiver: + test_receiver.start(GroupAccess(ACCESS_URL, GROUP.hello())) + test_receiver.register(lambda message: messages.append(message)) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.schedule(self.SERVICE_INFO, interval=0.01, one_shot=True) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(messages))) + + # Then + self.assertEqual([self.SERVICE_INFO.__dict__], messages) + + def test_sends_hello_when_schedules_advertisement_periodically(self): + # Given + context = Context() + sender = RadioSender(context) + _advertizer = DefaultAdvertizer(sender) + timer = ReusableTimer() + messages = [] + + with DefaultScheduledAdvertizer(_advertizer, timer) as advertizer, DishReceiver(context) as test_receiver: + test_receiver.start(GroupAccess(ACCESS_URL, GROUP.hello())) + test_receiver.register(lambda message: messages.append(message)) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.schedule(self.SERVICE_INFO, interval=0.01) + + # Then + wait_for_assertion(0.1, lambda: self.assertEqual(5, len(messages))) + if __name__ == '__main__': unittest.main() diff --git a/tests/apiIntegrationTest.py b/tests/apiIntegrationTest.py new file mode 100644 index 0000000..e4662a4 --- /dev/null +++ b/tests/apiIntegrationTest.py @@ -0,0 +1,97 @@ +import unittest +from unittest import TestCase + +from context_logger import setup_logging +from test_utility import wait_for_assertion +from zmq import Context + +from hello import ServiceInfo, Group, DefaultHello, ServiceQuery + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') +SERVICE_QUERY = ServiceQuery('test-service', 'test-role') + + +class ApiTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + self.SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + + def test_discoverer_caches_advertised_service(self): + # Given + context = Context() + hello = DefaultHello(context) + + with hello.default_advertizer() as advertizer, hello.discoverer() as discoverer: + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + # When + advertizer.advertise() + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) + + # Then + self.assertEqual({self.SERVICE_INFO.name: self.SERVICE_INFO}, discoverer.get_services()) + + def test_discoverer_caches_advertised_service_when_scheduled_once(self): + # Given + context = Context() + hello = DefaultHello(context) + + with hello.scheduled_advertizer() as advertizer, hello.discoverer() as discoverer: + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + # When + advertizer.schedule(interval=0.01, one_shot=True) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) + + # Then + self.assertEqual({self.SERVICE_INFO.name: self.SERVICE_INFO}, discoverer.get_services()) + + def test_discoverer_caches_advertised_service_when_scheduled_periodically(self): + # Given + context = Context() + hello = DefaultHello(context) + + with hello.scheduled_advertizer() as advertizer, hello.discoverer() as discoverer: + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + # When + advertizer.schedule(interval=0.01) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) + + # Then + self.assertEqual({self.SERVICE_INFO.name: self.SERVICE_INFO}, discoverer.get_services()) + + def test_discoverer_caches_discovery_response_service(self): + # Given + context = Context() + hello = DefaultHello(context) + + with hello.default_advertizer() as advertizer, hello.discoverer() as discoverer: + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + # When + discoverer.discover() + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) + + # Then + self.assertEqual({self.SERVICE_INFO.name: self.SERVICE_INFO}, discoverer.get_services()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/discovererIntegrationTest.py b/tests/discovererIntegrationTest.py index 770b673..1af32dd 100644 --- a/tests/discovererIntegrationTest.py +++ b/tests/discovererIntegrationTest.py @@ -14,7 +14,7 @@ SERVICE_QUERY = ServiceQuery('test-service', 'test-role') -class AdvertizerTest(TestCase): +class DiscovererTest(TestCase): SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') @classmethod diff --git a/tests/receiverIntegrationTest.py b/tests/receiverIntegrationTest.py index a580022..a427814 100644 --- a/tests/receiverIntegrationTest.py +++ b/tests/receiverIntegrationTest.py @@ -3,7 +3,7 @@ from context_logger import setup_logging from test_utility import wait_for_assertion -from zmq import Context, ZMQError +from zmq import Context from hello import ServiceInfo, Group, DishReceiver, GroupAccess, RadioSender @@ -31,7 +31,7 @@ def test_raises_error_when_restarted(self): receiver.start(group_access) # When, Then - with self.assertRaises(ZMQError): + with self.assertRaises(RuntimeError): receiver.start(group_access) def test_receives_message(self): diff --git a/tests/senderIntegrationTest.py b/tests/senderIntegrationTest.py index ae44256..b549b88 100644 --- a/tests/senderIntegrationTest.py +++ b/tests/senderIntegrationTest.py @@ -24,6 +24,19 @@ def setUpClass(cls): def setUp(self): print() + def test_raises_error_when_restarted(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = Context() + + with RadioSender(context) as sender: + sender.start(group_access) + + # When, Then + with self.assertRaises(RuntimeError): + sender.start(group_access) + + def test_sends_message(self): # Given group_access = GroupAccess(ACCESS_URL, GROUP.hello()) From 8077171a7fad5c85f73591eb55121272b8307239 Mon Sep 17 00:00:00 2001 From: Attila Gombos Date: Mon, 19 Jan 2026 14:50:30 +0100 Subject: [PATCH 5/6] Initial commit WIP --- hello/advertizer.py | 12 +- hello/discoverer.py | 42 ++-- hello/receiver.py | 33 +-- hello/sender.py | 4 +- tests/advertizerIntegrationTest.py | 28 +-- tests/apiIntegrationTest.py | 8 +- tests/defaultAdvertizerTest.py | 120 +++++++++++ tests/defaultDiscovererTest.py | 270 ++++++++++++++++++++++++ tests/defaultScheduledAdvertizerTest.py | 119 +++++++++++ tests/discovererIntegrationTest.py | 2 +- tests/dishReceiverTest.py | 176 +++++++++++++++ tests/radioSenderTest.py | 152 +++++++++++++ tests/receiverIntegrationTest.py | 2 +- tests/respondingAdvertizerTest.py | 119 +++++++++++ tests/senderIntegrationTest.py | 34 +-- 15 files changed, 1013 insertions(+), 108 deletions(-) create mode 100644 tests/defaultAdvertizerTest.py create mode 100644 tests/defaultDiscovererTest.py create mode 100644 tests/defaultScheduledAdvertizerTest.py create mode 100644 tests/dishReceiverTest.py create mode 100644 tests/radioSenderTest.py create mode 100644 tests/respondingAdvertizerTest.py diff --git a/hello/advertizer.py b/hello/advertizer.py index efb9922..13e2bea 100644 --- a/hello/advertizer.py +++ b/hello/advertizer.py @@ -117,14 +117,14 @@ def stop(self) -> None: def advertise(self, info: ServiceInfo | None = None) -> None: self._advertizer.advertise(info) - def schedule(self, info: ServiceInfo | None = None, interval: float = 10, one_shot: bool = False) -> None: + def schedule(self, info: ServiceInfo | None = None, interval: float = 60, one_shot: bool = False) -> None: if one_shot: self._timer.start(interval, self.advertise, [info]) log.info('One-shot service advertisement scheduled', service=info, interval=interval) else: - def periodic_advertise() -> None: - self.advertise(info) - self._timer.restart() - - self._timer.start(interval, periodic_advertise) + self._timer.start(interval, self._advertise_and_restart, [info]) log.info('Periodic service advertisement scheduled', service=info, interval=interval) + + def _advertise_and_restart(self, info: ServiceInfo | None = None) -> None: + self.advertise(info) + self._timer.restart() diff --git a/hello/discoverer.py b/hello/discoverer.py index bccb0e0..192f644 100644 --- a/hello/discoverer.py +++ b/hello/discoverer.py @@ -38,10 +38,13 @@ def discover(self, query: ServiceQuery | None = None) -> None: def get_services(self) -> dict[str, ServiceInfo]: raise NotImplementedError() - def register(self, callback: OnDiscoveryEvent) -> None: + def register(self, handler: OnDiscoveryEvent) -> None: raise NotImplementedError() - def deregister(self, callback: OnDiscoveryEvent) -> None: + def deregister(self, handler: OnDiscoveryEvent) -> None: + raise NotImplementedError() + + def get_handlers(self) -> list[OnDiscoveryEvent]: raise NotImplementedError() @@ -52,8 +55,8 @@ def __init__(self, sender: Sender, receiver: Receiver) -> None: self._receiver = receiver self._group: Group | None = None self._matcher: ServiceMatcher | None = None - self._services: dict[str, ServiceInfo] = {} - self._callbacks: list[OnDiscoveryEvent] = [] + self._cache: dict[str, ServiceInfo] = {} + self._handlers: list[OnDiscoveryEvent] = [] def __enter__(self) -> Discoverer: return self @@ -85,28 +88,27 @@ def discover(self, query: ServiceQuery | None = None) -> None: log.warning('Cannot discover services, discoverer not started', query=query) def get_services(self) -> dict[str, ServiceInfo]: - return self._services.copy() + return self._cache.copy() - def register(self, callback: OnDiscoveryEvent) -> None: - self._callbacks.append(callback) + def register(self, handler: OnDiscoveryEvent) -> None: + self._handlers.append(handler) - def deregister(self, callback: OnDiscoveryEvent) -> None: - self._callbacks.remove(callback) + def deregister(self, handler: OnDiscoveryEvent) -> None: + self._handlers.remove(handler) - def _handle_message(self, message: dict[str, Any]) -> None: - service: ServiceInfo | None = None + def get_handlers(self) -> list[OnDiscoveryEvent]: + return self._handlers.copy() + def _handle_message(self, message: dict[str, Any]) -> None: try: service = ServiceInfo(**message) + self._handle_service(service) except Exception as error: log.warn('Failed to handle received message', data=message, error=error) - if service: - self._handle_service(service) - def _handle_service(self, service: ServiceInfo) -> None: if self._matcher and self._matcher.matches(service): - cached = self._services.get(service.name) + cached = self._cache.get(service.name) if event := self._create_event(cached, service): self._handle_event(event) @@ -116,17 +118,17 @@ def _create_event(self, cached: ServiceInfo | None, service: ServiceInfo) -> Dis if cached != service: log.info('Service updated', old_service=cached, new_service=service) return DiscoveryEvent(service, DiscoveryEventType.UPDATED) + else: + return None else: log.info('Service discovered', service=service) return DiscoveryEvent(service, DiscoveryEventType.DISCOVERED) - return None - def _handle_event(self, event: DiscoveryEvent) -> None: service = event.service - self._services[service.name] = service - for callback in self._callbacks: + self._cache[service.name] = service + for callback in self._handlers: try: callback(event) except Exception as error: - log.warn('Error in callback execution', service=service, error=error) + log.warn('Error in event handler execution', event=event, error=error) diff --git a/hello/receiver.py b/hello/receiver.py index e7c2b87..b9e0068 100644 --- a/hello/receiver.py +++ b/hello/receiver.py @@ -2,7 +2,7 @@ from typing import Any, Protocol from context_logger import get_logger -from zmq import DISH, Poller, POLLIN, POLLOUT, Context +from zmq import DISH, Poller, POLLIN, Context from hello import GroupAccess @@ -21,10 +21,13 @@ def start(self, source: GroupAccess) -> None: def stop(self) -> None: raise NotImplementedError() - def register(self, callback: OnMessage) -> None: + def register(self, handler: OnMessage) -> None: raise NotImplementedError() - def deregister(self, callback: OnMessage) -> None: + def deregister(self, handler: OnMessage) -> None: + raise NotImplementedError() + + def get_handlers(self) -> list[OnMessage]: raise NotImplementedError() @@ -37,7 +40,7 @@ def __init__(self, context: Context[Any], max_workers: int = 1, poll_timeout: fl self._executor = ThreadPoolExecutor(max_workers=max_workers) self._poll_timeout = int(poll_timeout * 1000) self._group: str | None = None - self._callbacks: list[OnMessage] = [] + self._handlers: list[OnMessage] = [] def __enter__(self) -> Receiver: return self @@ -53,7 +56,7 @@ def start(self, source: GroupAccess) -> None: self._dish.bind(source.access_url) self._dish.join(source.full_group) self._group = source.full_group - self._executor.submit(self._handle_messages) + self._executor.submit(self._receive_loop) log.debug('Receiver started', address=source.access_url, group=source.full_group) except Exception as error: log.error('Failed to start receiver', address=source.access_url, group=source.full_group, error=error) @@ -62,7 +65,6 @@ def start(self, source: GroupAccess) -> None: def stop(self) -> None: try: self._group = None - self._poller.register(self._dish, POLLOUT) self._executor.shutdown() self._dish.close() log.debug('Receiver stopped') @@ -70,13 +72,16 @@ def stop(self) -> None: log.error('Failed to stop receiver', error=error) raise error - def register(self, callback: OnMessage) -> None: - self._callbacks.append(callback) + def register(self, handler: OnMessage) -> None: + self._handlers.append(handler) + + def deregister(self, handler: OnMessage) -> None: + self._handlers.remove(handler) - def deregister(self, callback: OnMessage) -> None: - self._callbacks.remove(callback) + def get_handlers(self) -> list[OnMessage]: + return self._handlers.copy() - def _handle_messages(self) -> None: + def _receive_loop(self) -> None: while self._group: sockets = dict(self._poller.poll(timeout=self._poll_timeout)) if self._dish in sockets and sockets[self._dish] == POLLIN: @@ -88,8 +93,8 @@ def _handle_messages(self) -> None: log.error('Failed to receive message', group=self._group, error=error) def _handle_message(self, message: dict[str, Any]) -> None: - for callback in self._callbacks: + for handler in self._handlers: try: - callback(message) + handler(message) except Exception as error: - log.warn('Error in callback execution', data=message, group=self._group, error=error) + log.warn('Error in message handler execution', data=message, group=self._group, error=error) diff --git a/hello/sender.py b/hello/sender.py index c5b7ece..57574e5 100644 --- a/hello/sender.py +++ b/hello/sender.py @@ -55,8 +55,8 @@ def stop(self) -> None: def send(self, data: Any) -> None: if self._group: - if data := self._convert_to_dict(data): - self._send_json(data) + if message := self._convert_to_dict(data): + self._send_json(message) else: log.warning('Unsupported message type', data=data, group=self._group) else: diff --git a/tests/advertizerIntegrationTest.py b/tests/advertizerIntegrationTest.py index 0b14455..69822e5 100644 --- a/tests/advertizerIntegrationTest.py +++ b/tests/advertizerIntegrationTest.py @@ -14,7 +14,7 @@ GROUP = Group(GROUP_NAME) -class AdvertizerTest(TestCase): +class AdvertizerIntegrationTest(TestCase): SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') @classmethod @@ -44,32 +44,6 @@ def test_sends_hello_when_advertises_service(self): # Then self.assertEqual([self.SERVICE_INFO.__dict__], messages) - def test_sends_hello_when_advertises_service_and_info_changed(self): - # Given - context = Context() - sender = RadioSender(context) - messages = [] - - with DefaultAdvertizer(sender) as advertizer, DishReceiver(context) as test_receiver: - test_receiver.start(GroupAccess(ACCESS_URL, GROUP.hello())) - test_receiver.register(lambda message: messages.append(message)) - advertizer.start(ACCESS_URL, GROUP) - - advertizer.advertise(self.SERVICE_INFO) - - self.SERVICE_INFO.url = 'http://localhost:9090' - - # When - advertizer.advertise(self.SERVICE_INFO) - - wait_for_assertion(0.1, lambda: self.assertEqual(2, len(messages))) - - # Then - self.assertEqual([ - {'name': 'test-service', 'role': 'test-role', 'url': 'http://localhost:8080'}, - {'name': 'test-service', 'role': 'test-role', 'url': 'http://localhost:9090'} - ], messages) - def test_sends_hello_when_query_received(self): # Given context = Context() diff --git a/tests/apiIntegrationTest.py b/tests/apiIntegrationTest.py index e4662a4..cc86505 100644 --- a/tests/apiIntegrationTest.py +++ b/tests/apiIntegrationTest.py @@ -14,7 +14,7 @@ SERVICE_QUERY = ServiceQuery('test-service', 'test-role') -class ApiTest(TestCase): +class ApiIntegrationTest(TestCase): @classmethod def setUpClass(cls): @@ -29,7 +29,7 @@ def test_discoverer_caches_advertised_service(self): context = Context() hello = DefaultHello(context) - with hello.default_advertizer() as advertizer, hello.discoverer() as discoverer: + with hello.default_advertizer(respond=False) as advertizer, hello.discoverer() as discoverer: advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) @@ -46,7 +46,7 @@ def test_discoverer_caches_advertised_service_when_scheduled_once(self): context = Context() hello = DefaultHello(context) - with hello.scheduled_advertizer() as advertizer, hello.discoverer() as discoverer: + with hello.scheduled_advertizer(respond=False) as advertizer, hello.discoverer() as discoverer: advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) @@ -87,7 +87,7 @@ def test_discoverer_caches_discovery_response_service(self): # When discoverer.discover() - wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) + wait_for_assertion(0.2, lambda: self.assertEqual(1, len(discoverer.get_services()))) # Then self.assertEqual({self.SERVICE_INFO.name: self.SERVICE_INFO}, discoverer.get_services()) diff --git a/tests/defaultAdvertizerTest.py b/tests/defaultAdvertizerTest.py new file mode 100644 index 0000000..a057b3f --- /dev/null +++ b/tests/defaultAdvertizerTest.py @@ -0,0 +1,120 @@ +import unittest +from unittest import TestCase +from unittest.mock import MagicMock + +from context_logger import setup_logging + +from hello import ServiceInfo, Group, Sender, DefaultAdvertizer, GroupAccess + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + + +class DefaultAdvertizerTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_stops_sender_on_exit(self): + # Given + sender = MagicMock(spec=Sender) + + with DefaultAdvertizer(sender) as advertizer: + advertizer.start(ACCESS_URL, GROUP) + + # When + + # Then + sender.stop.assert_called_once() + + def test_stops_sender_when_stopped(self): + # Given + sender = MagicMock(spec=Sender) + advertizer = DefaultAdvertizer(sender) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.stop() + + # Then + sender.stop.assert_called_once() + + def test_starts_sender_when_started(self): + # Given + sender = MagicMock(spec=Sender) + advertizer = DefaultAdvertizer(sender) + + # When + advertizer.start(ACCESS_URL, GROUP) + + # Then + sender.start.assert_called_once_with(GroupAccess(ACCESS_URL, GROUP.hello())) + + def test_sends_info_when_passed_at_start(self): + # Given + sender = MagicMock(spec=Sender) + advertizer = DefaultAdvertizer(sender) + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + + # When + advertizer.advertise() + + # Then + sender.send.assert_called_once_with(SERVICE_INFO) + + def test_sends_info_when_passed_at_advertise(self): + # Given + sender = MagicMock(spec=Sender) + advertizer = DefaultAdvertizer(sender) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.advertise(SERVICE_INFO) + + # Then + sender.send.assert_called_once_with(SERVICE_INFO) + + def test_sends_last_info_when_passed_at_start_and_at_advertise(self): + # Given + sender = MagicMock(spec=Sender) + advertizer = DefaultAdvertizer(sender) + advertizer.start(ACCESS_URL, GROUP, ServiceInfo('test-service', 'test-role', 'http://localhost:9090')) + + # When + advertizer.advertise(SERVICE_INFO) + + # Then + sender.send.assert_called_once_with(SERVICE_INFO) + + def test_does_not_send_info_when_no_info_provided(self): + # Given + sender = MagicMock(spec=Sender) + advertizer = DefaultAdvertizer(sender) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.advertise() + + # Then + sender.send.assert_not_called() + + def test_does_not_send_info_when_not_started(self): + # Given + sender = MagicMock(spec=Sender) + advertizer = DefaultAdvertizer(sender) + + # When + advertizer.advertise(SERVICE_INFO) + + # Then + sender.send.assert_not_called() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/defaultDiscovererTest.py b/tests/defaultDiscovererTest.py new file mode 100644 index 0000000..1bc583e --- /dev/null +++ b/tests/defaultDiscovererTest.py @@ -0,0 +1,270 @@ +import unittest +from unittest import TestCase +from unittest.mock import MagicMock + +from context_logger import setup_logging + +from hello import ServiceInfo, Group, GroupAccess, \ + ServiceQuery, DefaultDiscoverer, Sender, Receiver, OnDiscoveryEvent, DiscoveryEventType, DiscoveryEvent + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_QUERY = ServiceQuery('test-.*', 'test-.*') +SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + + +class DefaultDiscovererTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_stops_sender_and_receiver_on_exit(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + + with DefaultDiscoverer(sender, receiver) as discoverer: + discoverer.start(ACCESS_URL, GROUP) + + # When + + # Then + sender.stop.assert_called_once() + receiver.stop.assert_called_once() + + def test_stops_sender_and_receiver_when_stopped(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP) + + # When + discoverer.stop() + + # Then + sender.stop.assert_called_once() + receiver.stop.assert_called_once() + + def test_starts_sender_and_receiver_when_started(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + + # When + discoverer.start(ACCESS_URL, GROUP) + + # Then + sender.start.assert_called_once_with(GroupAccess(ACCESS_URL, GROUP.query())) + receiver.start.assert_called_once_with(GroupAccess(ACCESS_URL, GROUP.hello())) + + def test_registers_event_handler(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + handler = MagicMock(spec=OnDiscoveryEvent) + + # When + discoverer.register(handler) + + # Then + self.assertIn(handler, discoverer.get_handlers()) + + def test_deregisters_event_handler(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + handler = MagicMock(spec=OnDiscoveryEvent) + discoverer.register(handler) + + # When + discoverer.deregister(handler) + + # Then + self.assertNotIn(handler, discoverer.get_handlers()) + + def test_caches_service_and_calls_handler_when_receives_matching_info(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + handler = MagicMock(spec=OnDiscoveryEvent) + discoverer.register(handler) + + # When + discoverer._handle_message(SERVICE_INFO.__dict__) + + # Then + self.assertEqual({SERVICE_INFO.name: SERVICE_INFO}, discoverer.get_services()) + handler.assert_called_once_with(DiscoveryEvent(SERVICE_INFO, DiscoveryEventType.DISCOVERED)) + + def test_updates_service_and_calls_handler_when_receives_matching_info(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + handler = MagicMock(spec=OnDiscoveryEvent) + discoverer.register(handler) + discoverer._handle_message(SERVICE_INFO.__dict__) + handler.reset_mock() + new_service_info = ServiceInfo(SERVICE_INFO.name, SERVICE_INFO.role, 'http://localhost:9090') + + # When + discoverer._handle_message(new_service_info.__dict__) + + # Then + self.assertEqual({SERVICE_INFO.name: new_service_info}, discoverer.get_services()) + handler.assert_called_once_with(DiscoveryEvent(new_service_info, DiscoveryEventType.UPDATED)) + + def test_does_not_call_handler_when_service_info_not_changed(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + handler = MagicMock(spec=OnDiscoveryEvent) + discoverer.register(handler) + discoverer._handle_message(SERVICE_INFO.__dict__) + handler.reset_mock() + + # When + discoverer._handle_message(SERVICE_INFO.__dict__) + + # Then + handler.assert_not_called() + + def test_handles_handler_error_gracefully(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + handler = MagicMock(spec=OnDiscoveryEvent) + handler.side_effect = Exception("Handler error") + discoverer.register(handler) + + # When + discoverer._handle_message(SERVICE_INFO.__dict__) + + # Then + self.assertEqual({SERVICE_INFO.name: SERVICE_INFO}, discoverer.get_services()) + handler.assert_called_once_with(DiscoveryEvent(SERVICE_INFO, DiscoveryEventType.DISCOVERED)) + + def test_handles_invalid_message_gracefully(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + # When + discoverer._handle_message({'invalid': 'message'}) + + # Then + self.assertEqual({}, discoverer.get_services()) + + def test_does_not_cache_service_when_info_not_matching_query(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + non_matching_info = ServiceInfo('other-service', 'test-role', 'http://localhost:8080') + + # When + discoverer._handle_message(non_matching_info.__dict__) + + # Then + self.assertEqual({}, discoverer.get_services()) + + def test_does_not_cache_service_when_no_query_set(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP) + + # When + discoverer._handle_message(SERVICE_INFO.__dict__) + + # Then + self.assertEqual({}, discoverer.get_services()) + + def test_sends_query_when_passed_at_start(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + # When + discoverer.discover() + + # Then + sender.send.assert_called_once_with(SERVICE_QUERY) + + def test_sends_query_when_passed_at_discover(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP) + + # When + discoverer.discover(SERVICE_QUERY) + + # Then + sender.send.assert_called_once_with(SERVICE_QUERY) + + def test_sends_last_query_when_passed_at_start_and_at_discover(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, ServiceQuery('other-.*', 'test-.*')) + + # When + discoverer.discover(SERVICE_QUERY) + + # Then + sender.send.assert_called_once_with(SERVICE_QUERY) + + def test_does_not_send_query_when_no_query_provided(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP) + + # When + discoverer.discover() + + # Then + sender.send.assert_not_called() + + def test_does_not_send_query_when_not_started(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + + # When + discoverer.discover(SERVICE_QUERY) + + # Then + sender.send.assert_not_called() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/defaultScheduledAdvertizerTest.py b/tests/defaultScheduledAdvertizerTest.py new file mode 100644 index 0000000..cdb42bf --- /dev/null +++ b/tests/defaultScheduledAdvertizerTest.py @@ -0,0 +1,119 @@ +import unittest +from unittest import TestCase +from unittest.mock import MagicMock + +from common_utility import IReusableTimer +from context_logger import setup_logging + +from hello import ServiceInfo, Group, DefaultScheduledAdvertizer, Advertizer + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + + +class DefaultScheduledAdvertizerTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_stops_timer_and_advertizer_on_exit(self): + # Given + _advertizer = MagicMock(spec=Advertizer) + timer = MagicMock(spec=IReusableTimer) + + with DefaultScheduledAdvertizer(_advertizer, timer) as advertizer: + advertizer.start(ACCESS_URL, GROUP) + + # When + + # Then + timer.cancel.assert_called_once() + _advertizer.stop.assert_called_once() + + def test_stops_timer_and_advertizer_when_stopped(self): + # Given + _advertizer = MagicMock(spec=Advertizer) + timer = MagicMock(spec=IReusableTimer) + advertizer = DefaultScheduledAdvertizer(_advertizer, timer) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.stop() + + # Then + timer.cancel.assert_called_once() + _advertizer.stop.assert_called_once() + + def test_starts_advertizer_when_started(self): + # Given + _advertizer = MagicMock(spec=Advertizer) + timer = MagicMock(spec=IReusableTimer) + advertizer = DefaultScheduledAdvertizer(_advertizer, timer) + + # When + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + + # Then + _advertizer.start.assert_called_once_with(ACCESS_URL, GROUP, SERVICE_INFO) + + def test_sends_service_info(self): + # Given + _advertizer = MagicMock(spec=Advertizer) + timer = MagicMock(spec=IReusableTimer) + advertizer = DefaultScheduledAdvertizer(_advertizer, timer) + + # When + advertizer.advertise(SERVICE_INFO) + + # Then + _advertizer.advertise.assert_called_once_with(SERVICE_INFO) + + def test_schedules_advertise_once(self): + # Given + _advertizer = MagicMock(spec=Advertizer) + timer = MagicMock(spec=IReusableTimer) + advertizer = DefaultScheduledAdvertizer(_advertizer, timer) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.schedule(SERVICE_INFO, 60, True) + + # Then + timer.start.assert_called_once_with(60, advertizer.advertise, [SERVICE_INFO]) + + def test_schedules_periodic_advertise(self): + # Given + _advertizer = MagicMock(spec=Advertizer) + timer = MagicMock(spec=IReusableTimer) + advertizer = DefaultScheduledAdvertizer(_advertizer, timer) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.schedule(SERVICE_INFO, 60, False) + + # Then + timer.start.assert_called_once_with(60, advertizer._advertise_and_restart, [SERVICE_INFO]) + + def test_advertise_and_restart_calls_advertise_and_restarts_timer(self): + # Given + _advertizer = MagicMock(spec=Advertizer) + timer = MagicMock(spec=IReusableTimer) + advertizer = DefaultScheduledAdvertizer(_advertizer, timer) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer._advertise_and_restart(SERVICE_INFO) + + # Then + _advertizer.advertise.assert_called_once_with(SERVICE_INFO) + timer.restart.assert_called_once() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/discovererIntegrationTest.py b/tests/discovererIntegrationTest.py index 1af32dd..5d5b0a6 100644 --- a/tests/discovererIntegrationTest.py +++ b/tests/discovererIntegrationTest.py @@ -14,7 +14,7 @@ SERVICE_QUERY = ServiceQuery('test-service', 'test-role') -class DiscovererTest(TestCase): +class DiscovererIntegrationTest(TestCase): SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') @classmethod diff --git a/tests/dishReceiverTest.py b/tests/dishReceiverTest.py new file mode 100644 index 0000000..6e18543 --- /dev/null +++ b/tests/dishReceiverTest.py @@ -0,0 +1,176 @@ +import unittest +from unittest import TestCase +from unittest.mock import MagicMock + +from context_logger import setup_logging +from test_utility import wait_for_assertion +from zmq import Context, ZMQError, Poller, POLLIN + +from hello import ServiceInfo, Group, GroupAccess, DishReceiver, OnMessage + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + + +class DishReceiverTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_raises_error_when_restarted(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + + with DishReceiver(context) as receiver: + receiver.start(group_access) + + # When, Then + with self.assertRaises(RuntimeError): + receiver.start(group_access) + + def test_raises_error_when_fails_to_bind_socket(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + context.socket.return_value.bind.side_effect = ZMQError(1, "Bind failed") + receiver = DishReceiver(context) + + # When, Then + with self.assertRaises(ZMQError): + receiver.start(group_access) + + def test_closes_socket_on_exit(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + + with DishReceiver(context) as receiver: + receiver.start(group_access) + + # When + + # Then + context.socket.return_value.close.assert_called_once() + + def test_closes_socket_when_stopped(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + receiver = DishReceiver(context) + receiver.start(group_access) + + # When + receiver.stop() + + # Then + context.socket.return_value.close.assert_called_once() + + def test_raises_error_when_fails_to_close_socket_on_stop(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + context.socket.return_value.close.side_effect = [ZMQError(1, "Close failed"), None] + + with DishReceiver(context) as receiver: + receiver.start(group_access) + + # When, Then + with self.assertRaises(ZMQError): + receiver.stop() + + def test_registers_handler(self): + # Given + context = MagicMock(spec=Context) + receiver = DishReceiver(context) + handler = MagicMock(spec=OnMessage) + + # When + receiver.register(handler) + + # Then + self.assertIn(handler, receiver.get_handlers()) + + def test_deregisters_handler(self): + # Given + context = MagicMock(spec=Context) + receiver = DishReceiver(context) + handler = MagicMock(spec=OnMessage) + receiver.register(handler) + + # When + receiver.deregister(handler) + + # Then + self.assertNotIn(handler, receiver.get_handlers()) + + def test_calls_registered_handler_on_message(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + context.socket.return_value.recv_json.return_value = SERVICE_INFO.__dict__ + handler = MagicMock(spec=OnMessage) + + with DishReceiver(context) as receiver: + receiver._poller = MagicMock(spec=Poller) + receiver._poller.poll.side_effect = [ + {context.socket.return_value: POLLIN}, + ] + receiver.register(handler) + + # When + receiver.start(group_access) + + # Then + wait_for_assertion(0.1, lambda: handler.assert_called_once_with(SERVICE_INFO.__dict__)) + + def test_handles_message_receive_error_gracefully(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + context.socket.return_value.recv_json.side_effect = ZMQError(1, "Receive failed") + handler = MagicMock(spec=OnMessage) + + with DishReceiver(context) as receiver: + receiver._poller = MagicMock(spec=Poller) + receiver._poller.poll.side_effect = [ + {context.socket.return_value: POLLIN}, + ] + receiver.register(handler) + + # When + receiver.start(group_access) + + # Then + handler.assert_not_called() + + def test_handles_handler_execution_error_gracefully(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + context.socket.return_value.recv_json.return_value = SERVICE_INFO.__dict__ + handler = MagicMock(spec=OnMessage) + handler.side_effect = Exception("Execution failed") + + with DishReceiver(context) as receiver: + receiver._poller = MagicMock(spec=Poller) + receiver._poller.poll.side_effect = [ + {context.socket.return_value: POLLIN}, + ] + receiver.register(handler) + + # When + receiver.start(group_access) + + # Then + wait_for_assertion(0.1, lambda: handler.assert_called_once_with(SERVICE_INFO.__dict__)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/radioSenderTest.py b/tests/radioSenderTest.py new file mode 100644 index 0000000..d059a0c --- /dev/null +++ b/tests/radioSenderTest.py @@ -0,0 +1,152 @@ +import unittest +from unittest import TestCase +from unittest.mock import MagicMock + +from context_logger import setup_logging +from zmq import Context, ZMQError + +from hello import ServiceInfo, Group, GroupAccess +from hello.sender import RadioSender + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + + +class RadioSenderTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_raises_error_when_restarted(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + sender = RadioSender(context) + sender.start(group_access) + + # When, Then + with self.assertRaises(RuntimeError): + sender.start(group_access) + + def test_raises_error_when_fails_to_connect_socket(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + context.socket.return_value.connect.side_effect = ZMQError(1, "Connect failed") + sender = RadioSender(context) + + # When, Then + with self.assertRaises(ZMQError): + sender.start(group_access) + + def test_closes_socket_on_exit(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + + with RadioSender(context) as sender: + sender.start(group_access) + + # When + + # Then + context.socket.return_value.close.assert_called_once() + + def test_closes_socket_when_stopped(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + sender = RadioSender(context) + sender.start(group_access) + + # When + sender.stop() + + # Then + context.socket.return_value.close.assert_called_once() + + def test_raises_error_when_fails_to_close_socket_on_stop(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + context.socket.return_value.close.side_effect = ZMQError(1, "Close failed") + sender = RadioSender(context) + sender.start(group_access) + + # When, Then + with self.assertRaises(ZMQError): + sender.stop() + + def test_sends_message_when_convertible_to_dict(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + sender = RadioSender(context) + sender.start(group_access) + + # When + sender.send(SERVICE_INFO) + + # Then + context.socket.return_value.send_json.assert_called_with(SERVICE_INFO.__dict__, group='hello:test-group') + + def test_sends_message_when_type_is_dict(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + sender = RadioSender(context) + sender.start(group_access) + + # When + sender.send(SERVICE_INFO.__dict__) + + # Then + context.socket.return_value.send_json.assert_called_with(SERVICE_INFO.__dict__, group='hello:test-group') + + def test_does_not_send_message_when_not_serializable(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + sender = RadioSender(context) + sender.start(group_access) + + # When + sender.send("not serializable message") + + # Then + context.socket.return_value.send_json.assert_not_called() + + def test_does_not_send_message_when_not_started(self): + # Given + context = MagicMock(spec=Context) + sender = RadioSender(context) + + # When + sender.send(SERVICE_INFO) + + # Then + context.socket.return_value.send_json.assert_not_called() + + def test_handles_send_message_error_gracefully(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + sender = RadioSender(context) + sender.start(group_access) + context.socket.return_value.send_json.side_effect = ZMQError(1, "Send failed") + + # When + sender.send(SERVICE_INFO) + + # Then + context.socket.return_value.send_json.assert_called_once_with(SERVICE_INFO.__dict__, group='hello:test-group') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/receiverIntegrationTest.py b/tests/receiverIntegrationTest.py index a427814..e1c30f8 100644 --- a/tests/receiverIntegrationTest.py +++ b/tests/receiverIntegrationTest.py @@ -13,7 +13,7 @@ SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') -class ReceiverTest(TestCase): +class ReceiverIntegrationTest(TestCase): @classmethod def setUpClass(cls): diff --git a/tests/respondingAdvertizerTest.py b/tests/respondingAdvertizerTest.py new file mode 100644 index 0000000..a048c6e --- /dev/null +++ b/tests/respondingAdvertizerTest.py @@ -0,0 +1,119 @@ +import unittest +from unittest import TestCase +from unittest.mock import MagicMock + +from context_logger import setup_logging + +from hello import ServiceInfo, Group, Sender, GroupAccess, Receiver, RespondingAdvertizer, ServiceQuery + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + + +class RespondingAdvertizerTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_stops_sender_and_receiver_on_exit(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + + with RespondingAdvertizer(sender, receiver) as advertizer: + advertizer.start(ACCESS_URL, GROUP) + + # When + + # Then + sender.stop.assert_called_once() + receiver.stop.assert_called_once() + + def test_stops_sender_and_receiver_when_stopped(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + advertizer = RespondingAdvertizer(sender, receiver) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.stop() + + # Then + sender.stop.assert_called_once() + receiver.stop.assert_called_once() + + def test_starts_sender_and_receiver_when_started(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + advertizer = RespondingAdvertizer(sender, receiver) + + # When + advertizer.start(ACCESS_URL, GROUP) + + # Then + sender.start.assert_called_once_with(GroupAccess(ACCESS_URL, GROUP.hello())) + receiver.start.assert_called_once_with(GroupAccess(ACCESS_URL, GROUP.query())) + + def test_sends_service_info_when_receives_matching_query(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + advertizer = RespondingAdvertizer(sender, receiver) + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + + # When + advertizer._handle_message(ServiceQuery('test-.*', 'test-.*').__dict__) + + # Then + sender.send.assert_called_once_with(SERVICE_INFO) + + def test_does_not_send_service_info_when_receives_non_matching_query(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + advertizer = RespondingAdvertizer(sender, receiver) + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + + # When + advertizer._handle_message(ServiceQuery('other-.*', 'test-.*').__dict__) + + # Then + sender.send.assert_not_called() + + def test_does_not_send_service_info_when_no_service_info_set(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + advertizer = RespondingAdvertizer(sender, receiver) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer._handle_message(ServiceQuery('test-.*', 'test-.*').__dict__) + + # Then + sender.send.assert_not_called() + + def test_handles_invalid_message_gracefully(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + advertizer = RespondingAdvertizer(sender, receiver) + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + + # When + advertizer._handle_message({'invalid': 'message'}) + + # Then + sender.send.assert_not_called() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/senderIntegrationTest.py b/tests/senderIntegrationTest.py index b549b88..4596256 100644 --- a/tests/senderIntegrationTest.py +++ b/tests/senderIntegrationTest.py @@ -1,5 +1,4 @@ import unittest -from time import sleep from unittest import TestCase from context_logger import setup_logging @@ -15,7 +14,7 @@ SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') -class SenderTest(TestCase): +class SenderIntegrationTest(TestCase): @classmethod def setUpClass(cls): @@ -24,19 +23,6 @@ def setUpClass(cls): def setUp(self): print() - def test_raises_error_when_restarted(self): - # Given - group_access = GroupAccess(ACCESS_URL, GROUP.hello()) - context = Context() - - with RadioSender(context) as sender: - sender.start(group_access) - - # When, Then - with self.assertRaises(RuntimeError): - sender.start(group_access) - - def test_sends_message(self): # Given group_access = GroupAccess(ACCESS_URL, GROUP.hello()) @@ -56,24 +42,6 @@ def test_sends_message(self): # Then self.assertEqual([SERVICE_INFO.__dict__], messages) - def test_skips_sending_message_when_not_serializable(self): - # Given - group_access = GroupAccess(ACCESS_URL, GROUP.hello()) - context = Context() - messages = [] - - with RadioSender(context) as sender, DishReceiver(context) as test_receiver: - test_receiver.register(lambda message: messages.append(message)) - test_receiver.start(group_access) - sender.start(group_access) - - # When - sender.send('not serializable message') - - sleep(0.1) - - self.assertEqual(0, len(messages)) - if __name__ == '__main__': unittest.main() From 72e4cd0eaa43bb80bc9a1c83197fd7d55c65876b Mon Sep 17 00:00:00 2001 From: Attila Gombos Date: Mon, 19 Jan 2026 15:55:44 +0100 Subject: [PATCH 6/6] Initial commit WIP --- hello/service.py | 4 ++-- tests/advertizerIntegrationTest.py | 14 +++++++------- tests/apiIntegrationTest.py | 11 +++++------ tests/defaultAdvertizerTest.py | 4 ++-- tests/defaultDiscovererTest.py | 6 +++--- tests/defaultScheduledAdvertizerTest.py | 2 +- tests/discovererIntegrationTest.py | 10 +++++----- tests/dishReceiverTest.py | 2 +- tests/radioSenderTest.py | 2 +- tests/receiverIntegrationTest.py | 2 +- tests/respondingAdvertizerTest.py | 2 +- tests/senderIntegrationTest.py | 2 +- 12 files changed, 30 insertions(+), 31 deletions(-) diff --git a/hello/service.py b/hello/service.py index 374b595..99f909e 100644 --- a/hello/service.py +++ b/hello/service.py @@ -1,12 +1,12 @@ import re -from dataclasses import dataclass +from dataclasses import dataclass, field @dataclass class ServiceInfo: name: str role: str - url: str + urls: dict[str, str] = field(default_factory=dict) @dataclass diff --git a/tests/advertizerIntegrationTest.py b/tests/advertizerIntegrationTest.py index 69822e5..74554da 100644 --- a/tests/advertizerIntegrationTest.py +++ b/tests/advertizerIntegrationTest.py @@ -15,7 +15,7 @@ class AdvertizerIntegrationTest(TestCase): - SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) @classmethod def setUpClass(cls): @@ -23,7 +23,7 @@ def setUpClass(cls): def setUp(self): print() - self.SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + self.SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) def test_sends_hello_when_advertises_service(self): # Given @@ -90,7 +90,7 @@ def test_sends_hello_when_info_changed_and_query_received(self): wait_for_assertion(0.1, lambda: self.assertEqual(2, len(messages))) - self.SERVICE_INFO.url = 'http://localhost:9090' + self.SERVICE_INFO.urls['test'] = 'http://localhost:9090' advertizer.advertise(self.SERVICE_INFO) # When @@ -100,10 +100,10 @@ def test_sends_hello_when_info_changed_and_query_received(self): # Then self.assertEqual([ - {'name': 'test-service', 'role': 'test-role', 'url': 'http://localhost:8080'}, - {'name': 'test-service', 'role': 'test-role', 'url': 'http://localhost:8080'}, - {'name': 'test-service', 'role': 'test-role', 'url': 'http://localhost:9090'}, - {'name': 'test-service', 'role': 'test-role', 'url': 'http://localhost:9090'} + {'name': 'test-service', 'role': 'test-role', 'urls': {'test': 'http://localhost:8080'}}, + {'name': 'test-service', 'role': 'test-role', 'urls': {'test': 'http://localhost:8080'}}, + {'name': 'test-service', 'role': 'test-role', 'urls': {'test': 'http://localhost:9090'}}, + {'name': 'test-service', 'role': 'test-role', 'urls': {'test': 'http://localhost:9090'}} ], messages) def test_sends_hello_when_schedules_advertisement_once(self): diff --git a/tests/apiIntegrationTest.py b/tests/apiIntegrationTest.py index cc86505..5ea16f6 100644 --- a/tests/apiIntegrationTest.py +++ b/tests/apiIntegrationTest.py @@ -10,7 +10,7 @@ ACCESS_URL = 'udp://239.0.0.1:5555' GROUP_NAME = 'test-group' GROUP = Group(GROUP_NAME) -SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) SERVICE_QUERY = ServiceQuery('test-service', 'test-role') @@ -22,7 +22,6 @@ def setUpClass(cls): def setUp(self): print() - self.SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') def test_discoverer_caches_advertised_service(self): # Given @@ -39,7 +38,7 @@ def test_discoverer_caches_advertised_service(self): wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) # Then - self.assertEqual({self.SERVICE_INFO.name: self.SERVICE_INFO}, discoverer.get_services()) + self.assertEqual({SERVICE_INFO.name: SERVICE_INFO}, discoverer.get_services()) def test_discoverer_caches_advertised_service_when_scheduled_once(self): # Given @@ -56,7 +55,7 @@ def test_discoverer_caches_advertised_service_when_scheduled_once(self): wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) # Then - self.assertEqual({self.SERVICE_INFO.name: self.SERVICE_INFO}, discoverer.get_services()) + self.assertEqual({SERVICE_INFO.name: SERVICE_INFO}, discoverer.get_services()) def test_discoverer_caches_advertised_service_when_scheduled_periodically(self): # Given @@ -73,7 +72,7 @@ def test_discoverer_caches_advertised_service_when_scheduled_periodically(self): wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) # Then - self.assertEqual({self.SERVICE_INFO.name: self.SERVICE_INFO}, discoverer.get_services()) + self.assertEqual({SERVICE_INFO.name: SERVICE_INFO}, discoverer.get_services()) def test_discoverer_caches_discovery_response_service(self): # Given @@ -90,7 +89,7 @@ def test_discoverer_caches_discovery_response_service(self): wait_for_assertion(0.2, lambda: self.assertEqual(1, len(discoverer.get_services()))) # Then - self.assertEqual({self.SERVICE_INFO.name: self.SERVICE_INFO}, discoverer.get_services()) + self.assertEqual({SERVICE_INFO.name: SERVICE_INFO}, discoverer.get_services()) if __name__ == '__main__': diff --git a/tests/defaultAdvertizerTest.py b/tests/defaultAdvertizerTest.py index a057b3f..474e3b3 100644 --- a/tests/defaultAdvertizerTest.py +++ b/tests/defaultAdvertizerTest.py @@ -9,7 +9,7 @@ ACCESS_URL = 'udp://239.0.0.1:5555' GROUP_NAME = 'test-group' GROUP = Group(GROUP_NAME) -SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) class DefaultAdvertizerTest(TestCase): @@ -84,7 +84,7 @@ def test_sends_last_info_when_passed_at_start_and_at_advertise(self): # Given sender = MagicMock(spec=Sender) advertizer = DefaultAdvertizer(sender) - advertizer.start(ACCESS_URL, GROUP, ServiceInfo('test-service', 'test-role', 'http://localhost:9090')) + advertizer.start(ACCESS_URL, GROUP, ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:9090'})) # When advertizer.advertise(SERVICE_INFO) diff --git a/tests/defaultDiscovererTest.py b/tests/defaultDiscovererTest.py index 1bc583e..70fd70b 100644 --- a/tests/defaultDiscovererTest.py +++ b/tests/defaultDiscovererTest.py @@ -11,7 +11,7 @@ GROUP_NAME = 'test-group' GROUP = Group(GROUP_NAME) SERVICE_QUERY = ServiceQuery('test-.*', 'test-.*') -SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) class DefaultDiscovererTest(TestCase): @@ -117,7 +117,7 @@ def test_updates_service_and_calls_handler_when_receives_matching_info(self): discoverer.register(handler) discoverer._handle_message(SERVICE_INFO.__dict__) handler.reset_mock() - new_service_info = ServiceInfo(SERVICE_INFO.name, SERVICE_INFO.role, 'http://localhost:9090') + new_service_info = ServiceInfo(SERVICE_INFO.name, SERVICE_INFO.role, {'test': 'http://localhost:9090'}) # When discoverer._handle_message(new_service_info.__dict__) @@ -180,7 +180,7 @@ def test_does_not_cache_service_when_info_not_matching_query(self): discoverer = DefaultDiscoverer(sender, receiver) discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) - non_matching_info = ServiceInfo('other-service', 'test-role', 'http://localhost:8080') + non_matching_info = ServiceInfo('other-service', 'test-role', {'test': 'http://localhost:8080'}) # When discoverer._handle_message(non_matching_info.__dict__) diff --git a/tests/defaultScheduledAdvertizerTest.py b/tests/defaultScheduledAdvertizerTest.py index cdb42bf..29ac55a 100644 --- a/tests/defaultScheduledAdvertizerTest.py +++ b/tests/defaultScheduledAdvertizerTest.py @@ -10,7 +10,7 @@ ACCESS_URL = 'udp://239.0.0.1:5555' GROUP_NAME = 'test-group' GROUP = Group(GROUP_NAME) -SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) class DefaultScheduledAdvertizerTest(TestCase): diff --git a/tests/discovererIntegrationTest.py b/tests/discovererIntegrationTest.py index 5d5b0a6..7d4cbbf 100644 --- a/tests/discovererIntegrationTest.py +++ b/tests/discovererIntegrationTest.py @@ -15,7 +15,7 @@ class DiscovererIntegrationTest(TestCase): - SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) @classmethod def setUpClass(cls): @@ -23,7 +23,7 @@ def setUpClass(cls): def setUp(self): print() - self.SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') + self.SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) def test_discovers_service_when_hello_received(self): # Given @@ -58,18 +58,18 @@ def test_updates_service_when_info_changed(self): wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) # When - self.SERVICE_INFO.url = 'http://localhost:9090' + self.SERVICE_INFO.urls['test'] = 'http://localhost:9090' test_sender.send(self.SERVICE_INFO) wait_for_assertion(0.1, lambda: self.assertEqual( 'http://localhost:9090', - discoverer.get_services()[self.SERVICE_INFO.name].url + discoverer.get_services()[self.SERVICE_INFO.name].urls['test'] )) # Then self.assertEqual( 'http://localhost:9090', - discoverer.get_services()[self.SERVICE_INFO.name].url + discoverer.get_services()[self.SERVICE_INFO.name].urls['test'] ) def test_sends_query(self): diff --git a/tests/dishReceiverTest.py b/tests/dishReceiverTest.py index 6e18543..5852f9c 100644 --- a/tests/dishReceiverTest.py +++ b/tests/dishReceiverTest.py @@ -11,7 +11,7 @@ ACCESS_URL = 'udp://239.0.0.1:5555' GROUP_NAME = 'test-group' GROUP = Group(GROUP_NAME) -SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) class DishReceiverTest(TestCase): diff --git a/tests/radioSenderTest.py b/tests/radioSenderTest.py index d059a0c..d1a285b 100644 --- a/tests/radioSenderTest.py +++ b/tests/radioSenderTest.py @@ -11,7 +11,7 @@ ACCESS_URL = 'udp://239.0.0.1:5555' GROUP_NAME = 'test-group' GROUP = Group(GROUP_NAME) -SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) class RadioSenderTest(TestCase): diff --git a/tests/receiverIntegrationTest.py b/tests/receiverIntegrationTest.py index e1c30f8..b9ccd34 100644 --- a/tests/receiverIntegrationTest.py +++ b/tests/receiverIntegrationTest.py @@ -10,7 +10,7 @@ ACCESS_URL = 'udp://239.0.0.1:5555' GROUP_NAME = 'test-group' GROUP = Group(GROUP_NAME) -SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) class ReceiverIntegrationTest(TestCase): diff --git a/tests/respondingAdvertizerTest.py b/tests/respondingAdvertizerTest.py index a048c6e..d37581e 100644 --- a/tests/respondingAdvertizerTest.py +++ b/tests/respondingAdvertizerTest.py @@ -9,7 +9,7 @@ ACCESS_URL = 'udp://239.0.0.1:5555' GROUP_NAME = 'test-group' GROUP = Group(GROUP_NAME) -SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) class RespondingAdvertizerTest(TestCase): diff --git a/tests/senderIntegrationTest.py b/tests/senderIntegrationTest.py index 4596256..c5b3c46 100644 --- a/tests/senderIntegrationTest.py +++ b/tests/senderIntegrationTest.py @@ -11,7 +11,7 @@ ACCESS_URL = 'udp://239.0.0.1:5555' GROUP_NAME = 'test-group' GROUP = Group(GROUP_NAME) -SERVICE_INFO = ServiceInfo('test-service', 'test-role', 'http://localhost:8080') +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) class SenderIntegrationTest(TestCase):