2023-11-28 13:13:13 -08:00
import json
2023-11-28 17:05:43 -08:00
from typing import Any , Dict , Iterator
2023-05-23 14:24:13 -07:00
2023-11-28 13:13:13 -08:00
import bson
2023-07-31 18:05:13 -07:00
from websockets . client import WebSocketClientProtocol , connect as ws_connect_async
from websockets . sync . client import ClientConnection , connect as ws_connect
2023-05-23 14:24:13 -07:00
from . . . client import Client
2023-10-12 09:02:59 -07:00
from . . . models . web_socket_request import WebSocketRequest
2023-11-28 13:13:13 -08:00
from . . . models . web_socket_response import WebSocketResponse
2023-05-23 14:24:13 -07:00
def _get_kwargs (
2023-08-16 16:31:50 -07:00
fps : int ,
unlocked_framerate : bool ,
video_res_height : int ,
video_res_width : int ,
2023-08-30 15:59:51 -07:00
webrtc : bool ,
2023-05-23 14:24:13 -07:00
* ,
client : Client ,
) - > Dict [ str , Any ] :
2023-11-27 16:01:20 -08:00
url = " {} /ws/modeling/commands " . format ( client . base_url ) # noqa: E501
2023-08-16 16:31:50 -07:00
if fps is not None :
if " ? " in url :
url = url + " &fps= " + str ( fps )
else :
url = url + " ?fps= " + str ( fps )
2023-11-27 16:01:20 -08:00
2023-08-16 16:31:50 -07:00
if unlocked_framerate is not None :
if " ? " in url :
2023-11-28 13:13:13 -08:00
url = url + " &unlocked_framerate= " + str ( unlocked_framerate ) . lower ( )
2023-08-16 16:31:50 -07:00
else :
2023-11-28 13:13:13 -08:00
url = url + " ?unlocked_framerate= " + str ( unlocked_framerate ) . lower ( )
2023-11-27 16:01:20 -08:00
2023-08-16 16:31:50 -07:00
if video_res_height is not None :
if " ? " in url :
url = url + " &video_res_height= " + str ( video_res_height )
else :
url = url + " ?video_res_height= " + str ( video_res_height )
2023-11-27 16:01:20 -08:00
2023-08-16 16:31:50 -07:00
if video_res_width is not None :
if " ? " in url :
url = url + " &video_res_width= " + str ( video_res_width )
else :
url = url + " ?video_res_width= " + str ( video_res_width )
2023-11-27 16:01:20 -08:00
2023-08-30 15:59:51 -07:00
if webrtc is not None :
if " ? " in url :
2023-11-28 13:13:13 -08:00
url = url + " &webrtc= " + str ( webrtc ) . lower ( )
2023-08-30 15:59:51 -07:00
else :
2023-11-28 13:13:13 -08:00
url = url + " ?webrtc= " + str ( webrtc ) . lower ( )
2023-08-30 15:59:51 -07:00
2023-05-23 14:24:13 -07:00
headers : Dict [ str , Any ] = client . get_headers ( )
cookies : Dict [ str , Any ] = client . get_cookies ( )
return {
" url " : url ,
" headers " : headers ,
" cookies " : cookies ,
" timeout " : client . get_timeout ( ) ,
}
2023-07-31 18:05:13 -07:00
def sync (
2023-08-16 16:31:50 -07:00
fps : int ,
unlocked_framerate : bool ,
video_res_height : int ,
video_res_width : int ,
2023-08-30 15:59:51 -07:00
webrtc : bool ,
2023-05-23 14:24:13 -07:00
* ,
client : Client ,
2023-07-31 18:05:13 -07:00
) - > ClientConnection :
""" Pass those commands to the engine via websocket, and pass responses back to the client. Basically, this is a websocket proxy between the frontend/client and the engine. """ # noqa: E501
2023-05-23 14:24:13 -07:00
kwargs = _get_kwargs (
2023-08-16 16:31:50 -07:00
fps = fps ,
unlocked_framerate = unlocked_framerate ,
video_res_height = video_res_height ,
video_res_width = video_res_width ,
2023-08-30 15:59:51 -07:00
webrtc = webrtc ,
2023-05-23 14:24:13 -07:00
client = client ,
)
2023-11-28 16:37:47 -08:00
return ws_connect ( kwargs [ " url " ] . replace ( " http " , " ws " ) , additional_headers = kwargs [ " headers " ] , close_timeout = None , compression = None , max_size = None ) # type: ignore
2023-05-23 14:24:13 -07:00
2023-07-31 18:05:13 -07:00
async def asyncio (
2023-08-16 16:31:50 -07:00
fps : int ,
unlocked_framerate : bool ,
video_res_height : int ,
video_res_width : int ,
2023-08-30 15:59:51 -07:00
webrtc : bool ,
2023-05-23 14:24:13 -07:00
* ,
client : Client ,
2023-07-31 18:05:13 -07:00
) - > WebSocketClientProtocol :
2023-05-23 14:24:13 -07:00
""" Pass those commands to the engine via websocket, and pass responses back to the client. Basically, this is a websocket proxy between the frontend/client and the engine. """ # noqa: E501
kwargs = _get_kwargs (
2023-08-16 16:31:50 -07:00
fps = fps ,
unlocked_framerate = unlocked_framerate ,
video_res_height = video_res_height ,
video_res_width = video_res_width ,
2023-08-30 15:59:51 -07:00
webrtc = webrtc ,
2023-05-23 14:24:13 -07:00
client = client ,
)
2023-11-28 16:43:20 -08:00
return await ws_connect_async (
2023-11-28 16:37:47 -08:00
kwargs [ " url " ] . replace ( " http " , " ws " ) , extra_headers = kwargs [ " headers " ]
)
2023-11-28 13:13:13 -08:00
class WebSocket :
""" A websocket connection to the API endpoint. """
ws : ClientConnection
def __init__ (
self ,
fps : int ,
unlocked_framerate : bool ,
video_res_height : int ,
video_res_width : int ,
webrtc : bool ,
client : Client ,
) :
self . ws = sync (
fps ,
unlocked_framerate ,
video_res_height ,
video_res_width ,
webrtc ,
client = client ,
)
2023-11-28 16:37:47 -08:00
def __enter__ (
self ,
) :
return self
def __exit__ ( self , exc_type , exc_value , traceback ) :
self . close ( )
2023-11-28 17:05:43 -08:00
def __iter__ ( self ) - > Iterator [ WebSocketResponse ] :
"""
Iterate on incoming messages .
The iterator calls : meth : ` recv ` and yields messages in an infinite loop .
It exits when the connection is closed normally . It raises a
: exc : ` ~ websockets . exceptions . ConnectionClosedError ` exception after a
protocol error or a network failure .
"""
for message in self . ws :
2023-11-28 17:22:38 -08:00
yield WebSocketResponse . from_dict ( json . loads ( message ) )
2023-11-28 17:05:43 -08:00
2023-11-28 13:13:13 -08:00
def send ( self , data : WebSocketRequest ) :
""" Send data to the websocket. """
self . ws . send ( json . dumps ( data . to_dict ( ) ) )
def send_binary ( self , data : WebSocketRequest ) :
""" Send data as bson to the websocket. """
self . ws . send ( bson . BSON . encode ( data . to_dict ( ) ) )
2023-11-28 14:16:05 -08:00
def recv ( self ) - > WebSocketResponse :
2023-11-28 13:13:13 -08:00
""" Receive data from the websocket. """
message = self . ws . recv ( )
2023-11-28 14:16:05 -08:00
return WebSocketResponse . from_dict ( json . loads ( message ) )
def close ( self ) :
""" Close the websocket. """
self . ws . close ( )