File size: 2,483 Bytes
48b647e
6c6abf4
 
 
48b647e
 
 
 
6c6abf4
 
 
 
 
 
48b647e
6c6abf4
e9d3b63
48b647e
 
 
 
6c6abf4
48b647e
 
 
6c6abf4
48b647e
 
 
 
 
 
6c6abf4
 
48b647e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c6abf4
e9d3b63
 
6c6abf4
 
48b647e
0f764c5
 
 
 
 
48b647e
6c6abf4
 
0f764c5
6c6abf4
48b647e
0f764c5
6c6abf4
 
48b647e
6c6abf4
 
0f764c5
 
6c6abf4
 
 
 
48b647e
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Optional, Any

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from joblib import load

from models.iris import Iris


class Model(BaseModel):
    id: int
    name: str
    param_count: Optional[int] = None
    _model : Optional[Any] = None


models = {
    "0" : Model(id=0, name="CNN"), 
    "1" : Model(id=1, name="Transformer"),
    "2" : Model(id=2, name="Iris"),
}
id_2_hosted_models = {
    model.id : model for model in models.values()
    }
model_names_2_id = { 
    model.name.lower() : model.id for model in models.values()
}

#TODO: fix this mess ^^
ml_models = {
    model.name : model for model in models.values()
}

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Load the ML model
    ml_models["Iris"]._model = load('models/iris_v1.joblib')
    yield
    # Clean up the ML models and release the resources
    ml_models.clear()


################################################################


app = FastAPI(
    title="ML Repository API", 
    description="API for getting predictions from hosted ML models.", 
    version="0.0.1",
    lifespan=lifespan)


@app.get("/")
def greet_json():
    return {"Hello World": "Welcome to my ML Repository API!"}


@app.get("/hosted")
def list_models():
    "List all the hosted models."
    return models


@app.get("/hosted/id/{model_id}")
def get_by_id(model_id: int):
    "Get a specific model by its ID."
    if model_id not in id_2_hosted_models:
        raise HTTPException(status_code=404, detail=f"Model with 'id={model_id}' not found")

    return id_2_hosted_models[model_id]


@app.get("/hosted/name/{model_name}")
def get_by_name(model_name: str):
    "Get a specific model by its name."

    model_name = model_name.lower()
    if model_name not in model_names_2_id:
        raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
    
    return id_2_hosted_models[model_names_2_id[model_name]]


@app.post("/hosted/name/{model_name}/predict", tags=["Predictions"])
async def get_prediction(model_name: str, iris: Iris):
    
    if model_name.lower() != "iris":
        raise HTTPException(status_code=501, detail="Not implemented yet.")
    
    data = dict(iris)['data']
    prediction = ml_models["Iris"]._model.predict(data).tolist()
    log_probs = ml_models["Iris"]._model.predict_proba(data).tolist()
    return {"predictions": prediction,
            "log_probs": log_probs}