yt-dlp-bot/start.py

111 lines
3.1 KiB
Python
Raw Normal View History

2023-03-23 05:24:06 +08:00
"""Over-engineered Python 3.10+ version of bash script with netcat (nc) just for fun.
#!/bin/bash
check_reachability() {
while ! nc -z "$1" "${!2}"
do
echo "Waiting for $3 to be reachable on port ${!2}"
sleep 1
done
echo "Connection to $3 on port ${!2} verified"
return 0
}
wait_for_services_to_be_reachable() {
check_reachability rabbitmq RABBITMQ_PORT RabbitMQ
check_reachability postgres POSTGRES_PORT PostgreSQL
}
wait_for_services_to_be_reachable
exit 0
"""
import asyncio
import logging
import os
import sys
from dataclasses import dataclass, field
2023-03-23 06:21:36 +08:00
from typing import Generator, Type
2023-03-23 05:24:06 +08:00
SOCK_CONNECTED = 0
DEFAULT_PORT = 0
DEFAULT_SLEEP_TIME = 1
class ServiceRegistry(type):
REGISTRY: dict[str, type['BaseService']] = {}
def __new__(
mcs: Type['ServiceRegistry'],
name: str,
bases: tuple[type['BaseService']],
attrs: dict,
) -> type['BaseService']:
service_cls: type['BaseService'] = type.__new__(mcs, name, bases, attrs)
mcs.REGISTRY[service_cls.__name__] = service_cls
return service_cls
@classmethod
2023-03-23 06:21:36 +08:00
def get_registry(mcs) -> dict[str, type['BaseService']]:
2023-03-23 05:24:06 +08:00
return mcs.REGISTRY.copy()
2023-03-23 06:21:36 +08:00
@classmethod
def get_instances(mcs) -> Generator['BaseService', None, None]:
return (service_cls() for service_cls in mcs.REGISTRY.values())
2023-03-23 05:24:06 +08:00
@dataclass
class BaseService:
name: str = field(default='', init=False)
host: str = field(default='', init=False)
port: int = field(default=DEFAULT_PORT, init=False)
def __post_init__(self) -> None:
if self.__class__ is BaseService:
raise TypeError('Cannot instantiate abstract class.')
@dataclass
class RabbitMQService(BaseService, metaclass=ServiceRegistry):
name: str = field(default='RabbitMQ')
host: str = field(default=os.getenv('RABBITMQ_HOST'))
2023-03-29 03:26:16 +08:00
port: int = field(default=int(os.getenv('RABBITMQ_PORT', DEFAULT_PORT)))
2023-03-23 05:24:06 +08:00
@dataclass
class PostgreSQLService(BaseService, metaclass=ServiceRegistry):
name: str = field(default='PostgreSQL')
host: str = field(default=os.getenv('POSTGRES_HOST'))
2023-03-29 03:26:16 +08:00
port: int = field(default=int(os.getenv('POSTGRES_PORT', DEFAULT_PORT)))
2023-03-23 05:24:06 +08:00
async def is_port_open(host: str, port: int) -> bool:
2023-03-24 06:06:20 +08:00
try:
reader, writer = await asyncio.open_connection(host, port)
writer.close()
await writer.wait_closed()
return True
except Exception:
return False
2023-03-23 05:24:06 +08:00
async def check_reachability(service: BaseService) -> None:
while True:
2023-07-29 04:25:08 +08:00
print(f'[{service.name}] Waiting to be reachable on port {service.port}')
2023-03-23 05:24:06 +08:00
if await is_port_open(host=service.host, port=service.port):
break
await asyncio.sleep(DEFAULT_SLEEP_TIME)
2023-07-29 04:25:08 +08:00
print(f'[{service.name}] Connection on port {service.port} verified')
2023-03-23 05:24:06 +08:00
async def main() -> None:
logging.getLogger('asyncio').setLevel(logging.ERROR)
2023-03-23 06:21:36 +08:00
coros = [check_reachability(service) for service in ServiceRegistry.get_instances()]
2023-03-23 05:24:06 +08:00
await asyncio.gather(*coros)
if __name__ == '__main__':
sys.exit(asyncio.run(main()))