228 lines
7.6 KiB
Rust
228 lines
7.6 KiB
Rust
//! Executes KCL programs.
|
|
//! The server reuses the same engine session for each KCL program it receives.
|
|
use std::{
|
|
net::SocketAddr,
|
|
sync::{
|
|
atomic::{AtomicUsize, Ordering},
|
|
Arc,
|
|
},
|
|
time::Duration,
|
|
};
|
|
|
|
use hyper::{
|
|
body::Bytes,
|
|
header::CONTENT_TYPE,
|
|
service::{make_service_fn, service_fn},
|
|
Body, Error, Response, Server,
|
|
};
|
|
use kcl_lib::{test_server::RequestBody, ExecState, ExecutorContext, Program};
|
|
use tokio::{
|
|
sync::{mpsc, oneshot},
|
|
task::JoinHandle,
|
|
time::sleep,
|
|
};
|
|
|
|
#[derive(Debug)]
|
|
pub struct ServerArgs {
|
|
/// What port this server should listen on.
|
|
pub listen_on: SocketAddr,
|
|
/// How many connections to establish with the engine.
|
|
pub num_engine_conns: u8,
|
|
/// Where to find the engine.
|
|
/// If none, uses the prod engine.
|
|
/// This is useful for testing a local engine instance.
|
|
/// Overridden by the $ZOO_HOST environment variable.
|
|
pub engine_address: Option<String>,
|
|
}
|
|
|
|
impl ServerArgs {
|
|
pub fn parse(mut pargs: pico_args::Arguments) -> Result<Self, pico_args::Error> {
|
|
let mut args = ServerArgs {
|
|
listen_on: pargs
|
|
.opt_value_from_str("--listen-on")?
|
|
.unwrap_or("0.0.0.0:3333".parse().unwrap()),
|
|
num_engine_conns: pargs.opt_value_from_str("--num-engine-conns")?.unwrap_or(1),
|
|
engine_address: pargs.opt_value_from_str("--engine-address")?,
|
|
};
|
|
if let Ok(addr) = std::env::var("ZOO_HOST") {
|
|
println!("Overriding engine address via $ZOO_HOST");
|
|
args.engine_address = Some(addr);
|
|
}
|
|
println!("Config is {args:?}");
|
|
Ok(args)
|
|
}
|
|
}
|
|
|
|
/// Sent from the server to each worker.
|
|
struct WorkerReq {
|
|
/// A KCL program, in UTF-8.
|
|
body: Bytes,
|
|
/// A channel to send the HTTP response back.
|
|
resp: oneshot::Sender<Response<Body>>,
|
|
}
|
|
|
|
/// Each worker has a connection to the engine, and accepts
|
|
/// KCL programs. When it receives one (over the mpsc channel)
|
|
/// it executes it and returns the result via a oneshot channel.
|
|
fn start_worker(i: u8, engine_addr: Option<String>) -> mpsc::Sender<WorkerReq> {
|
|
println!("Starting worker {i}");
|
|
// Make a work queue for this worker.
|
|
let (tx, mut rx) = mpsc::channel(1);
|
|
tokio::task::spawn(async move {
|
|
let state = ExecutorContext::new_for_unit_test(engine_addr).await.unwrap();
|
|
println!("Worker {i} ready");
|
|
while let Some(req) = rx.recv().await {
|
|
let req: WorkerReq = req;
|
|
let resp = snapshot_endpoint(req.body, state.clone()).await;
|
|
if req.resp.send(resp).is_err() {
|
|
println!("\tWorker {i} exiting");
|
|
}
|
|
}
|
|
println!("\tWorker {i} exiting");
|
|
});
|
|
tx
|
|
}
|
|
|
|
struct ServerState {
|
|
workers: Vec<mpsc::Sender<WorkerReq>>,
|
|
req_num: AtomicUsize,
|
|
}
|
|
|
|
pub async fn start_server(args: ServerArgs) -> anyhow::Result<()> {
|
|
let ServerArgs {
|
|
listen_on,
|
|
num_engine_conns,
|
|
engine_address,
|
|
} = args;
|
|
let workers: Vec<_> = (0..num_engine_conns)
|
|
.map(|i| start_worker(i, engine_address.clone()))
|
|
.collect();
|
|
let state = Arc::new(ServerState {
|
|
workers,
|
|
req_num: 0.into(),
|
|
});
|
|
// In hyper, a `MakeService` is basically your server.
|
|
// It makes a `Service` for each connection, which manages the connection.
|
|
let make_service = make_service_fn(
|
|
// This closure is run for each connection.
|
|
move |_conn_info| {
|
|
// eprintln!("Connected to a client");
|
|
let state = state.clone();
|
|
async move {
|
|
// This is the `Service` which handles the connection.
|
|
// `service_fn` converts a function which returns a Response
|
|
// into a `Service`.
|
|
Ok::<_, Error>(service_fn(move |req| {
|
|
// eprintln!("Received a request");
|
|
let state = state.clone();
|
|
async move { handle_request(req, state).await }
|
|
}))
|
|
}
|
|
},
|
|
);
|
|
let server = Server::bind(&listen_on).serve(make_service);
|
|
println!("Listening on {listen_on}");
|
|
println!("PID is {}", std::process::id());
|
|
if let Err(e) = server.await {
|
|
eprintln!("Server error: {e}");
|
|
return Err(e.into());
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_request(req: hyper::Request<Body>, state3: Arc<ServerState>) -> Result<Response<Body>, Error> {
|
|
let body = hyper::body::to_bytes(req.into_body()).await?;
|
|
|
|
// Round robin requests between each available worker.
|
|
let req_num = state3.req_num.fetch_add(1, Ordering::Relaxed);
|
|
let worker_id = req_num % state3.workers.len();
|
|
// println!("Sending request {req_num} to worker {worker_id}");
|
|
let worker = state3.workers[worker_id].clone();
|
|
let (tx, rx) = oneshot::channel();
|
|
let req_sent = worker.send(WorkerReq { body, resp: tx }).await;
|
|
req_sent.unwrap();
|
|
let resp = rx.await.unwrap();
|
|
Ok(resp)
|
|
}
|
|
|
|
/// Execute a KCL program, then respond with a PNG snapshot.
|
|
/// KCL errors (from engine or the executor) respond with HTTP Bad Gateway.
|
|
/// Malformed requests are HTTP Bad Request.
|
|
/// Successful requests contain a PNG as the body.
|
|
async fn snapshot_endpoint(body: Bytes, ctxt: ExecutorContext) -> Response<Body> {
|
|
let body = match serde_json::from_slice::<RequestBody>(body.as_ref()) {
|
|
Ok(bd) => bd,
|
|
Err(e) => return bad_request(format!("Invalid request JSON: {e}")),
|
|
};
|
|
let RequestBody { kcl_program, test_name } = body;
|
|
|
|
let program = match Program::parse_no_errs(&kcl_program) {
|
|
Ok(pr) => pr,
|
|
Err(e) => return bad_request(format!("Parse error: {e}")),
|
|
};
|
|
|
|
eprintln!("Executing {test_name}");
|
|
let mut exec_state = ExecState::new(&ctxt);
|
|
// This is a shitty source range, I don't know what else to use for it though.
|
|
// There's no actual KCL associated with this reset_scene call.
|
|
if let Err(e) = ctxt
|
|
.send_clear_scene(&mut exec_state, kcl_lib::SourceRange::default())
|
|
.await
|
|
{
|
|
return kcl_err(e);
|
|
}
|
|
// Let users know if the test is taking a long time.
|
|
let (done_tx, done_rx) = oneshot::channel::<()>();
|
|
let timer = time_until(done_rx);
|
|
if let Err(e) = ctxt.run(&program, &mut exec_state).await {
|
|
return kcl_err(e);
|
|
}
|
|
let snapshot = match ctxt.prepare_snapshot().await {
|
|
Ok(s) => s,
|
|
Err(e) => return kcl_err(e),
|
|
};
|
|
let _ = done_tx.send(());
|
|
timer.abort();
|
|
eprintln!("\tServing response");
|
|
let png_bytes = snapshot.contents.0;
|
|
let mut resp = Response::new(Body::from(png_bytes));
|
|
resp.headers_mut().insert(CONTENT_TYPE, "image/png".parse().unwrap());
|
|
resp
|
|
}
|
|
|
|
fn bad_request(msg: String) -> Response<Body> {
|
|
eprintln!("\tBad request");
|
|
let mut resp = Response::new(Body::from(msg));
|
|
*resp.status_mut() = hyper::StatusCode::BAD_REQUEST;
|
|
resp
|
|
}
|
|
|
|
fn bad_gateway(msg: String) -> Response<Body> {
|
|
eprintln!("\tBad gateway");
|
|
let mut resp = Response::new(Body::from(msg));
|
|
*resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
|
|
resp
|
|
}
|
|
|
|
fn kcl_err(err: impl std::fmt::Display) -> Response<Body> {
|
|
eprintln!("\tBad KCL");
|
|
bad_gateway(format!("{err}"))
|
|
}
|
|
|
|
fn time_until(done: oneshot::Receiver<()>) -> JoinHandle<()> {
|
|
tokio::task::spawn(async move {
|
|
let period = 10;
|
|
tokio::pin!(done);
|
|
for i in 1..=3 {
|
|
tokio::select! {
|
|
biased;
|
|
// If the test is done, no need for this timer anymore.
|
|
_ = &mut done => return,
|
|
_ = sleep(Duration::from_secs(period)) => {
|
|
eprintln!("\tTest has taken {}s", period * i);
|
|
},
|
|
};
|
|
}
|
|
})
|
|
}
|