Daniel.y commited on
Commit
8365801
·
unverified ·
2 Parent(s): 6e4a904 e097a6a

Merge pull request #1654 from a-bruhn/azure-env-vars

Browse files
Files changed (2) hide show
  1. env.example +2 -0
  2. lightrag/llm/azure_openai.py +52 -28
env.example CHANGED
@@ -103,6 +103,8 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
103
  ### Optional for Azure
104
  # AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
105
  # AZURE_EMBEDDING_API_VERSION=2023-05-15
 
 
106
 
107
  ### Data storage selection
108
  # LIGHTRAG_KV_STORAGE=PGKVStorage
 
103
  ### Optional for Azure
104
  # AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
105
  # AZURE_EMBEDDING_API_VERSION=2023-05-15
106
+ # AZURE_EMBEDDING_ENDPOINT=your_endpoint
107
+ # AZURE_EMBEDDING_API_KEY=your_api_key
108
 
109
  ### Data storage selection
110
  # LIGHTRAG_KV_STORAGE=PGKVStorage
lightrag/llm/azure_openai.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import pipmaster as pm # Pipmaster for dynamic library install
3
 
@@ -11,6 +12,8 @@ from openai import (
11
  RateLimitError,
12
  APITimeoutError,
13
  )
 
 
14
  from tenacity import (
15
  retry,
16
  stop_after_attempt,
@@ -37,31 +40,38 @@ import numpy as np
37
  async def azure_openai_complete_if_cache(
38
  model,
39
  prompt,
40
- system_prompt=None,
41
- history_messages=[],
42
- base_url=None,
43
- api_key=None,
44
- api_version=None,
45
  **kwargs,
46
  ):
47
- if api_key:
48
- os.environ["AZURE_OPENAI_API_KEY"] = api_key
49
- if base_url:
50
- os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
51
- if api_version:
52
- os.environ["AZURE_OPENAI_API_VERSION"] = api_version
 
 
 
 
 
 
53
 
54
  openai_async_client = AsyncAzureOpenAI(
55
- azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
56
  azure_deployment=model,
57
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
58
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
59
  )
60
  kwargs.pop("hashing_kv", None)
61
  messages = []
62
  if system_prompt:
63
  messages.append({"role": "system", "content": system_prompt})
64
- messages.extend(history_messages)
 
65
  if prompt is not None:
66
  messages.append({"role": "user", "content": prompt})
67
 
@@ -121,23 +131,37 @@ async def azure_openai_complete(
121
  )
122
  async def azure_openai_embed(
123
  texts: list[str],
124
- model: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"),
125
- base_url: str = None,
126
- api_key: str = None,
127
- api_version: str = None,
128
  ) -> np.ndarray:
129
- if api_key:
130
- os.environ["AZURE_OPENAI_API_KEY"] = api_key
131
- if base_url:
132
- os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
133
- if api_version:
134
- os.environ["AZURE_OPENAI_API_VERSION"] = api_version
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  openai_async_client = AsyncAzureOpenAI(
137
- azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
138
  azure_deployment=model,
139
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
140
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
141
  )
142
 
143
  response = await openai_async_client.embeddings.create(
 
1
+ from collections.abc import Iterable
2
  import os
3
  import pipmaster as pm # Pipmaster for dynamic library install
4
 
 
12
  RateLimitError,
13
  APITimeoutError,
14
  )
15
+ from openai.types.chat import ChatCompletionMessageParam
16
+
17
  from tenacity import (
18
  retry,
19
  stop_after_attempt,
 
40
  async def azure_openai_complete_if_cache(
41
  model,
42
  prompt,
43
+ system_prompt: str | None = None,
44
+ history_messages: Iterable[ChatCompletionMessageParam] | None = None,
45
+ base_url: str | None = None,
46
+ api_key: str | None = None,
47
+ api_version: str | None = None,
48
  **kwargs,
49
  ):
50
+ model = model or os.getenv("AZURE_OPENAI_DEPLOYMENT") or os.getenv("LLM_MODEL")
51
+ base_url = (
52
+ base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")
53
+ )
54
+ api_key = (
55
+ api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY")
56
+ )
57
+ api_version = (
58
+ api_version
59
+ or os.getenv("AZURE_OPENAI_API_VERSION")
60
+ or os.getenv("OPENAI_API_VERSION")
61
+ )
62
 
63
  openai_async_client = AsyncAzureOpenAI(
64
+ azure_endpoint=base_url,
65
  azure_deployment=model,
66
+ api_key=api_key,
67
+ api_version=api_version,
68
  )
69
  kwargs.pop("hashing_kv", None)
70
  messages = []
71
  if system_prompt:
72
  messages.append({"role": "system", "content": system_prompt})
73
+ if history_messages:
74
+ messages.extend(history_messages)
75
  if prompt is not None:
76
  messages.append({"role": "user", "content": prompt})
77
 
 
131
  )
132
  async def azure_openai_embed(
133
  texts: list[str],
134
+ model: str | None = None,
135
+ base_url: str | None = None,
136
+ api_key: str | None = None,
137
+ api_version: str | None = None,
138
  ) -> np.ndarray:
139
+ model = (
140
+ model
141
+ or os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
142
+ or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
143
+ )
144
+ base_url = (
145
+ base_url
146
+ or os.getenv("AZURE_EMBEDDING_ENDPOINT")
147
+ or os.getenv("EMBEDDING_BINDING_HOST")
148
+ )
149
+ api_key = (
150
+ api_key
151
+ or os.getenv("AZURE_EMBEDDING_API_KEY")
152
+ or os.getenv("EMBEDDING_BINDING_API_KEY")
153
+ )
154
+ api_version = (
155
+ api_version
156
+ or os.getenv("AZURE_EMBEDDING_API_VERSION")
157
+ or os.getenv("OPENAI_API_VERSION")
158
+ )
159
 
160
  openai_async_client = AsyncAzureOpenAI(
161
+ azure_endpoint=base_url,
162
  azure_deployment=model,
163
+ api_key=api_key,
164
+ api_version=api_version,
165
  )
166
 
167
  response = await openai_async_client.embeddings.create(