better working ws

Signed-off-by: Jess Frazelle <github@jessfraz.com>
This commit is contained in:
Jess Frazelle
2023-11-28 13:13:13 -08:00
parent be246702fd
commit b6aa9ab98b
37 changed files with 771 additions and 593 deletions

View File

@ -1,11 +1,14 @@
from typing import Any, Dict
import json
from typing import Any, Dict, Optional
import bson
from websockets.client import WebSocketClientProtocol, connect as ws_connect_async
from websockets.sync.client import ClientConnection, connect as ws_connect
from ...client import Client
from ...models.error import Error
from ...models.web_socket_request import WebSocketRequest
from ...models.web_socket_response import WebSocketResponse
def _get_kwargs(
@ -14,7 +17,6 @@ def _get_kwargs(
video_res_height: int,
video_res_width: int,
webrtc: bool,
body: WebSocketRequest,
*,
client: Client,
) -> Dict[str, Any]:
@ -28,9 +30,9 @@ def _get_kwargs(
if unlocked_framerate is not None:
if "?" in url:
url = url + "&unlocked_framerate=" + str(unlocked_framerate)
url = url + "&unlocked_framerate=" + str(unlocked_framerate).lower()
else:
url = url + "?unlocked_framerate=" + str(unlocked_framerate)
url = url + "?unlocked_framerate=" + str(unlocked_framerate).lower()
if video_res_height is not None:
if "?" in url:
@ -46,9 +48,9 @@ def _get_kwargs(
if webrtc is not None:
if "?" in url:
url = url + "&webrtc=" + str(webrtc)
url = url + "&webrtc=" + str(webrtc).lower()
else:
url = url + "?webrtc=" + str(webrtc)
url = url + "?webrtc=" + str(webrtc).lower()
headers: Dict[str, Any] = client.get_headers()
cookies: Dict[str, Any] = client.get_cookies()
@ -58,7 +60,6 @@ def _get_kwargs(
"headers": headers,
"cookies": cookies,
"timeout": client.get_timeout(),
"content": body,
}
@ -68,7 +69,6 @@ def sync(
video_res_height: int,
video_res_width: int,
webrtc: bool,
body: WebSocketRequest,
*,
client: Client,
) -> ClientConnection:
@ -80,7 +80,6 @@ def sync(
video_res_height=video_res_height,
video_res_width=video_res_width,
webrtc=webrtc,
body=body,
client=client,
)
@ -100,7 +99,6 @@ async def asyncio(
video_res_height: int,
video_res_width: int,
webrtc: bool,
body: WebSocketRequest,
*,
client: Client,
) -> WebSocketClientProtocol:
@ -112,7 +110,6 @@ async def asyncio(
video_res_height=video_res_height,
video_res_width=video_res_width,
webrtc=webrtc,
body=body,
client=client,
)
@ -123,3 +120,40 @@ async def asyncio(
# Return an error if we got here.
return Error(message="An error occurred while connecting to the websocket.")
class WebSocket:
"""A websocket connection to the API endpoint."""
ws: ClientConnection
def __init__(
self,
fps: int,
unlocked_framerate: bool,
video_res_height: int,
video_res_width: int,
webrtc: bool,
client: Client,
):
self.ws = sync(
fps,
unlocked_framerate,
video_res_height,
video_res_width,
webrtc,
client=client,
)
def send(self, data: WebSocketRequest):
"""Send data to the websocket."""
self.ws.send(json.dumps(data.to_dict()))
def send_binary(self, data: WebSocketRequest):
"""Send data as bson to the websocket."""
self.ws.send(bson.BSON.encode(data.to_dict()))
def recv(self) -> Optional[WebSocketResponse]:
"""Receive data from the websocket."""
message = self.ws.recv()
return Optional[WebSocketResponse].from_dict(json.loads(message))