veretube_bot_lib/veretube_bot/async_bot.py
2026-05-31 22:50:04 +02:00

269 lines
8.9 KiB
Python

import asyncio
import inspect
import logging
from collections import defaultdict
from typing import Callable
import socketio
from ._api import BotAPI
logger = logging.getLogger(__name__)
_PASSTHROUGH_EVENTS = (
"chatMsg",
"pm",
"errorMsg",
"kick",
"announcement",
"clearchat",
"updateEmote",
"removeEmote",
)
class AsyncBot:
def __init__(
self,
token: str,
channel: str,
socket_url: str,
api_url: str,
transports: list[str] | None = None,
reconnection: bool = False,
reconnection_delay: int = 3,
):
if not token.startswith("cbt_"):
raise ValueError("token must start with 'cbt_'")
self.token = token
self.channel = channel
self.socket_url = socket_url
self.transports = transports
self.api = BotAPI(api_url, channel, token)
self.users: list[dict] = []
self.now_playing: dict | None = None
self.playlist: list[dict] = []
self.channel_opts: dict = {}
self.last_disconnect_reason = None
self._handlers: dict[str, list[Callable]] = defaultdict(list)
self._sio = socketio.AsyncClient(
reconnection=reconnection,
reconnection_delay=reconnection_delay,
logger=False,
engineio_logger=False,
)
self._wire_sio()
def _wire_sio(self):
sio = self._sio
@sio.on("connect")
async def _connect():
logger.info("connected to %s", self.socket_url)
await self._fire("connect", None)
@sio.on("disconnect")
async def _disconnect(*args):
reason = args[0] if args else None
self.last_disconnect_reason = reason
self.users = []
self.playlist = []
logger.info("disconnected: %s", reason)
await self._fire("disconnect", reason)
@sio.on("connect_error")
async def _connect_error(err):
logger.error("connection error: %s", err)
await self._fire("connect_error", err)
@sio.on("login")
async def _login(data):
if data.get("success"):
logger.info("authenticated as %s", data.get("name"))
try:
await sio.emit("joinChannel", {"name": self.channel}, namespace="/")
except Exception:
logger.exception("failed to emit joinChannel on '/' namespace")
else:
logger.error("authentication failed: %s", data.get("error", "unknown"))
await self._fire("login", data)
@sio.on("userlist")
async def _userlist(users):
self.users = list(users)
await self._fire("userlist", users)
@sio.on("addUser")
async def _add_user(user):
self.users = [u for u in self.users if u["name"] != user["name"]]
self.users.append(user)
await self._fire("addUser", user)
@sio.on("userLeave")
async def _user_leave(data):
self.users = [u for u in self.users if u["name"] != data["name"]]
await self._fire("userLeave", data)
@sio.on("setUserMeta")
async def _user_meta(data):
for user in self.users:
if user["name"] == data["name"]:
user.setdefault("meta", {}).update(data.get("meta", {}))
break
await self._fire("setUserMeta", data)
@sio.on("setUserRank")
async def _user_rank(data):
for user in self.users:
if user["name"] == data["name"]:
user["rank"] = data["rank"]
break
await self._fire("setUserRank", data)
@sio.on("changeMedia")
async def _change_media(data):
self.now_playing = data
await self._fire("changeMedia", data)
@sio.on("playlist")
async def _playlist(items):
self.playlist = list(items)
await self._fire("playlist", items)
@sio.on("queue")
async def _queue(data):
item = data.get("item")
if item:
self.playlist.append(item)
await self._fire("queue", data)
@sio.on("delete")
async def _delete(data):
uid = data.get("uid")
self.playlist = [i for i in self.playlist if i.get("uid") != uid]
await self._fire("delete", data)
@sio.on("channelOpts")
async def _channel_opts(opts):
self.channel_opts = opts
await self._fire("channelOpts", opts)
for _event in _PASSTHROUGH_EVENTS:
def _make(ev: str):
@sio.on(ev)
async def _handler(data=None):
await self._fire(ev, data)
_make(_event)
async def _fire(self, event: str, data):
for handler in self._handlers[event]:
try:
result = handler(data)
if inspect.isawaitable(result):
await result
except Exception:
logger.exception("unhandled exception in %s handler", event)
def on(self, event: str) -> Callable:
def decorator(fn: Callable) -> Callable:
self._handlers[event].append(fn)
return fn
return decorator
async def _wait_for_namespace(self, namespace: str = "/", timeout: float = 10.0):
deadline = asyncio.get_running_loop().time() + timeout
while namespace not in self._sio.namespaces:
if asyncio.get_running_loop().time() >= deadline:
raise TimeoutError(
f"namespace {namespace!r} did not connect within {timeout:.1f}s"
)
await self._sio.sleep(0.05)
async def _emit_guarded(self, event: str, data: dict, namespace: str = "/"):
if not self._sio.connected:
raise RuntimeError(
f"socket is not connected; last_disconnect_reason={self.last_disconnect_reason!r}"
)
if namespace not in self._sio.namespaces:
raise RuntimeError(
f"namespace {namespace!r} is not connected; call connect() and wait for readiness"
)
await self._sio.emit(event, data, namespace=namespace)
async def connect(self, timeout: float = 10.0):
await self._sio.connect(
self.socket_url,
auth={"token": self.token},
transports=self.transports,
namespaces=["/"],
wait=True,
wait_timeout=timeout,
)
await self._wait_for_namespace("/", timeout=timeout)
async def wait(self):
await self._sio.wait()
async def run(self, timeout: float = 10.0):
await self.connect(timeout=timeout)
await self.wait()
async def disconnect(self):
await self._sio.disconnect()
async def send_message(self, msg: str, to: str | None = None):
if to:
msg = f"{to}: {msg}"
await self._emit_guarded("chatMsg", {"msg": msg, "meta": {}})
async def send_action(self, text: str):
await self._emit_guarded("chatMsg", {"msg": f"/me {text}", "meta": {}})
async def send_pm(self, to: str, msg: str):
await self._emit_guarded("pm", {"to": to, "msg": msg, "meta": {}})
async def queue(self, id: str, type: str, pos: str = "end"):
await asyncio.to_thread(self.api.add_to_playlist, id, type, pos)
async def delete_item(self, uid: int):
await asyncio.to_thread(self.api.delete_playlist_item, uid)
async def skip_to(self, uid: int):
await asyncio.to_thread(self.api.skip_to, uid)
async def skip(self):
data = await asyncio.to_thread(self.api.get_playlist)
items = data.get("items", [])
idx = data.get("currentIndex", -1)
if 0 <= idx < len(items) - 1:
await asyncio.to_thread(self.api.skip_to, items[idx + 1]["uid"])
async def shuffle_playlist(self):
await asyncio.to_thread(self.api.shuffle_playlist)
async def clear_playlist(self):
await asyncio.to_thread(self.api.clear_playlist)
async def list_shows(self) -> list:
return await asyncio.to_thread(self.api.list_shows)
async def list_public_shows(self) -> list:
return await asyncio.to_thread(self.api.list_public_shows)
async def get_show(self, show_id: int | str) -> dict:
return await asyncio.to_thread(self.api.get_show, show_id)
async def create_show(self, payload: dict) -> dict:
return await asyncio.to_thread(self.api.create_show, payload)
async def update_show(self, show_id: int | str, payload: dict) -> dict:
return await asyncio.to_thread(self.api.update_show, show_id, payload)
async def delete_show(self, show_id: int | str):
await asyncio.to_thread(self.api.delete_show, show_id)
async def show_action(self, show_id: int | str, action: str) -> dict:
return await asyncio.to_thread(self.api.show_action, show_id, action)