Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 102 additions & 31 deletions interpreter/core/async_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,18 @@
try:
import janus
import uvicorn
from fastapi import APIRouter, FastAPI, File, Form, UploadFile, WebSocket
from fastapi.responses import PlainTextResponse, StreamingResponse
from fastapi import (
APIRouter,
FastAPI,
File,
Form,
HTTPException,
Request,
UploadFile,
WebSocket,
)
from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse
from starlette.status import HTTP_403_FORBIDDEN
except:
# Server dependencies are not required by the main package.
pass
Expand Down Expand Up @@ -204,6 +214,24 @@ def accumulate(self, chunk):
self.messages[-1]["content"] += chunk


def authenticate_function(key):
"""
This function checks if the provided key is valid for authentication.

Returns True if the key is valid, False otherwise.
"""
# Fetch the API key from the environment variables. If it's not set, return True.
api_key = os.getenv("INTERPRETER_API_KEY", None)

# If the API key is not set in the environment variables, return True.
# Otherwise, check if the provided key matches the fetched API key.
# Return True if they match, False otherwise.
if api_key is None:
return True
else:
return key == api_key


def create_router(async_interpreter):
router = APIRouter()

Expand All @@ -226,6 +254,7 @@ async def home():
<button>Send</button>
</form>
<button id="approveCodeButton">Approve Code</button>
<button id="authButton">Send Auth</button>
<div id="messages"></div>
<script>
var ws = new WebSocket("ws://"""
Expand All @@ -234,6 +263,7 @@ async def home():
+ str(async_interpreter.server.port)
+ """/");
var lastMessageElement = null;

ws.onmessage = function(event) {

var eventData = JSON.parse(event.data);
Expand Down Expand Up @@ -326,8 +356,15 @@ async def home():
};
ws.send(JSON.stringify(endCommandBlock));
}
function authenticate() {
var authBlock = {
"auth": "dummy-api-key"
};
ws.send(JSON.stringify(authBlock));
}

document.getElementById("approveCodeButton").addEventListener("click", approveCode);
document.getElementById("authButton").addEventListener("click", authenticate);
</script>
</body>
</html>
Expand All @@ -338,13 +375,30 @@ async def home():
@router.websocket("/")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()

try:

async def receive_input():
authenticated = False
while True:
try:
data = await websocket.receive()

if not authenticated:
if "text" in data:
data = json.loads(data["text"])
if "auth" in data:
if async_interpreter.server.authenticate(
data["auth"]
):
authenticated = True
await websocket.send_text(
json.dumps({"auth": True})
)
if not authenticated:
await websocket.send_text(json.dumps({"auth": False}))
continue

if data.get("type") == "websocket.receive":
if "text" in data:
data = json.loads(data["text"])
Expand Down Expand Up @@ -474,19 +528,6 @@ async def post_input(payload: Dict[str, Any]):
except Exception as e:
return {"error": str(e)}, 500

@router.post("/run")
async def run_code(payload: Dict[str, Any]):
language, code = payload.get("language"), payload.get("code")
if not (language and code):
return {"error": "Both 'language' and 'code' are required."}, 400
try:
print(f"Running {language}:", code)
output = async_interpreter.computer.run(language, code)
print("Output:", output)
return {"output": output}
except Exception as e:
return {"error": str(e)}, 500

@router.post("/settings")
async def set_settings(payload: Dict[str, Any]):
for key, value in payload.items():
Expand Down Expand Up @@ -520,23 +561,38 @@ async def get_setting(setting: str):
else:
return json.dumps({"error": "Setting not found"}), 404

@router.post("/upload")
async def upload_file(file: UploadFile = File(...), path: str = Form(...)):
try:
with open(path, "wb") as output_file:
shutil.copyfileobj(file.file, output_file)
return {"status": "success"}
except Exception as e:
return {"error": str(e)}, 500
if os.getenv("INTERPRETER_INSECURE_ROUTES", "").lower() == "true":

@router.get("/download/{filename}")
async def download_file(filename: str):
try:
return StreamingResponse(
open(filename, "rb"), media_type="application/octet-stream"
)
except Exception as e:
return {"error": str(e)}, 500
@router.post("/run")
async def run_code(payload: Dict[str, Any]):
language, code = payload.get("language"), payload.get("code")
if not (language and code):
return {"error": "Both 'language' and 'code' are required."}, 400
try:
print(f"Running {language}:", code)
output = async_interpreter.computer.run(language, code)
print("Output:", output)
return {"output": output}
except Exception as e:
return {"error": str(e)}, 500

@router.post("/upload")
async def upload_file(file: UploadFile = File(...), path: str = Form(...)):
try:
with open(path, "wb") as output_file:
shutil.copyfileobj(file.file, output_file)
return {"status": "success"}
except Exception as e:
return {"error": str(e)}, 500

@router.get("/download/{filename}")
async def download_file(filename: str):
try:
return StreamingResponse(
open(filename, "rb"), media_type="application/octet-stream"
)
except Exception as e:
return {"error": str(e)}, 500

### OPENAI COMPATIBLE ENDPOINT

Expand Down Expand Up @@ -648,6 +704,21 @@ class Server:
def __init__(self, async_interpreter, host="127.0.0.1", port=8000):
self.app = FastAPI()
router = create_router(async_interpreter)
self.authenticate = authenticate_function

# Add authentication middleware
@self.app.middleware("http")
async def validate_api_key(request: Request, call_next):
api_key = request.headers.get("X-API-KEY")
if self.authenticate(api_key):
response = await call_next(request)
return response
else:
return JSONResponse(
status_code=HTTP_403_FORBIDDEN,
content={"detail": "Authentication failed"},
)

self.app.include_router(router)
self.config = uvicorn.Config(app=self.app, host=host, port=port)
self.uvicorn_server = uvicorn.Server(self.config)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ async def test_fastapi_server():
# Connect to the websocket
print("Connected to WebSocket")

# Sending message via WebSocket
await websocket.send(json.dumps({"auth": "dummy-api-key"}))

# Sending POST request
post_url = "http://localhost:8000/settings"
settings = {
Expand Down