diff --git a/.devcontainer b/.devcontainer new file mode 160000 index 0000000..0ddb12b --- /dev/null +++ b/.devcontainer @@ -0,0 +1 @@ +Subproject commit 0ddb12bd5a2381d9c588f5fc6f96c876716087f4 diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..87ae6dc --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +--- +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" + - package-ecosystem: "gitsubmodule" + directory: "/" + schedule: + interval: "daily" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..2b8ec6b --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,17 @@ +name: CI + +on: + push: + branches: [main] + tags: [v*.*.*] + + pull_request: + branches: [ "main" ] + types: + - synchronize + - opened + - reopened + +jobs: + call_ci: + uses: EffectiveRange/ci-workflows/.github/workflows/python-ci.yaml@latest-python diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..f1b0753 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule ".devcontainer"] + path = .devcontainer + url = https://github.com/EffectiveRange/devcontainer-defs diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..ca1b149 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# IntelliJ IDEA folder-specific ignored files +/shelf/ +/workspace.xml +/queries/ +/dataSources/ +/dataSources.local.xml +/httpRequests/ +copilot.* \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..30ff3a3 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,9 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..e83f687 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/python-hello.iml b/.idea/python-hello.iml new file mode 100644 index 0000000..a2043e1 --- /dev/null +++ b/.idea/python-hello.iml @@ -0,0 +1,16 @@ + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..5831342 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..d29106d --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,35 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "args": [ + "--backend=scipy" + ], + "console": "integratedTerminal" + }, + { + "name": "Run All Tests (pytest)", + "type": "debugpy", + "request": "launch", + "module": "pytest", + "args": [ + "tests" + ], + "console": "integratedTerminal" + }, + { + "name": "Run All Tests with Coverage (pytest-cov)", + "type": "debugpy", + "request": "launch", + "module": "pytest", + "args": [ + "--cov=hello", "tests" + ], + "console": "integratedTerminal" + } + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..ded9998 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,20 @@ +{ + "python.venvPath": "${workspaceFolder}/.venv", + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "*Test.py" + ], + "python.testing.pytestEnabled": true, + "python.testing.unittestEnabled": false, + "black-formatter.interpreter": [ + "${workspaceFolder}/.venv/bin/python3" + ], + "black-formatter.args": [ + "--config=setup.cfg" + ], + "python.analysis.typeCheckingMode": "standard", + "python.testing.pytestArgs": [] +} diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 0000000..328d1b1 --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,15 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "label": "Create Python venv", + "type": "shell", + "command": "if [ -d /var/chroot/buildroot ];then dpkgdeps -v --arch $(grep TARGET_ARCH /home/crossbuilder/target/target | cut -d'=' -f2 | tr -d \\') .;else dpkgdeps -v .;fi && rm -rf .venv && python3 -m venv --system-site-packages .venv && .venv/bin/pip install -e . && .venv/bin/python3 -m mypy --non-interactive --install-types && .venv/bin/pip install pytest-cov || true", + "group": "build", + "detail": "Creates a Python virtual environment in the .venv folder", + "problemMatcher": [ + "$eslint-compact" + ] + } + ] +} diff --git a/README.md b/README.md index db85593..2d7f101 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,5 @@ +[![CI](https://github.com/EffectiveRange/python-hello/actions/workflows/ci.yaml/badge.svg)](https://github.com/EffectiveRange/python-hello/actions/workflows/ci.yaml) +[![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/EffectiveRange/python-hello/python-coverage-comment-action-data/endpoint.json)](https://htmlpreview.github.io/?https://github.com/EffectiveRange/python-hello/blob/python-coverage-comment-action-data/htmlcov/index.html) + # python-hello -A service advertizer library using ZeroMQ +A service advertizer/discovery protocol library using ZeroMQ diff --git a/deps.json b/deps.json new file mode 100644 index 0000000..a6fe6ec --- /dev/null +++ b/deps.json @@ -0,0 +1,20 @@ +{ + "deps": [ + { + "name": "libzmq-drafts", + "hostinstall": true + }, + { + "name": "python3-zmq-drafts", + "hostinstall": true + }, + { + "name": "python3-context-logger", + "hostinstall": true + }, + { + "name": "python3-common-utility", + "hostinstall": true + } + ] +} diff --git a/hello/__init__.py b/hello/__init__.py new file mode 100644 index 0000000..c360ae2 --- /dev/null +++ b/hello/__init__.py @@ -0,0 +1,7 @@ +from .group import * +from .sender import * +from .receiver import * +from .service import * +from .advertizer import * +from .discoverer import * +from .api import * diff --git a/hello/advertizer.py b/hello/advertizer.py new file mode 100644 index 0000000..13e2bea --- /dev/null +++ b/hello/advertizer.py @@ -0,0 +1,130 @@ +import random +import time +from typing import Any + +from common_utility import IReusableTimer +from context_logger import get_logger + +from hello import ServiceInfo, Group, Sender, GroupAccess, Receiver, ServiceMatcher, ServiceQuery + +log = get_logger('Advertizer') + + +class Advertizer: + + def start(self, address: str, group: Group, info: ServiceInfo | None = None) -> None: + raise NotImplementedError() + + def stop(self) -> None: + raise NotImplementedError() + + def advertise(self, info: ServiceInfo | None = None) -> None: + raise NotImplementedError() + + +class DefaultAdvertizer(Advertizer): + + def __init__(self, sender: Sender) -> None: + self._sender = sender + self._group: Group | None = None + self._info: ServiceInfo | None = None + + def __enter__(self) -> Advertizer: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.stop() + + def start(self, address: str, group: Group, info: ServiceInfo | None = None) -> None: + self._sender.start(GroupAccess(address, group.hello())) + self._group = group + self._info = info + + def stop(self) -> None: + self._group = None + self._sender.stop() + + def advertise(self, info: ServiceInfo | None = None) -> None: + if self._group: + if info: + self._info = info + if self._info: + self._sender.send(self._info) + log.info('Service advertised', service=self._info, group=self._group) + else: + log.warning('Cannot advertise service, advertizer not started', service=info) + + +class RespondingAdvertizer(DefaultAdvertizer): + + def __init__(self, sender: Sender, receiver: Receiver, max_response_delay: float = 0.1) -> None: + super().__init__(sender) + self._receiver = receiver + self._max_delay = max_response_delay + + def start(self, address: str, group: Group, info: ServiceInfo | None = None) -> None: + super().start(address, group, info) + self._receiver.start(GroupAccess(address, group.query())) + self._receiver.register(self._handle_message) + + def stop(self) -> None: + super().stop() + self._receiver.stop() + + def _handle_message(self, message: dict[str, Any]) -> None: + if self._info: + try: + query = ServiceQuery(**message) + log.debug('Query received', group=self._group, query=query) + self._handle_query(query, self._info) + except Exception as error: + log.warning('Invalid query message received', group=self._group, received=message, error=error) + + def _handle_query(self, query: ServiceQuery, info: ServiceInfo) -> None: + matcher = ServiceMatcher(query) + if matcher and matcher.matches(info): + delay = round(self._max_delay * random.random(), 3) + log.info('Responding to query', group=self._group, query=matcher.query, service=info, delay=delay) + time.sleep(delay) + self.advertise(info) + + +class ScheduledAdvertizer(Advertizer): + + def schedule(self, info: ServiceInfo | None = None, interval: float = 10, one_shot: bool = False) -> None: + raise NotImplementedError() + + +class DefaultScheduledAdvertizer(ScheduledAdvertizer): + + def __init__(self, advertizer: Advertizer, timer: IReusableTimer) -> None: + self._advertizer = advertizer + self._timer = timer + + def __enter__(self) -> ScheduledAdvertizer: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.stop() + + def start(self, address: str, group: Group, info: ServiceInfo | None = None) -> None: + self._advertizer.start(address, group, info) + + def stop(self) -> None: + self._timer.cancel() + self._advertizer.stop() + + def advertise(self, info: ServiceInfo | None = None) -> None: + self._advertizer.advertise(info) + + def schedule(self, info: ServiceInfo | None = None, interval: float = 60, one_shot: bool = False) -> None: + if one_shot: + self._timer.start(interval, self.advertise, [info]) + log.info('One-shot service advertisement scheduled', service=info, interval=interval) + else: + self._timer.start(interval, self._advertise_and_restart, [info]) + log.info('Periodic service advertisement scheduled', service=info, interval=interval) + + def _advertise_and_restart(self, info: ServiceInfo | None = None) -> None: + self.advertise(info) + self._timer.restart() diff --git a/hello/api.py b/hello/api.py new file mode 100644 index 0000000..57fc5c9 --- /dev/null +++ b/hello/api.py @@ -0,0 +1,44 @@ +from typing import Any + +from common_utility import ReusableTimer +from zmq import Context + +from hello import Advertizer, Discoverer, RadioSender, DishReceiver, DefaultAdvertizer, DefaultDiscoverer, \ + ScheduledAdvertizer, RespondingAdvertizer, DefaultScheduledAdvertizer + + +class Hello: + + def default_advertizer(self, respond: bool = True, delay: float = 0.1) -> Advertizer: + raise NotImplementedError() + + def scheduled_advertizer(self, respond: bool = True, delay: float = 0.1) -> ScheduledAdvertizer: + raise NotImplementedError() + + def discoverer(self) -> Discoverer: + raise NotImplementedError() + + +class DefaultHello(Hello): + + def __init__(self, context: Context[Any] | None = None, max_workers: int = 1, poll_timeout: float = 0.1) -> None: + self._context = context if context else Context() + self._max_workers = max_workers + self._poll_timeout = poll_timeout + + def default_advertizer(self, respond: bool = True, delay: float = 0.1) -> Advertizer: + sender = RadioSender(self._context) + if respond: + receiver = DishReceiver(self._context, self._max_workers, self._poll_timeout) + return RespondingAdvertizer(sender, receiver, delay) + else: + return DefaultAdvertizer(sender) + + def scheduled_advertizer(self, respond: bool = True, delay: float = 0.1) -> ScheduledAdvertizer: + advertizer = self.default_advertizer(respond, delay) + return DefaultScheduledAdvertizer(advertizer, ReusableTimer()) + + def discoverer(self) -> Discoverer: + sender = RadioSender(self._context) + receiver = DishReceiver(self._context, self._max_workers, self._poll_timeout) + return DefaultDiscoverer(sender, receiver) diff --git a/hello/discoverer.py b/hello/discoverer.py new file mode 100644 index 0000000..192f644 --- /dev/null +++ b/hello/discoverer.py @@ -0,0 +1,134 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Protocol + +from context_logger import get_logger + +from hello import Group, ServiceQuery, Sender, Receiver, GroupAccess, ServiceInfo, ServiceMatcher + +log = get_logger('Discoverer') + + +class DiscoveryEventType(Enum): + DISCOVERED = 'discovered' + UPDATED = 'updated' + + +@dataclass +class DiscoveryEvent: + service: ServiceInfo + type: DiscoveryEventType + + +class OnDiscoveryEvent(Protocol): + def __call__(self, event: DiscoveryEvent) -> None: ... + + +class Discoverer: + + def start(self, address: str, group: Group, query: ServiceQuery | None = None) -> None: + raise NotImplementedError() + + def stop(self) -> None: + raise NotImplementedError() + + def discover(self, query: ServiceQuery | None = None) -> None: + raise NotImplementedError() + + def get_services(self) -> dict[str, ServiceInfo]: + raise NotImplementedError() + + def register(self, handler: OnDiscoveryEvent) -> None: + raise NotImplementedError() + + def deregister(self, handler: OnDiscoveryEvent) -> None: + raise NotImplementedError() + + def get_handlers(self) -> list[OnDiscoveryEvent]: + raise NotImplementedError() + + +class DefaultDiscoverer(Discoverer): + + def __init__(self, sender: Sender, receiver: Receiver) -> None: + self._sender = sender + self._receiver = receiver + self._group: Group | None = None + self._matcher: ServiceMatcher | None = None + self._cache: dict[str, ServiceInfo] = {} + self._handlers: list[OnDiscoveryEvent] = [] + + def __enter__(self) -> Discoverer: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.stop() + + def start(self, address: str, group: Group, query: ServiceQuery | None = None) -> None: + self._group = group + if query: + self._matcher = ServiceMatcher(query) + self._sender.start(GroupAccess(address, group.query())) + self._receiver.register(self._handle_message) + self._receiver.start(GroupAccess(address, group.hello())) + + def stop(self) -> None: + self._group = None + self._sender.stop() + self._receiver.stop() + + def discover(self, query: ServiceQuery | None = None) -> None: + if self._group: + if query: + self._matcher = ServiceMatcher(query) + if self._matcher: + self._sender.send(self._matcher.query) + log.info('Service discovery initiated', query=self._matcher.query, group=self._group) + else: + log.warning('Cannot discover services, discoverer not started', query=query) + + def get_services(self) -> dict[str, ServiceInfo]: + return self._cache.copy() + + def register(self, handler: OnDiscoveryEvent) -> None: + self._handlers.append(handler) + + def deregister(self, handler: OnDiscoveryEvent) -> None: + self._handlers.remove(handler) + + def get_handlers(self) -> list[OnDiscoveryEvent]: + return self._handlers.copy() + + def _handle_message(self, message: dict[str, Any]) -> None: + try: + service = ServiceInfo(**message) + self._handle_service(service) + except Exception as error: + log.warn('Failed to handle received message', data=message, error=error) + + def _handle_service(self, service: ServiceInfo) -> None: + if self._matcher and self._matcher.matches(service): + cached = self._cache.get(service.name) + + if event := self._create_event(cached, service): + self._handle_event(event) + + def _create_event(self, cached: ServiceInfo | None, service: ServiceInfo) -> DiscoveryEvent | None: + if cached: + if cached != service: + log.info('Service updated', old_service=cached, new_service=service) + return DiscoveryEvent(service, DiscoveryEventType.UPDATED) + else: + return None + else: + log.info('Service discovered', service=service) + return DiscoveryEvent(service, DiscoveryEventType.DISCOVERED) + + def _handle_event(self, event: DiscoveryEvent) -> None: + service = event.service + self._cache[service.name] = service + for callback in self._handlers: + try: + callback(event) + except Exception as error: + log.warn('Error in event handler execution', event=event, error=error) diff --git a/hello/group.py b/hello/group.py new file mode 100644 index 0000000..95b0654 --- /dev/null +++ b/hello/group.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from enum import Enum + + +class GroupPrefix(Enum): + HELLO = 'hello' + QUERY = 'query' + + +class IGroup: + + def hello(self) -> str: + raise NotImplementedError() + + def query(self) -> str: + raise NotImplementedError() + + +class Group(IGroup): + def __init__(self, name: str) -> None: + self.name = name + + def hello(self) -> str: + return self._prefix(GroupPrefix.HELLO) + + def query(self) -> str: + return self._prefix(GroupPrefix.QUERY) + + def _prefix(self, group_type: GroupPrefix) -> str: + return f'{group_type.value}:{self.name}' + + def __repr__(self) -> str: + return self.name + + +@dataclass +class GroupAccess: + access_url: str + full_group: str diff --git a/hello/py.typed b/hello/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/hello/receiver.py b/hello/receiver.py new file mode 100644 index 0000000..b9e0068 --- /dev/null +++ b/hello/receiver.py @@ -0,0 +1,100 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Protocol + +from context_logger import get_logger +from zmq import DISH, Poller, POLLIN, Context + +from hello import GroupAccess + +log = get_logger('Receiver') + + +class OnMessage(Protocol): + def __call__(self, message: dict[str, Any]) -> None: ... + + +class Receiver: + + def start(self, source: GroupAccess) -> None: + raise NotImplementedError() + + def stop(self) -> None: + raise NotImplementedError() + + def register(self, handler: OnMessage) -> None: + raise NotImplementedError() + + def deregister(self, handler: OnMessage) -> None: + raise NotImplementedError() + + def get_handlers(self) -> list[OnMessage]: + raise NotImplementedError() + + +class DishReceiver(Receiver): + + def __init__(self, context: Context[Any], max_workers: int = 1, poll_timeout: float = 0.1) -> None: + self._context = context + self._dish = self._context.socket(DISH) + self._poller = Poller() + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._poll_timeout = int(poll_timeout * 1000) + self._group: str | None = None + self._handlers: list[OnMessage] = [] + + def __enter__(self) -> Receiver: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.stop() + + def start(self, source: GroupAccess) -> None: + try: + if self._group: + raise RuntimeError('Receiver already started') + self._poller.register(self._dish, POLLIN) + self._dish.bind(source.access_url) + self._dish.join(source.full_group) + self._group = source.full_group + self._executor.submit(self._receive_loop) + log.debug('Receiver started', address=source.access_url, group=source.full_group) + except Exception as error: + log.error('Failed to start receiver', address=source.access_url, group=source.full_group, error=error) + raise error + + def stop(self) -> None: + try: + self._group = None + self._executor.shutdown() + self._dish.close() + log.debug('Receiver stopped') + except Exception as error: + log.error('Failed to stop receiver', error=error) + raise error + + def register(self, handler: OnMessage) -> None: + self._handlers.append(handler) + + def deregister(self, handler: OnMessage) -> None: + self._handlers.remove(handler) + + def get_handlers(self) -> list[OnMessage]: + return self._handlers.copy() + + def _receive_loop(self) -> None: + while self._group: + sockets = dict(self._poller.poll(timeout=self._poll_timeout)) + if self._dish in sockets and sockets[self._dish] == POLLIN: + try: + data = self._dish.recv_json() + log.debug('Message received', data=data, group=self._group) + self._handle_message(data) + except Exception as error: + log.error('Failed to receive message', group=self._group, error=error) + + def _handle_message(self, message: dict[str, Any]) -> None: + for handler in self._handlers: + try: + handler(message) + except Exception as error: + log.warn('Error in message handler execution', data=message, group=self._group, error=error) diff --git a/hello/sender.py b/hello/sender.py new file mode 100644 index 0000000..57574e5 --- /dev/null +++ b/hello/sender.py @@ -0,0 +1,77 @@ +from typing import Any, cast + +from context_logger import get_logger +from zmq import Context, RADIO, Socket + +from hello import GroupAccess + +log = get_logger('Sender') + + +class Sender: + + def start(self, target: GroupAccess) -> None: + raise NotImplementedError() + + def stop(self) -> None: + raise NotImplementedError() + + def send(self, data: object) -> None: + raise NotImplementedError() + + +class RadioSender(Sender): + + def __init__(self, context: Context[Any]) -> None: + self._context = context + self._radio: Socket[bytes] = self._context.socket(RADIO) + self._group: str | None = None + + def __enter__(self) -> Sender: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.stop() + + def start(self, target: GroupAccess) -> None: + try: + if self._group: + raise RuntimeError('Sender already started') + self._radio.connect(target.access_url) + self._group = target.full_group + log.debug('Sender started', address=target.access_url, group=target.full_group) + except Exception as error: + log.error('Failed to start sender', address=target.access_url, group=target.full_group, error=error) + raise error + + def stop(self) -> None: + try: + self._group = None + self._radio.close() + log.debug('Sender stopped') + except Exception as error: + log.error('Failed to stop sender', error=error) + raise error + + def send(self, data: Any) -> None: + if self._group: + if message := self._convert_to_dict(data): + self._send_json(message) + else: + log.warning('Unsupported message type', data=data, group=self._group) + else: + log.warning('Cannot send message, sender not started', data=data) + + def _convert_to_dict(self, data: Any) -> dict[str, Any] | None: + if isinstance(data, dict): + return data + elif hasattr(data, '__dict__'): + return cast(dict[str, Any], data.__dict__) + return None + + def _send_json(self, data: dict[str, Any]) -> None: + try: + self._radio.send_json(data, group=self._group) + log.debug('Message sent', data=data, group=self._group) + except Exception as error: + log.error('Failed to send message', data=data, group=self._group, error=error) diff --git a/hello/service.py b/hello/service.py new file mode 100644 index 0000000..99f909e --- /dev/null +++ b/hello/service.py @@ -0,0 +1,28 @@ +import re +from dataclasses import dataclass, field + + +@dataclass +class ServiceInfo: + name: str + role: str + urls: dict[str, str] = field(default_factory=dict) + + +@dataclass +class ServiceQuery(object): + name: str + role: str + + +class ServiceMatcher(object): + + def __init__(self, query: ServiceQuery) -> None: + self.query = query + self._name_matcher = re.compile(self.query.name) + self._role_matcher = re.compile(self.query.role) + + def matches(self, info: ServiceInfo) -> bool: + name_match = self._name_matcher.match(info.name) + role_match = self._role_matcher.match(info.role) + return bool(name_match and role_match) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8622581 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,27 @@ +[project] +name = "python-hello" +description = "A service advertizer/discovery protocol library using ZeroMQ" +authors = [ + { name = "Ferenc Nandor Janky & Attila Gombos", email = "info@effective-range.com" } +] +dependencies = [ + "pyzmq @ git+https://github.com/EffectiveRange/pyzmq.git@v27.1.1", + "python-context-logger @ git+https://github.com/EffectiveRange/python-context-logger.git@latest", + "python-common-utility @ git+https://github.com/EffectiveRange/python-common-utility.git@latest" +] +dynamic = ["version"] + +[tool.setuptools] +package-dir = {"" = "."} +packages = ["hello"] + +[tool.setuptools.package-data] +hello = ["py.typed"] + +[build-system] +requires = ["setuptools>=61", "setuptools_scm"] +build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] +version_scheme = "guess-next-dev" +local_scheme = "node-and-date" diff --git a/python-hello.code-workspace b/python-hello.code-workspace new file mode 100644 index 0000000..a083c85 --- /dev/null +++ b/python-hello.code-workspace @@ -0,0 +1,15 @@ +{ + "folders": [ + { + "path": "." + } + ], + "settings": { + "python.formatting.provider": "black", + "editor.formatOnSave": true, + "mypy-type-checker.ignorePatterns": [ + ], + "flake8.ignorePatterns": [ + ] + } +} diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..ce39c45 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,57 @@ +[pack-python] +packaging = + wheel + fpm-deb + +[mypy] +packages = hello +strict = True + +[flake8] +exclude = build,dist,.eggs,.venv +max-line-length = 120 +max-complexity = 10 +count = True +statistics = True +show-source = True +per-file-ignores = + # F401: imported but unused + # F403: import * used; unable to detect undefined names + __init__.py: F401,F403 + +[tool:pytest] +addopts = --capture=no --verbose +python_files = *Test.py +python_classes = *Test + +[coverage:run] +relative_files = true +branch = True +source = hello + +[coverage:report] +; Regexes for lines to exclude from consideration +exclude_also = + ; Don't complain about missing debug-only code: + def __repr__ + if self\.debug + + ; Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + + ; Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: + + ; Don't complain about abstract methods, they aren't run: + @(abc\.)?abstractmethod + +ignore_errors = True +skip_empty = True + +[coverage:html] +directory = coverage/html + +[coverage:json] +output = coverage/coverage.json diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/advertizerIntegrationTest.py b/tests/advertizerIntegrationTest.py new file mode 100644 index 0000000..74554da --- /dev/null +++ b/tests/advertizerIntegrationTest.py @@ -0,0 +1,151 @@ +import unittest +from unittest import TestCase + +from common_utility import ReusableTimer +from context_logger import setup_logging +from test_utility import wait_for_assertion +from zmq import Context + +from hello import DefaultAdvertizer, ServiceInfo, Group, RadioSender, DishReceiver, GroupAccess, \ + RespondingAdvertizer, ServiceQuery, DefaultScheduledAdvertizer + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) + + +class AdvertizerIntegrationTest(TestCase): + SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + self.SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) + + def test_sends_hello_when_advertises_service(self): + # Given + context = Context() + sender = RadioSender(context) + messages = [] + + with DefaultAdvertizer(sender) as advertizer, DishReceiver(context) as test_receiver: + test_receiver.start(GroupAccess(ACCESS_URL, GROUP.hello())) + test_receiver.register(lambda message: messages.append(message)) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.advertise(self.SERVICE_INFO) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(messages))) + + # Then + self.assertEqual([self.SERVICE_INFO.__dict__], messages) + + def test_sends_hello_when_query_received(self): + # Given + context = Context() + sender = RadioSender(context) + receiver = DishReceiver(context) + messages = [] + + with (RespondingAdvertizer(sender, receiver, 0.01) as advertizer, + RadioSender(context) as test_sender, + DishReceiver(context) as test_receiver): + test_sender.start(GroupAccess(ACCESS_URL, GROUP.query())) + test_receiver.start(GroupAccess(ACCESS_URL, GROUP.hello())) + test_receiver.register(lambda message: messages.append(message)) + + advertizer.start(ACCESS_URL, GROUP, self.SERVICE_INFO) + + # When + test_sender.send(ServiceQuery('test-service', 'test-role')) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(messages))) + + # Then + self.assertEqual([self.SERVICE_INFO.__dict__], messages) + + def test_sends_hello_when_info_changed_and_query_received(self): + # Given + context = Context() + sender = RadioSender(context) + receiver = DishReceiver(context) + messages = [] + + with (RespondingAdvertizer(sender, receiver, 0.01) as advertizer, + RadioSender(context) as test_sender, + DishReceiver(context) as test_receiver): + test_sender.start(GroupAccess(ACCESS_URL, GROUP.query())) + test_receiver.start(GroupAccess(ACCESS_URL, GROUP.hello())) + test_receiver.register(lambda message: messages.append(message)) + + advertizer.start(ACCESS_URL, GROUP) + advertizer.advertise(self.SERVICE_INFO) + + query = ServiceQuery('test-service', 'test-role') + test_sender.send(query) + + wait_for_assertion(0.1, lambda: self.assertEqual(2, len(messages))) + + self.SERVICE_INFO.urls['test'] = 'http://localhost:9090' + advertizer.advertise(self.SERVICE_INFO) + + # When + test_sender.send(query) + + wait_for_assertion(0.1, lambda: self.assertEqual(4, len(messages))) + + # Then + self.assertEqual([ + {'name': 'test-service', 'role': 'test-role', 'urls': {'test': 'http://localhost:8080'}}, + {'name': 'test-service', 'role': 'test-role', 'urls': {'test': 'http://localhost:8080'}}, + {'name': 'test-service', 'role': 'test-role', 'urls': {'test': 'http://localhost:9090'}}, + {'name': 'test-service', 'role': 'test-role', 'urls': {'test': 'http://localhost:9090'}} + ], messages) + + def test_sends_hello_when_schedules_advertisement_once(self): + # Given + context = Context() + sender = RadioSender(context) + _advertizer = DefaultAdvertizer(sender) + timer = ReusableTimer() + messages = [] + + with DefaultScheduledAdvertizer(_advertizer, timer) as advertizer, DishReceiver(context) as test_receiver: + test_receiver.start(GroupAccess(ACCESS_URL, GROUP.hello())) + test_receiver.register(lambda message: messages.append(message)) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.schedule(self.SERVICE_INFO, interval=0.01, one_shot=True) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(messages))) + + # Then + self.assertEqual([self.SERVICE_INFO.__dict__], messages) + + def test_sends_hello_when_schedules_advertisement_periodically(self): + # Given + context = Context() + sender = RadioSender(context) + _advertizer = DefaultAdvertizer(sender) + timer = ReusableTimer() + messages = [] + + with DefaultScheduledAdvertizer(_advertizer, timer) as advertizer, DishReceiver(context) as test_receiver: + test_receiver.start(GroupAccess(ACCESS_URL, GROUP.hello())) + test_receiver.register(lambda message: messages.append(message)) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.schedule(self.SERVICE_INFO, interval=0.01) + + # Then + wait_for_assertion(0.1, lambda: self.assertEqual(5, len(messages))) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/apiIntegrationTest.py b/tests/apiIntegrationTest.py new file mode 100644 index 0000000..5ea16f6 --- /dev/null +++ b/tests/apiIntegrationTest.py @@ -0,0 +1,96 @@ +import unittest +from unittest import TestCase + +from context_logger import setup_logging +from test_utility import wait_for_assertion +from zmq import Context + +from hello import ServiceInfo, Group, DefaultHello, ServiceQuery + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) +SERVICE_QUERY = ServiceQuery('test-service', 'test-role') + + +class ApiIntegrationTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_discoverer_caches_advertised_service(self): + # Given + context = Context() + hello = DefaultHello(context) + + with hello.default_advertizer(respond=False) as advertizer, hello.discoverer() as discoverer: + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + # When + advertizer.advertise() + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) + + # Then + self.assertEqual({SERVICE_INFO.name: SERVICE_INFO}, discoverer.get_services()) + + def test_discoverer_caches_advertised_service_when_scheduled_once(self): + # Given + context = Context() + hello = DefaultHello(context) + + with hello.scheduled_advertizer(respond=False) as advertizer, hello.discoverer() as discoverer: + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + # When + advertizer.schedule(interval=0.01, one_shot=True) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) + + # Then + self.assertEqual({SERVICE_INFO.name: SERVICE_INFO}, discoverer.get_services()) + + def test_discoverer_caches_advertised_service_when_scheduled_periodically(self): + # Given + context = Context() + hello = DefaultHello(context) + + with hello.scheduled_advertizer() as advertizer, hello.discoverer() as discoverer: + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + # When + advertizer.schedule(interval=0.01) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) + + # Then + self.assertEqual({SERVICE_INFO.name: SERVICE_INFO}, discoverer.get_services()) + + def test_discoverer_caches_discovery_response_service(self): + # Given + context = Context() + hello = DefaultHello(context) + + with hello.default_advertizer() as advertizer, hello.discoverer() as discoverer: + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + # When + discoverer.discover() + + wait_for_assertion(0.2, lambda: self.assertEqual(1, len(discoverer.get_services()))) + + # Then + self.assertEqual({SERVICE_INFO.name: SERVICE_INFO}, discoverer.get_services()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/defaultAdvertizerTest.py b/tests/defaultAdvertizerTest.py new file mode 100644 index 0000000..474e3b3 --- /dev/null +++ b/tests/defaultAdvertizerTest.py @@ -0,0 +1,120 @@ +import unittest +from unittest import TestCase +from unittest.mock import MagicMock + +from context_logger import setup_logging + +from hello import ServiceInfo, Group, Sender, DefaultAdvertizer, GroupAccess + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) + + +class DefaultAdvertizerTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_stops_sender_on_exit(self): + # Given + sender = MagicMock(spec=Sender) + + with DefaultAdvertizer(sender) as advertizer: + advertizer.start(ACCESS_URL, GROUP) + + # When + + # Then + sender.stop.assert_called_once() + + def test_stops_sender_when_stopped(self): + # Given + sender = MagicMock(spec=Sender) + advertizer = DefaultAdvertizer(sender) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.stop() + + # Then + sender.stop.assert_called_once() + + def test_starts_sender_when_started(self): + # Given + sender = MagicMock(spec=Sender) + advertizer = DefaultAdvertizer(sender) + + # When + advertizer.start(ACCESS_URL, GROUP) + + # Then + sender.start.assert_called_once_with(GroupAccess(ACCESS_URL, GROUP.hello())) + + def test_sends_info_when_passed_at_start(self): + # Given + sender = MagicMock(spec=Sender) + advertizer = DefaultAdvertizer(sender) + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + + # When + advertizer.advertise() + + # Then + sender.send.assert_called_once_with(SERVICE_INFO) + + def test_sends_info_when_passed_at_advertise(self): + # Given + sender = MagicMock(spec=Sender) + advertizer = DefaultAdvertizer(sender) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.advertise(SERVICE_INFO) + + # Then + sender.send.assert_called_once_with(SERVICE_INFO) + + def test_sends_last_info_when_passed_at_start_and_at_advertise(self): + # Given + sender = MagicMock(spec=Sender) + advertizer = DefaultAdvertizer(sender) + advertizer.start(ACCESS_URL, GROUP, ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:9090'})) + + # When + advertizer.advertise(SERVICE_INFO) + + # Then + sender.send.assert_called_once_with(SERVICE_INFO) + + def test_does_not_send_info_when_no_info_provided(self): + # Given + sender = MagicMock(spec=Sender) + advertizer = DefaultAdvertizer(sender) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.advertise() + + # Then + sender.send.assert_not_called() + + def test_does_not_send_info_when_not_started(self): + # Given + sender = MagicMock(spec=Sender) + advertizer = DefaultAdvertizer(sender) + + # When + advertizer.advertise(SERVICE_INFO) + + # Then + sender.send.assert_not_called() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/defaultDiscovererTest.py b/tests/defaultDiscovererTest.py new file mode 100644 index 0000000..70fd70b --- /dev/null +++ b/tests/defaultDiscovererTest.py @@ -0,0 +1,270 @@ +import unittest +from unittest import TestCase +from unittest.mock import MagicMock + +from context_logger import setup_logging + +from hello import ServiceInfo, Group, GroupAccess, \ + ServiceQuery, DefaultDiscoverer, Sender, Receiver, OnDiscoveryEvent, DiscoveryEventType, DiscoveryEvent + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_QUERY = ServiceQuery('test-.*', 'test-.*') +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) + + +class DefaultDiscovererTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_stops_sender_and_receiver_on_exit(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + + with DefaultDiscoverer(sender, receiver) as discoverer: + discoverer.start(ACCESS_URL, GROUP) + + # When + + # Then + sender.stop.assert_called_once() + receiver.stop.assert_called_once() + + def test_stops_sender_and_receiver_when_stopped(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP) + + # When + discoverer.stop() + + # Then + sender.stop.assert_called_once() + receiver.stop.assert_called_once() + + def test_starts_sender_and_receiver_when_started(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + + # When + discoverer.start(ACCESS_URL, GROUP) + + # Then + sender.start.assert_called_once_with(GroupAccess(ACCESS_URL, GROUP.query())) + receiver.start.assert_called_once_with(GroupAccess(ACCESS_URL, GROUP.hello())) + + def test_registers_event_handler(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + handler = MagicMock(spec=OnDiscoveryEvent) + + # When + discoverer.register(handler) + + # Then + self.assertIn(handler, discoverer.get_handlers()) + + def test_deregisters_event_handler(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + handler = MagicMock(spec=OnDiscoveryEvent) + discoverer.register(handler) + + # When + discoverer.deregister(handler) + + # Then + self.assertNotIn(handler, discoverer.get_handlers()) + + def test_caches_service_and_calls_handler_when_receives_matching_info(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + handler = MagicMock(spec=OnDiscoveryEvent) + discoverer.register(handler) + + # When + discoverer._handle_message(SERVICE_INFO.__dict__) + + # Then + self.assertEqual({SERVICE_INFO.name: SERVICE_INFO}, discoverer.get_services()) + handler.assert_called_once_with(DiscoveryEvent(SERVICE_INFO, DiscoveryEventType.DISCOVERED)) + + def test_updates_service_and_calls_handler_when_receives_matching_info(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + handler = MagicMock(spec=OnDiscoveryEvent) + discoverer.register(handler) + discoverer._handle_message(SERVICE_INFO.__dict__) + handler.reset_mock() + new_service_info = ServiceInfo(SERVICE_INFO.name, SERVICE_INFO.role, {'test': 'http://localhost:9090'}) + + # When + discoverer._handle_message(new_service_info.__dict__) + + # Then + self.assertEqual({SERVICE_INFO.name: new_service_info}, discoverer.get_services()) + handler.assert_called_once_with(DiscoveryEvent(new_service_info, DiscoveryEventType.UPDATED)) + + def test_does_not_call_handler_when_service_info_not_changed(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + handler = MagicMock(spec=OnDiscoveryEvent) + discoverer.register(handler) + discoverer._handle_message(SERVICE_INFO.__dict__) + handler.reset_mock() + + # When + discoverer._handle_message(SERVICE_INFO.__dict__) + + # Then + handler.assert_not_called() + + def test_handles_handler_error_gracefully(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + handler = MagicMock(spec=OnDiscoveryEvent) + handler.side_effect = Exception("Handler error") + discoverer.register(handler) + + # When + discoverer._handle_message(SERVICE_INFO.__dict__) + + # Then + self.assertEqual({SERVICE_INFO.name: SERVICE_INFO}, discoverer.get_services()) + handler.assert_called_once_with(DiscoveryEvent(SERVICE_INFO, DiscoveryEventType.DISCOVERED)) + + def test_handles_invalid_message_gracefully(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + # When + discoverer._handle_message({'invalid': 'message'}) + + # Then + self.assertEqual({}, discoverer.get_services()) + + def test_does_not_cache_service_when_info_not_matching_query(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + non_matching_info = ServiceInfo('other-service', 'test-role', {'test': 'http://localhost:8080'}) + + # When + discoverer._handle_message(non_matching_info.__dict__) + + # Then + self.assertEqual({}, discoverer.get_services()) + + def test_does_not_cache_service_when_no_query_set(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP) + + # When + discoverer._handle_message(SERVICE_INFO.__dict__) + + # Then + self.assertEqual({}, discoverer.get_services()) + + def test_sends_query_when_passed_at_start(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + # When + discoverer.discover() + + # Then + sender.send.assert_called_once_with(SERVICE_QUERY) + + def test_sends_query_when_passed_at_discover(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP) + + # When + discoverer.discover(SERVICE_QUERY) + + # Then + sender.send.assert_called_once_with(SERVICE_QUERY) + + def test_sends_last_query_when_passed_at_start_and_at_discover(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP, ServiceQuery('other-.*', 'test-.*')) + + # When + discoverer.discover(SERVICE_QUERY) + + # Then + sender.send.assert_called_once_with(SERVICE_QUERY) + + def test_does_not_send_query_when_no_query_provided(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + discoverer.start(ACCESS_URL, GROUP) + + # When + discoverer.discover() + + # Then + sender.send.assert_not_called() + + def test_does_not_send_query_when_not_started(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + discoverer = DefaultDiscoverer(sender, receiver) + + # When + discoverer.discover(SERVICE_QUERY) + + # Then + sender.send.assert_not_called() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/defaultScheduledAdvertizerTest.py b/tests/defaultScheduledAdvertizerTest.py new file mode 100644 index 0000000..29ac55a --- /dev/null +++ b/tests/defaultScheduledAdvertizerTest.py @@ -0,0 +1,119 @@ +import unittest +from unittest import TestCase +from unittest.mock import MagicMock + +from common_utility import IReusableTimer +from context_logger import setup_logging + +from hello import ServiceInfo, Group, DefaultScheduledAdvertizer, Advertizer + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) + + +class DefaultScheduledAdvertizerTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_stops_timer_and_advertizer_on_exit(self): + # Given + _advertizer = MagicMock(spec=Advertizer) + timer = MagicMock(spec=IReusableTimer) + + with DefaultScheduledAdvertizer(_advertizer, timer) as advertizer: + advertizer.start(ACCESS_URL, GROUP) + + # When + + # Then + timer.cancel.assert_called_once() + _advertizer.stop.assert_called_once() + + def test_stops_timer_and_advertizer_when_stopped(self): + # Given + _advertizer = MagicMock(spec=Advertizer) + timer = MagicMock(spec=IReusableTimer) + advertizer = DefaultScheduledAdvertizer(_advertizer, timer) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.stop() + + # Then + timer.cancel.assert_called_once() + _advertizer.stop.assert_called_once() + + def test_starts_advertizer_when_started(self): + # Given + _advertizer = MagicMock(spec=Advertizer) + timer = MagicMock(spec=IReusableTimer) + advertizer = DefaultScheduledAdvertizer(_advertizer, timer) + + # When + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + + # Then + _advertizer.start.assert_called_once_with(ACCESS_URL, GROUP, SERVICE_INFO) + + def test_sends_service_info(self): + # Given + _advertizer = MagicMock(spec=Advertizer) + timer = MagicMock(spec=IReusableTimer) + advertizer = DefaultScheduledAdvertizer(_advertizer, timer) + + # When + advertizer.advertise(SERVICE_INFO) + + # Then + _advertizer.advertise.assert_called_once_with(SERVICE_INFO) + + def test_schedules_advertise_once(self): + # Given + _advertizer = MagicMock(spec=Advertizer) + timer = MagicMock(spec=IReusableTimer) + advertizer = DefaultScheduledAdvertizer(_advertizer, timer) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.schedule(SERVICE_INFO, 60, True) + + # Then + timer.start.assert_called_once_with(60, advertizer.advertise, [SERVICE_INFO]) + + def test_schedules_periodic_advertise(self): + # Given + _advertizer = MagicMock(spec=Advertizer) + timer = MagicMock(spec=IReusableTimer) + advertizer = DefaultScheduledAdvertizer(_advertizer, timer) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.schedule(SERVICE_INFO, 60, False) + + # Then + timer.start.assert_called_once_with(60, advertizer._advertise_and_restart, [SERVICE_INFO]) + + def test_advertise_and_restart_calls_advertise_and_restarts_timer(self): + # Given + _advertizer = MagicMock(spec=Advertizer) + timer = MagicMock(spec=IReusableTimer) + advertizer = DefaultScheduledAdvertizer(_advertizer, timer) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer._advertise_and_restart(SERVICE_INFO) + + # Then + _advertizer.advertise.assert_called_once_with(SERVICE_INFO) + timer.restart.assert_called_once() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/discovererIntegrationTest.py b/tests/discovererIntegrationTest.py new file mode 100644 index 0000000..7d4cbbf --- /dev/null +++ b/tests/discovererIntegrationTest.py @@ -0,0 +1,97 @@ +import unittest +from unittest import TestCase + +from context_logger import setup_logging +from test_utility import wait_for_assertion +from zmq import Context + +from hello import ServiceInfo, Group, RadioSender, DishReceiver, GroupAccess, \ + ServiceQuery, DefaultDiscoverer + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_QUERY = ServiceQuery('test-service', 'test-role') + + +class DiscovererIntegrationTest(TestCase): + SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + self.SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) + + def test_discovers_service_when_hello_received(self): + # Given + context = Context() + sender = RadioSender(context) + receiver = DishReceiver(context) + + with DefaultDiscoverer(sender, receiver) as discoverer, RadioSender(context) as test_sender: + test_sender.start(GroupAccess(ACCESS_URL, GROUP.hello())) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + # When + test_sender.send(self.SERVICE_INFO) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) + + # Then + self.assertEqual({self.SERVICE_INFO.name: self.SERVICE_INFO}, discoverer.get_services()) + + def test_updates_service_when_info_changed(self): + # Given + context = Context() + sender = RadioSender(context) + receiver = DishReceiver(context) + + with DefaultDiscoverer(sender, receiver) as discoverer, RadioSender(context) as test_sender: + test_sender.start(GroupAccess(ACCESS_URL, GROUP.hello())) + discoverer.start(ACCESS_URL, GROUP, SERVICE_QUERY) + + test_sender.send(self.SERVICE_INFO) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(discoverer.get_services()))) + + # When + self.SERVICE_INFO.urls['test'] = 'http://localhost:9090' + test_sender.send(self.SERVICE_INFO) + + wait_for_assertion(0.1, lambda: self.assertEqual( + 'http://localhost:9090', + discoverer.get_services()[self.SERVICE_INFO.name].urls['test'] + )) + + # Then + self.assertEqual( + 'http://localhost:9090', + discoverer.get_services()[self.SERVICE_INFO.name].urls['test'] + ) + + def test_sends_query(self): + # Given + context = Context() + sender = RadioSender(context) + receiver = DishReceiver(context) + messages = [] + + with DefaultDiscoverer(sender, receiver) as discoverer, DishReceiver(context) as test_receiver: + test_receiver.register(lambda message: messages.append(message)) + test_receiver.start(GroupAccess(ACCESS_URL, GROUP.query())) + discoverer.start(ACCESS_URL, GROUP) + + # When + discoverer.discover(SERVICE_QUERY) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(messages))) + + # Then + self.assertEqual([SERVICE_QUERY.__dict__], messages) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/dishReceiverTest.py b/tests/dishReceiverTest.py new file mode 100644 index 0000000..5852f9c --- /dev/null +++ b/tests/dishReceiverTest.py @@ -0,0 +1,176 @@ +import unittest +from unittest import TestCase +from unittest.mock import MagicMock + +from context_logger import setup_logging +from test_utility import wait_for_assertion +from zmq import Context, ZMQError, Poller, POLLIN + +from hello import ServiceInfo, Group, GroupAccess, DishReceiver, OnMessage + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) + + +class DishReceiverTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_raises_error_when_restarted(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + + with DishReceiver(context) as receiver: + receiver.start(group_access) + + # When, Then + with self.assertRaises(RuntimeError): + receiver.start(group_access) + + def test_raises_error_when_fails_to_bind_socket(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + context.socket.return_value.bind.side_effect = ZMQError(1, "Bind failed") + receiver = DishReceiver(context) + + # When, Then + with self.assertRaises(ZMQError): + receiver.start(group_access) + + def test_closes_socket_on_exit(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + + with DishReceiver(context) as receiver: + receiver.start(group_access) + + # When + + # Then + context.socket.return_value.close.assert_called_once() + + def test_closes_socket_when_stopped(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + receiver = DishReceiver(context) + receiver.start(group_access) + + # When + receiver.stop() + + # Then + context.socket.return_value.close.assert_called_once() + + def test_raises_error_when_fails_to_close_socket_on_stop(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + context.socket.return_value.close.side_effect = [ZMQError(1, "Close failed"), None] + + with DishReceiver(context) as receiver: + receiver.start(group_access) + + # When, Then + with self.assertRaises(ZMQError): + receiver.stop() + + def test_registers_handler(self): + # Given + context = MagicMock(spec=Context) + receiver = DishReceiver(context) + handler = MagicMock(spec=OnMessage) + + # When + receiver.register(handler) + + # Then + self.assertIn(handler, receiver.get_handlers()) + + def test_deregisters_handler(self): + # Given + context = MagicMock(spec=Context) + receiver = DishReceiver(context) + handler = MagicMock(spec=OnMessage) + receiver.register(handler) + + # When + receiver.deregister(handler) + + # Then + self.assertNotIn(handler, receiver.get_handlers()) + + def test_calls_registered_handler_on_message(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + context.socket.return_value.recv_json.return_value = SERVICE_INFO.__dict__ + handler = MagicMock(spec=OnMessage) + + with DishReceiver(context) as receiver: + receiver._poller = MagicMock(spec=Poller) + receiver._poller.poll.side_effect = [ + {context.socket.return_value: POLLIN}, + ] + receiver.register(handler) + + # When + receiver.start(group_access) + + # Then + wait_for_assertion(0.1, lambda: handler.assert_called_once_with(SERVICE_INFO.__dict__)) + + def test_handles_message_receive_error_gracefully(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + context.socket.return_value.recv_json.side_effect = ZMQError(1, "Receive failed") + handler = MagicMock(spec=OnMessage) + + with DishReceiver(context) as receiver: + receiver._poller = MagicMock(spec=Poller) + receiver._poller.poll.side_effect = [ + {context.socket.return_value: POLLIN}, + ] + receiver.register(handler) + + # When + receiver.start(group_access) + + # Then + handler.assert_not_called() + + def test_handles_handler_execution_error_gracefully(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + context.socket.return_value.recv_json.return_value = SERVICE_INFO.__dict__ + handler = MagicMock(spec=OnMessage) + handler.side_effect = Exception("Execution failed") + + with DishReceiver(context) as receiver: + receiver._poller = MagicMock(spec=Poller) + receiver._poller.poll.side_effect = [ + {context.socket.return_value: POLLIN}, + ] + receiver.register(handler) + + # When + receiver.start(group_access) + + # Then + wait_for_assertion(0.1, lambda: handler.assert_called_once_with(SERVICE_INFO.__dict__)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/radioSenderTest.py b/tests/radioSenderTest.py new file mode 100644 index 0000000..d1a285b --- /dev/null +++ b/tests/radioSenderTest.py @@ -0,0 +1,152 @@ +import unittest +from unittest import TestCase +from unittest.mock import MagicMock + +from context_logger import setup_logging +from zmq import Context, ZMQError + +from hello import ServiceInfo, Group, GroupAccess +from hello.sender import RadioSender + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) + + +class RadioSenderTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_raises_error_when_restarted(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + sender = RadioSender(context) + sender.start(group_access) + + # When, Then + with self.assertRaises(RuntimeError): + sender.start(group_access) + + def test_raises_error_when_fails_to_connect_socket(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + context.socket.return_value.connect.side_effect = ZMQError(1, "Connect failed") + sender = RadioSender(context) + + # When, Then + with self.assertRaises(ZMQError): + sender.start(group_access) + + def test_closes_socket_on_exit(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + + with RadioSender(context) as sender: + sender.start(group_access) + + # When + + # Then + context.socket.return_value.close.assert_called_once() + + def test_closes_socket_when_stopped(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + sender = RadioSender(context) + sender.start(group_access) + + # When + sender.stop() + + # Then + context.socket.return_value.close.assert_called_once() + + def test_raises_error_when_fails_to_close_socket_on_stop(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + context.socket.return_value.close.side_effect = ZMQError(1, "Close failed") + sender = RadioSender(context) + sender.start(group_access) + + # When, Then + with self.assertRaises(ZMQError): + sender.stop() + + def test_sends_message_when_convertible_to_dict(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + sender = RadioSender(context) + sender.start(group_access) + + # When + sender.send(SERVICE_INFO) + + # Then + context.socket.return_value.send_json.assert_called_with(SERVICE_INFO.__dict__, group='hello:test-group') + + def test_sends_message_when_type_is_dict(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + sender = RadioSender(context) + sender.start(group_access) + + # When + sender.send(SERVICE_INFO.__dict__) + + # Then + context.socket.return_value.send_json.assert_called_with(SERVICE_INFO.__dict__, group='hello:test-group') + + def test_does_not_send_message_when_not_serializable(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + sender = RadioSender(context) + sender.start(group_access) + + # When + sender.send("not serializable message") + + # Then + context.socket.return_value.send_json.assert_not_called() + + def test_does_not_send_message_when_not_started(self): + # Given + context = MagicMock(spec=Context) + sender = RadioSender(context) + + # When + sender.send(SERVICE_INFO) + + # Then + context.socket.return_value.send_json.assert_not_called() + + def test_handles_send_message_error_gracefully(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = MagicMock(spec=Context) + sender = RadioSender(context) + sender.start(group_access) + context.socket.return_value.send_json.side_effect = ZMQError(1, "Send failed") + + # When + sender.send(SERVICE_INFO) + + # Then + context.socket.return_value.send_json.assert_called_once_with(SERVICE_INFO.__dict__, group='hello:test-group') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/receiverIntegrationTest.py b/tests/receiverIntegrationTest.py new file mode 100644 index 0000000..b9ccd34 --- /dev/null +++ b/tests/receiverIntegrationTest.py @@ -0,0 +1,58 @@ +import unittest +from unittest import TestCase + +from context_logger import setup_logging +from test_utility import wait_for_assertion +from zmq import Context + +from hello import ServiceInfo, Group, DishReceiver, GroupAccess, RadioSender + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) + + +class ReceiverIntegrationTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_raises_error_when_restarted(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = Context() + + with DishReceiver(context) as receiver: + receiver.start(group_access) + + # When, Then + with self.assertRaises(RuntimeError): + receiver.start(group_access) + + def test_receives_message(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = Context() + messages = [] + + with DishReceiver(context) as receiver, RadioSender(context) as test_sender: + receiver.register(lambda message: messages.append(message)) + receiver.start(group_access) + test_sender.start(group_access) + + # When + test_sender.send(SERVICE_INFO) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(messages))) + + # Then + self.assertEqual([SERVICE_INFO.__dict__], messages) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/respondingAdvertizerTest.py b/tests/respondingAdvertizerTest.py new file mode 100644 index 0000000..d37581e --- /dev/null +++ b/tests/respondingAdvertizerTest.py @@ -0,0 +1,119 @@ +import unittest +from unittest import TestCase +from unittest.mock import MagicMock + +from context_logger import setup_logging + +from hello import ServiceInfo, Group, Sender, GroupAccess, Receiver, RespondingAdvertizer, ServiceQuery + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) + + +class RespondingAdvertizerTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_stops_sender_and_receiver_on_exit(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + + with RespondingAdvertizer(sender, receiver) as advertizer: + advertizer.start(ACCESS_URL, GROUP) + + # When + + # Then + sender.stop.assert_called_once() + receiver.stop.assert_called_once() + + def test_stops_sender_and_receiver_when_stopped(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + advertizer = RespondingAdvertizer(sender, receiver) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer.stop() + + # Then + sender.stop.assert_called_once() + receiver.stop.assert_called_once() + + def test_starts_sender_and_receiver_when_started(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + advertizer = RespondingAdvertizer(sender, receiver) + + # When + advertizer.start(ACCESS_URL, GROUP) + + # Then + sender.start.assert_called_once_with(GroupAccess(ACCESS_URL, GROUP.hello())) + receiver.start.assert_called_once_with(GroupAccess(ACCESS_URL, GROUP.query())) + + def test_sends_service_info_when_receives_matching_query(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + advertizer = RespondingAdvertizer(sender, receiver) + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + + # When + advertizer._handle_message(ServiceQuery('test-.*', 'test-.*').__dict__) + + # Then + sender.send.assert_called_once_with(SERVICE_INFO) + + def test_does_not_send_service_info_when_receives_non_matching_query(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + advertizer = RespondingAdvertizer(sender, receiver) + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + + # When + advertizer._handle_message(ServiceQuery('other-.*', 'test-.*').__dict__) + + # Then + sender.send.assert_not_called() + + def test_does_not_send_service_info_when_no_service_info_set(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + advertizer = RespondingAdvertizer(sender, receiver) + advertizer.start(ACCESS_URL, GROUP) + + # When + advertizer._handle_message(ServiceQuery('test-.*', 'test-.*').__dict__) + + # Then + sender.send.assert_not_called() + + def test_handles_invalid_message_gracefully(self): + # Given + sender = MagicMock(spec=Sender) + receiver = MagicMock(spec=Receiver) + advertizer = RespondingAdvertizer(sender, receiver) + advertizer.start(ACCESS_URL, GROUP, SERVICE_INFO) + + # When + advertizer._handle_message({'invalid': 'message'}) + + # Then + sender.send.assert_not_called() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/senderIntegrationTest.py b/tests/senderIntegrationTest.py new file mode 100644 index 0000000..c5b3c46 --- /dev/null +++ b/tests/senderIntegrationTest.py @@ -0,0 +1,47 @@ +import unittest +from unittest import TestCase + +from context_logger import setup_logging +from test_utility import wait_for_assertion +from zmq import Context + +from hello import ServiceInfo, Group, GroupAccess, DishReceiver +from hello.sender import RadioSender + +ACCESS_URL = 'udp://239.0.0.1:5555' +GROUP_NAME = 'test-group' +GROUP = Group(GROUP_NAME) +SERVICE_INFO = ServiceInfo('test-service', 'test-role', {'test': 'http://localhost:8080'}) + + +class SenderIntegrationTest(TestCase): + + @classmethod + def setUpClass(cls): + setup_logging('hello', 'DEBUG', warn_on_overwrite=False) + + def setUp(self): + print() + + def test_sends_message(self): + # Given + group_access = GroupAccess(ACCESS_URL, GROUP.hello()) + context = Context() + messages = [] + + with RadioSender(context) as sender, DishReceiver(context) as test_receiver: + test_receiver.register(lambda message: messages.append(message)) + test_receiver.start(group_access) + sender.start(group_access) + + # When + sender.send(SERVICE_INFO) + + wait_for_assertion(0.1, lambda: self.assertEqual(1, len(messages))) + + # Then + self.assertEqual([SERVICE_INFO.__dict__], messages) + + +if __name__ == '__main__': + unittest.main()