Spaces:
Running
Running
import os | |
import re | |
import logging | |
import aiohttp | |
from pathlib import Path | |
from typing import Optional | |
from open_webui.models.functions import ( | |
FunctionForm, | |
FunctionModel, | |
FunctionResponse, | |
Functions, | |
) | |
from open_webui.utils.plugin import ( | |
load_function_module_by_id, | |
replace_imports, | |
get_function_module_from_cache, | |
) | |
from open_webui.config import CACHE_DIR | |
from open_webui.constants import ERROR_MESSAGES | |
from fastapi import APIRouter, Depends, HTTPException, Request, status | |
from open_webui.utils.auth import get_admin_user, get_verified_user | |
from open_webui.env import SRC_LOG_LEVELS | |
from pydantic import BaseModel, HttpUrl | |
log = logging.getLogger(__name__) | |
log.setLevel(SRC_LOG_LEVELS["MAIN"]) | |
router = APIRouter() | |
############################ | |
# GetFunctions | |
############################ | |
async def get_functions(user=Depends(get_verified_user)): | |
return Functions.get_functions() | |
############################ | |
# ExportFunctions | |
############################ | |
async def get_functions(user=Depends(get_admin_user)): | |
return Functions.get_functions() | |
############################ | |
# LoadFunctionFromLink | |
############################ | |
class LoadUrlForm(BaseModel): | |
url: HttpUrl | |
def github_url_to_raw_url(url: str) -> str: | |
# Handle 'tree' (folder) URLs (add main.py at the end) | |
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url) | |
if m1: | |
org, repo, branch, path = m1.groups() | |
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py" | |
# Handle 'blob' (file) URLs | |
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url) | |
if m2: | |
org, repo, branch, path = m2.groups() | |
return ( | |
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}" | |
) | |
# No match; return as-is | |
return url | |
async def load_function_from_url( | |
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user) | |
): | |
# NOTE: This is NOT a SSRF vulnerability: | |
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use, | |
# and does NOT accept untrusted user input. Access is enforced by authentication. | |
url = str(form_data.url) | |
if not url: | |
raise HTTPException(status_code=400, detail="Please enter a valid URL") | |
url = github_url_to_raw_url(url) | |
url_parts = url.rstrip("/").split("/") | |
file_name = url_parts[-1] | |
function_name = ( | |
file_name[:-3] | |
if ( | |
file_name.endswith(".py") | |
and (not file_name.startswith(("main.py", "index.py", "__init__.py"))) | |
) | |
else url_parts[-2] if len(url_parts) > 1 else "function" | |
) | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.get( | |
url, headers={"Content-Type": "application/json"} | |
) as resp: | |
if resp.status != 200: | |
raise HTTPException( | |
status_code=resp.status, detail="Failed to fetch the function" | |
) | |
data = await resp.text() | |
if not data: | |
raise HTTPException( | |
status_code=400, detail="No data received from the URL" | |
) | |
return { | |
"name": function_name, | |
"content": data, | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error importing function: {e}") | |
############################ | |
# SyncFunctions | |
############################ | |
class SyncFunctionsForm(FunctionForm): | |
functions: list[FunctionModel] = [] | |
async def sync_functions( | |
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user) | |
): | |
return Functions.sync_functions(user.id, form_data.functions) | |
############################ | |
# CreateNewFunction | |
############################ | |
async def create_new_function( | |
request: Request, form_data: FunctionForm, user=Depends(get_admin_user) | |
): | |
if not form_data.id.isidentifier(): | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Only alphanumeric characters and underscores are allowed in the id", | |
) | |
form_data.id = form_data.id.lower() | |
function = Functions.get_function_by_id(form_data.id) | |
if function is None: | |
try: | |
form_data.content = replace_imports(form_data.content) | |
function_module, function_type, frontmatter = load_function_module_by_id( | |
form_data.id, | |
content=form_data.content, | |
) | |
form_data.meta.manifest = frontmatter | |
FUNCTIONS = request.app.state.FUNCTIONS | |
FUNCTIONS[form_data.id] = function_module | |
function = Functions.insert_new_function(user.id, function_type, form_data) | |
function_cache_dir = CACHE_DIR / "functions" / form_data.id | |
function_cache_dir.mkdir(parents=True, exist_ok=True) | |
if function: | |
return function | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT("Error creating function"), | |
) | |
except Exception as e: | |
log.exception(f"Failed to create a new function: {e}") | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT(e), | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.ID_TAKEN, | |
) | |
############################ | |
# GetFunctionById | |
############################ | |
async def get_function_by_id(id: str, user=Depends(get_admin_user)): | |
function = Functions.get_function_by_id(id) | |
if function: | |
return function | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
############################ | |
# ToggleFunctionById | |
############################ | |
async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): | |
function = Functions.get_function_by_id(id) | |
if function: | |
function = Functions.update_function_by_id( | |
id, {"is_active": not function.is_active} | |
) | |
if function: | |
return function | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT("Error updating function"), | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
############################ | |
# ToggleGlobalById | |
############################ | |
async def toggle_global_by_id(id: str, user=Depends(get_admin_user)): | |
function = Functions.get_function_by_id(id) | |
if function: | |
function = Functions.update_function_by_id( | |
id, {"is_global": not function.is_global} | |
) | |
if function: | |
return function | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT("Error updating function"), | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
############################ | |
# UpdateFunctionById | |
############################ | |
async def update_function_by_id( | |
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user) | |
): | |
try: | |
form_data.content = replace_imports(form_data.content) | |
function_module, function_type, frontmatter = load_function_module_by_id( | |
id, content=form_data.content | |
) | |
form_data.meta.manifest = frontmatter | |
FUNCTIONS = request.app.state.FUNCTIONS | |
FUNCTIONS[id] = function_module | |
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} | |
log.debug(updated) | |
function = Functions.update_function_by_id(id, updated) | |
if function: | |
return function | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT("Error updating function"), | |
) | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT(e), | |
) | |
############################ | |
# DeleteFunctionById | |
############################ | |
async def delete_function_by_id( | |
request: Request, id: str, user=Depends(get_admin_user) | |
): | |
result = Functions.delete_function_by_id(id) | |
if result: | |
FUNCTIONS = request.app.state.FUNCTIONS | |
if id in FUNCTIONS: | |
del FUNCTIONS[id] | |
return result | |
############################ | |
# GetFunctionValves | |
############################ | |
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)): | |
function = Functions.get_function_by_id(id) | |
if function: | |
try: | |
valves = Functions.get_function_valves_by_id(id) | |
return valves | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT(e), | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
############################ | |
# GetFunctionValvesSpec | |
############################ | |
async def get_function_valves_spec_by_id( | |
request: Request, id: str, user=Depends(get_admin_user) | |
): | |
function = Functions.get_function_by_id(id) | |
if function: | |
function_module, function_type, frontmatter = get_function_module_from_cache( | |
request, id | |
) | |
if hasattr(function_module, "Valves"): | |
Valves = function_module.Valves | |
return Valves.schema() | |
return None | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
############################ | |
# UpdateFunctionValves | |
############################ | |
async def update_function_valves_by_id( | |
request: Request, id: str, form_data: dict, user=Depends(get_admin_user) | |
): | |
function = Functions.get_function_by_id(id) | |
if function: | |
function_module, function_type, frontmatter = get_function_module_from_cache( | |
request, id | |
) | |
if hasattr(function_module, "Valves"): | |
Valves = function_module.Valves | |
try: | |
form_data = {k: v for k, v in form_data.items() if v is not None} | |
valves = Valves(**form_data) | |
Functions.update_function_valves_by_id(id, valves.model_dump()) | |
return valves.model_dump() | |
except Exception as e: | |
log.exception(f"Error updating function values by id {id}: {e}") | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT(e), | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
############################ | |
# FunctionUserValves | |
############################ | |
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)): | |
function = Functions.get_function_by_id(id) | |
if function: | |
try: | |
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id) | |
return user_valves | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT(e), | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
async def get_function_user_valves_spec_by_id( | |
request: Request, id: str, user=Depends(get_verified_user) | |
): | |
function = Functions.get_function_by_id(id) | |
if function: | |
function_module, function_type, frontmatter = get_function_module_from_cache( | |
request, id | |
) | |
if hasattr(function_module, "UserValves"): | |
UserValves = function_module.UserValves | |
return UserValves.schema() | |
return None | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
async def update_function_user_valves_by_id( | |
request: Request, id: str, form_data: dict, user=Depends(get_verified_user) | |
): | |
function = Functions.get_function_by_id(id) | |
if function: | |
function_module, function_type, frontmatter = get_function_module_from_cache( | |
request, id | |
) | |
if hasattr(function_module, "UserValves"): | |
UserValves = function_module.UserValves | |
try: | |
form_data = {k: v for k, v in form_data.items() if v is not None} | |
user_valves = UserValves(**form_data) | |
Functions.update_user_valves_by_id_and_user_id( | |
id, user.id, user_valves.model_dump() | |
) | |
return user_valves.model_dump() | |
except Exception as e: | |
log.exception(f"Error updating function user valves by id {id}: {e}") | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT(e), | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |