Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files
server/security/notebook_training_gr.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
server/security/prompt_guard.py
CHANGED
|
@@ -25,7 +25,6 @@ def get_embedding(documents: list[str]) -> NDArray[np.float32]:
|
|
| 25 |
return model.encode(documents)
|
| 26 |
|
| 27 |
|
| 28 |
-
|
| 29 |
class Guardrail:
|
| 30 |
"""
|
| 31 |
A class to handle guardrail analysis based on query embeddings.
|
|
@@ -38,11 +37,11 @@ class Guardrail:
|
|
| 38 |
"""
|
| 39 |
Initializes the Guardrail class with a guardrail model instance.
|
| 40 |
"""
|
| 41 |
-
file_path = os.path.join("server","security","storage","guardrail_multi.pkl")
|
| 42 |
with open(file_path, "rb") as f:
|
| 43 |
self.guardrail = load(f)
|
| 44 |
|
| 45 |
-
def analyze_language(self, query:str) -> bool:
|
| 46 |
"""
|
| 47 |
Analyzes the given query to determine what language it is written in and whether it is english, french, german or spanish.
|
| 48 |
|
|
@@ -53,8 +52,8 @@ class Guardrail:
|
|
| 53 |
bool: Returns `False` if the query is not a supported language, `True` otherwise.
|
| 54 |
"""
|
| 55 |
det = detect(query)
|
| 56 |
-
return det in ["en","fr","de","es"]
|
| 57 |
-
|
| 58 |
def analyze_query(self, query: str) -> bool:
|
| 59 |
"""
|
| 60 |
Analyzes the given query to determine if it passes the guardrail check.
|
|
@@ -68,7 +67,6 @@ class Guardrail:
|
|
| 68 |
embed_query = get_embedding(documents=[query])
|
| 69 |
pred = self.guardrail.predict(embed_query.reshape(1, -1))
|
| 70 |
return pred != 1 # Return True if pred is not 1, otherwise False
|
| 71 |
-
|
| 72 |
|
| 73 |
def incremental_learning(self, X_new, y_new):
|
| 74 |
"""
|
|
@@ -80,9 +78,11 @@ class Guardrail:
|
|
| 80 |
"""
|
| 81 |
# Extraction des caractéristiques
|
| 82 |
embedding = model.encode(X_new)
|
| 83 |
-
|
| 84 |
# Mise à jour incrémentale du modèle
|
| 85 |
-
self.guardrail.partial_fit(embedding, y_new, classes=[0, 1])
|
| 86 |
|
| 87 |
-
with open(
|
|
|
|
|
|
|
| 88 |
dump(self.guardrail, f)
|
|
|
|
| 25 |
return model.encode(documents)
|
| 26 |
|
| 27 |
|
|
|
|
| 28 |
class Guardrail:
|
| 29 |
"""
|
| 30 |
A class to handle guardrail analysis based on query embeddings.
|
|
|
|
| 37 |
"""
|
| 38 |
Initializes the Guardrail class with a guardrail model instance.
|
| 39 |
"""
|
| 40 |
+
file_path = os.path.join("server", "security", "storage", "guardrail_multi.pkl")
|
| 41 |
with open(file_path, "rb") as f:
|
| 42 |
self.guardrail = load(f)
|
| 43 |
|
| 44 |
+
def analyze_language(self, query: str) -> bool:
|
| 45 |
"""
|
| 46 |
Analyzes the given query to determine what language it is written in and whether it is english, french, german or spanish.
|
| 47 |
|
|
|
|
| 52 |
bool: Returns `False` if the query is not a supported language, `True` otherwise.
|
| 53 |
"""
|
| 54 |
det = detect(query)
|
| 55 |
+
return det in ["en", "fr", "de", "es"]
|
| 56 |
+
|
| 57 |
def analyze_query(self, query: str) -> bool:
|
| 58 |
"""
|
| 59 |
Analyzes the given query to determine if it passes the guardrail check.
|
|
|
|
| 67 |
embed_query = get_embedding(documents=[query])
|
| 68 |
pred = self.guardrail.predict(embed_query.reshape(1, -1))
|
| 69 |
return pred != 1 # Return True if pred is not 1, otherwise False
|
|
|
|
| 70 |
|
| 71 |
def incremental_learning(self, X_new, y_new):
|
| 72 |
"""
|
|
|
|
| 78 |
"""
|
| 79 |
# Extraction des caractéristiques
|
| 80 |
embedding = model.encode(X_new)
|
| 81 |
+
|
| 82 |
# Mise à jour incrémentale du modèle
|
| 83 |
+
self.guardrail.partial_fit(embedding.reshape(1, -1), y_new, classes=[0, 1])
|
| 84 |
|
| 85 |
+
with open(
|
| 86 |
+
os.path.join("server", "security", "storage", "guardrail_multi.pkl"), "wb"
|
| 87 |
+
) as f:
|
| 88 |
dump(self.guardrail, f)
|