2023-11-28 13:13:13 -08:00
import json
from typing import Any , Dict , Optional
2023-05-23 14:24:13 -07:00
2023-11-28 13:13:13 -08:00
import bson
2023-07-31 18:05:13 -07:00
from websockets . client import WebSocketClientProtocol , connect as ws_connect_async
from websockets . sync . client import ClientConnection , connect as ws_connect
2023-05-23 14:24:13 -07:00
from . . . client import Client
2023-07-31 18:05:13 -07:00
from . . . models . error import Error
2023-10-12 09:02:59 -07:00
from . . . models . web_socket_request import WebSocketRequest
2023-11-28 13:13:13 -08:00
from . . . models . web_socket_response import WebSocketResponse
2023-05-23 14:24:13 -07:00
def _get_kwargs (
2023-08-16 16:31:50 -07:00
fps : int ,
unlocked_framerate : bool ,
video_res_height : int ,
video_res_width : int ,
2023-08-30 15:59:51 -07:00
webrtc : bool ,
2023-05-23 14:24:13 -07:00
* ,
client : Client ,
) - > Dict [ str , Any ] :
2023-11-27 16:01:20 -08:00
url = " {} /ws/modeling/commands " . format ( client . base_url ) # noqa: E501
2023-08-16 16:31:50 -07:00
if fps is not None :
if " ? " in url :
url = url + " &fps= " + str ( fps )
else :
url = url + " ?fps= " + str ( fps )
2023-11-27 16:01:20 -08:00
2023-08-16 16:31:50 -07:00
if unlocked_framerate is not None :
if " ? " in url :
2023-11-28 13:13:13 -08:00
url = url + " &unlocked_framerate= " + str ( unlocked_framerate ) . lower ( )
2023-08-16 16:31:50 -07:00
else :
2023-11-28 13:13:13 -08:00
url = url + " ?unlocked_framerate= " + str ( unlocked_framerate ) . lower ( )
2023-11-27 16:01:20 -08:00
2023-08-16 16:31:50 -07:00
if video_res_height is not None :
if " ? " in url :
url = url + " &video_res_height= " + str ( video_res_height )
else :
url = url + " ?video_res_height= " + str ( video_res_height )
2023-11-27 16:01:20 -08:00
2023-08-16 16:31:50 -07:00
if video_res_width is not None :
if " ? " in url :
url = url + " &video_res_width= " + str ( video_res_width )
else :
url = url + " ?video_res_width= " + str ( video_res_width )
2023-11-27 16:01:20 -08:00
2023-08-30 15:59:51 -07:00
if webrtc is not None :
if " ? " in url :
2023-11-28 13:13:13 -08:00
url = url + " &webrtc= " + str ( webrtc ) . lower ( )
2023-08-30 15:59:51 -07:00
else :
2023-11-28 13:13:13 -08:00
url = url + " ?webrtc= " + str ( webrtc ) . lower ( )
2023-08-30 15:59:51 -07:00
2023-05-23 14:24:13 -07:00
headers : Dict [ str , Any ] = client . get_headers ( )
cookies : Dict [ str , Any ] = client . get_cookies ( )
return {
" url " : url ,
" headers " : headers ,
" cookies " : cookies ,
" timeout " : client . get_timeout ( ) ,
}
2023-07-31 18:05:13 -07:00
def sync (
2023-08-16 16:31:50 -07:00
fps : int ,
unlocked_framerate : bool ,
video_res_height : int ,
video_res_width : int ,
2023-08-30 15:59:51 -07:00
webrtc : bool ,
2023-05-23 14:24:13 -07:00
* ,
client : Client ,
2023-07-31 18:05:13 -07:00
) - > ClientConnection :
""" Pass those commands to the engine via websocket, and pass responses back to the client. Basically, this is a websocket proxy between the frontend/client and the engine. """ # noqa: E501
2023-05-23 14:24:13 -07:00
kwargs = _get_kwargs (
2023-08-16 16:31:50 -07:00
fps = fps ,
unlocked_framerate = unlocked_framerate ,
video_res_height = video_res_height ,
video_res_width = video_res_width ,
2023-08-30 15:59:51 -07:00
webrtc = webrtc ,
2023-05-23 14:24:13 -07:00
client = client ,
)
2023-11-27 16:01:20 -08:00
with ws_connect (
kwargs [ " url " ] . replace ( " https:// " , " wss:// " ) ,
additional_headers = kwargs [ " headers " ] ,
) as websocket :
return websocket # type: ignore
2023-05-23 14:24:13 -07:00
2023-07-31 18:05:13 -07:00
# Return an error if we got here.
return Error ( message = " An error occurred while connecting to the websocket. " )
2023-05-23 14:24:13 -07:00
2023-07-31 18:05:13 -07:00
async def asyncio (
2023-08-16 16:31:50 -07:00
fps : int ,
unlocked_framerate : bool ,
video_res_height : int ,
video_res_width : int ,
2023-08-30 15:59:51 -07:00
webrtc : bool ,
2023-05-23 14:24:13 -07:00
* ,
client : Client ,
2023-07-31 18:05:13 -07:00
) - > WebSocketClientProtocol :
2023-05-23 14:24:13 -07:00
""" Pass those commands to the engine via websocket, and pass responses back to the client. Basically, this is a websocket proxy between the frontend/client and the engine. """ # noqa: E501
kwargs = _get_kwargs (
2023-08-16 16:31:50 -07:00
fps = fps ,
unlocked_framerate = unlocked_framerate ,
video_res_height = video_res_height ,
video_res_width = video_res_width ,
2023-08-30 15:59:51 -07:00
webrtc = webrtc ,
2023-05-23 14:24:13 -07:00
client = client ,
)
2023-11-27 16:01:20 -08:00
async with ws_connect_async (
kwargs [ " url " ] . replace ( " https:// " , " wss:// " ) , extra_headers = kwargs [ " headers " ]
) as websocket :
2023-07-31 18:05:13 -07:00
return websocket
2023-05-23 14:24:13 -07:00
2023-07-31 18:05:13 -07:00
# Return an error if we got here.
return Error ( message = " An error occurred while connecting to the websocket. " )
2023-11-28 13:13:13 -08:00
class WebSocket :
""" A websocket connection to the API endpoint. """
ws : ClientConnection
def __init__ (
self ,
fps : int ,
unlocked_framerate : bool ,
video_res_height : int ,
video_res_width : int ,
webrtc : bool ,
client : Client ,
) :
self . ws = sync (
fps ,
unlocked_framerate ,
video_res_height ,
video_res_width ,
webrtc ,
client = client ,
)
def send ( self , data : WebSocketRequest ) :
""" Send data to the websocket. """
self . ws . send ( json . dumps ( data . to_dict ( ) ) )
def send_binary ( self , data : WebSocketRequest ) :
""" Send data as bson to the websocket. """
self . ws . send ( bson . BSON . encode ( data . to_dict ( ) ) )
def recv ( self ) - > Optional [ WebSocketResponse ] :
""" Receive data from the websocket. """
message = self . ws . recv ( )
return Optional [ WebSocketResponse ] . from_dict ( json . loads ( message ) )