diff --git a/interpreter/core/async_core.py b/interpreter/core/async_core.py index 8c6096598c..df56f3725b 100644 --- a/interpreter/core/async_core.py +++ b/interpreter/core/async_core.py @@ -17,8 +17,18 @@ try: import janus import uvicorn - from fastapi import APIRouter, FastAPI, File, Form, UploadFile, WebSocket - from fastapi.responses import PlainTextResponse, StreamingResponse + from fastapi import ( + APIRouter, + FastAPI, + File, + Form, + HTTPException, + Request, + UploadFile, + WebSocket, + ) + from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse + from starlette.status import HTTP_403_FORBIDDEN except: # Server dependencies are not required by the main package. pass @@ -204,6 +214,24 @@ def accumulate(self, chunk): self.messages[-1]["content"] += chunk +def authenticate_function(key): + """ + This function checks if the provided key is valid for authentication. + + Returns True if the key is valid, False otherwise. + """ + # Fetch the API key from the environment variables. If it's not set, return True. + api_key = os.getenv("INTERPRETER_API_KEY", None) + + # If the API key is not set in the environment variables, return True. + # Otherwise, check if the provided key matches the fetched API key. + # Return True if they match, False otherwise. + if api_key is None: + return True + else: + return key == api_key + + def create_router(async_interpreter): router = APIRouter() @@ -226,6 +254,7 @@ async def home(): +
@@ -338,13 +375,30 @@ async def home(): @router.websocket("/") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() + try: async def receive_input(): + authenticated = False while True: try: data = await websocket.receive() + if not authenticated: + if "text" in data: + data = json.loads(data["text"]) + if "auth" in data: + if async_interpreter.server.authenticate( + data["auth"] + ): + authenticated = True + await websocket.send_text( + json.dumps({"auth": True}) + ) + if not authenticated: + await websocket.send_text(json.dumps({"auth": False})) + continue + if data.get("type") == "websocket.receive": if "text" in data: data = json.loads(data["text"]) @@ -474,19 +528,6 @@ async def post_input(payload: Dict[str, Any]): except Exception as e: return {"error": str(e)}, 500 - @router.post("/run") - async def run_code(payload: Dict[str, Any]): - language, code = payload.get("language"), payload.get("code") - if not (language and code): - return {"error": "Both 'language' and 'code' are required."}, 400 - try: - print(f"Running {language}:", code) - output = async_interpreter.computer.run(language, code) - print("Output:", output) - return {"output": output} - except Exception as e: - return {"error": str(e)}, 500 - @router.post("/settings") async def set_settings(payload: Dict[str, Any]): for key, value in payload.items(): @@ -520,23 +561,38 @@ async def get_setting(setting: str): else: return json.dumps({"error": "Setting not found"}), 404 - @router.post("/upload") - async def upload_file(file: UploadFile = File(...), path: str = Form(...)): - try: - with open(path, "wb") as output_file: - shutil.copyfileobj(file.file, output_file) - return {"status": "success"} - except Exception as e: - return {"error": str(e)}, 500 + if os.getenv("INTERPRETER_INSECURE_ROUTES", "").lower() == "true": - @router.get("/download/{filename}") - async def download_file(filename: str): - try: - return StreamingResponse( - open(filename, "rb"), media_type="application/octet-stream" - ) - except Exception as e: - return {"error": str(e)}, 500 + @router.post("/run") + async def run_code(payload: Dict[str, Any]): + language, code = payload.get("language"), payload.get("code") + if not (language and code): + return {"error": "Both 'language' and 'code' are required."}, 400 + try: + print(f"Running {language}:", code) + output = async_interpreter.computer.run(language, code) + print("Output:", output) + return {"output": output} + except Exception as e: + return {"error": str(e)}, 500 + + @router.post("/upload") + async def upload_file(file: UploadFile = File(...), path: str = Form(...)): + try: + with open(path, "wb") as output_file: + shutil.copyfileobj(file.file, output_file) + return {"status": "success"} + except Exception as e: + return {"error": str(e)}, 500 + + @router.get("/download/{filename}") + async def download_file(filename: str): + try: + return StreamingResponse( + open(filename, "rb"), media_type="application/octet-stream" + ) + except Exception as e: + return {"error": str(e)}, 500 ### OPENAI COMPATIBLE ENDPOINT @@ -648,6 +704,21 @@ class Server: def __init__(self, async_interpreter, host="127.0.0.1", port=8000): self.app = FastAPI() router = create_router(async_interpreter) + self.authenticate = authenticate_function + + # Add authentication middleware + @self.app.middleware("http") + async def validate_api_key(request: Request, call_next): + api_key = request.headers.get("X-API-KEY") + if self.authenticate(api_key): + response = await call_next(request) + return response + else: + return JSONResponse( + status_code=HTTP_403_FORBIDDEN, + content={"detail": "Authentication failed"}, + ) + self.app.include_router(router) self.config = uvicorn.Config(app=self.app, host=host, port=port) self.uvicorn_server = uvicorn.Server(self.config) diff --git a/tests/test_interpreter.py b/tests/test_interpreter.py index c842bc6535..ee8c3aabdd 100644 --- a/tests/test_interpreter.py +++ b/tests/test_interpreter.py @@ -53,6 +53,9 @@ async def test_fastapi_server(): # Connect to the websocket print("Connected to WebSocket") + # Sending message via WebSocket + await websocket.send(json.dumps({"auth": "dummy-api-key"})) + # Sending POST request post_url = "http://localhost:8000/settings" settings = {