embedding_function.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. import requests
  2. import time
  3. import numpy as np
  4. from typing import List, Callable
  5. class EmbeddingFunction:
  6. def __init__(self, model_name: str, api_key: str, base_url: str, embedding_dimension: int):
  7. self.model_name = model_name
  8. self.api_key = api_key
  9. self.base_url = base_url
  10. self.embedding_dimension = embedding_dimension
  11. self.headers = {
  12. "Authorization": f"Bearer {api_key}",
  13. "Content-Type": "application/json"
  14. }
  15. self.max_retries = 2 # 设置默认的最大重试次数
  16. self.retry_interval = 2 # 设置默认的重试间隔秒数
  17. self.normalize_embeddings = True # 设置默认是否归一化
  18. def _normalize_vector(self, vector: List[float]) -> List[float]:
  19. """
  20. 对向量进行L2归一化
  21. Args:
  22. vector: 输入向量
  23. Returns:
  24. List[float]: 归一化后的向量
  25. """
  26. if not vector:
  27. return []
  28. norm = np.linalg.norm(vector)
  29. if norm == 0:
  30. return vector
  31. return (np.array(vector) / norm).tolist()
  32. def __call__(self, input) -> List[List[float]]:
  33. """
  34. 为文本列表生成嵌入向量
  35. Args:
  36. input: 要嵌入的文本或文本列表
  37. Returns:
  38. List[List[float]]: 嵌入向量列表
  39. """
  40. if not isinstance(input, list):
  41. input = [input]
  42. embeddings = []
  43. for text in input:
  44. payload = {
  45. "model": self.model_name,
  46. "input": text,
  47. "encoding_format": "float"
  48. }
  49. try:
  50. # 修复URL拼接问题
  51. url = self.base_url
  52. if not url.endswith("/embeddings"):
  53. url = url.rstrip("/") # 移除尾部斜杠,避免双斜杠
  54. if not url.endswith("/v1/embeddings"):
  55. url = f"{url}/embeddings"
  56. response = requests.post(url, json=payload, headers=self.headers)
  57. response.raise_for_status()
  58. result = response.json()
  59. if "data" in result and len(result["data"]) > 0:
  60. vector = result["data"][0]["embedding"]
  61. embeddings.append(vector)
  62. else:
  63. raise ValueError(f"API返回无效: {result}")
  64. except Exception as e:
  65. print(f"获取embedding时出错: {e}")
  66. # 使用实例的 embedding_dimension 来创建零向量
  67. embeddings.append([0.0] * self.embedding_dimension)
  68. return embeddings
  69. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  70. """
  71. 为文档列表生成嵌入向量 (LangChain 接口)
  72. Args:
  73. texts: 要嵌入的文档列表
  74. Returns:
  75. List[List[float]]: 嵌入向量列表
  76. """
  77. return self.__call__(texts)
  78. def embed_query(self, text: str) -> List[float]:
  79. """
  80. 为查询文本生成嵌入向量 (LangChain 接口)
  81. Args:
  82. text: 要嵌入的查询文本
  83. Returns:
  84. List[float]: 嵌入向量
  85. """
  86. return self.generate_embedding(text)
  87. def generate_embedding(self, text: str) -> List[float]:
  88. """
  89. 为单个文本生成嵌入向量
  90. Args:
  91. text (str): 要嵌入的文本
  92. Returns:
  93. List[float]: 嵌入向量
  94. """
  95. print(f"生成嵌入向量,文本长度: {len(text)} 字符")
  96. # 处理空文本
  97. if not text or len(text.strip()) == 0:
  98. print("输入文本为空,返回零向量")
  99. # self.embedding_dimension 在初始化时已被强制要求
  100. # 因此不应该为 None 或需要默认值
  101. if self.embedding_dimension is None:
  102. # 这个分支理论上不应该被执行,因为工厂函数会确保 embedding_dimension 已设置
  103. # 但为了健壮性,如果它意外地是 None,则抛出错误
  104. raise ValueError("Embedding dimension (self.embedding_dimension) 未被正确初始化。")
  105. return [0.0] * self.embedding_dimension
  106. # 准备请求体
  107. payload = {
  108. "model": self.model_name,
  109. "input": text,
  110. "encoding_format": "float"
  111. }
  112. # 添加重试机制
  113. retries = 0
  114. while retries <= self.max_retries:
  115. try:
  116. # 发送API请求
  117. url = self.base_url
  118. if not url.endswith("/embeddings"):
  119. url = url.rstrip("/") # 移除尾部斜杠,避免双斜杠
  120. if not url.endswith("/v1/embeddings"):
  121. url = f"{url}/embeddings"
  122. print(f"请求URL: {url}")
  123. response = requests.post(
  124. url,
  125. json=payload,
  126. headers=self.headers,
  127. timeout=30 # 设置超时时间
  128. )
  129. # 检查响应状态
  130. if response.status_code != 200:
  131. error_msg = f"API请求错误: {response.status_code}, {response.text}"
  132. print(error_msg)
  133. # 根据错误码判断是否需要重试
  134. if response.status_code in (429, 500, 502, 503, 504):
  135. retries += 1
  136. if retries <= self.max_retries:
  137. wait_time = self.retry_interval * (2 ** (retries - 1)) # 指数退避
  138. print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
  139. time.sleep(wait_time)
  140. continue
  141. raise ValueError(error_msg)
  142. # 解析响应
  143. result = response.json()
  144. # 提取embedding向量
  145. if "data" in result and len(result["data"]) > 0 and "embedding" in result["data"][0]:
  146. vector = result["data"][0]["embedding"]
  147. # 如果是首次调用且未提供维度,则自动设置
  148. if self.embedding_dimension is None:
  149. self.embedding_dimension = len(vector)
  150. print(f"自动设置embedding维度为: {self.embedding_dimension}")
  151. else:
  152. # 验证向量维度
  153. actual_dim = len(vector)
  154. if actual_dim != self.embedding_dimension:
  155. print(f"向量维度不匹配: 期望 {self.embedding_dimension}, 实际 {actual_dim}")
  156. # 如果需要归一化
  157. if self.normalize_embeddings:
  158. vector = self._normalize_vector(vector)
  159. print(f"成功生成embedding向量,维度: {len(vector)}")
  160. return vector
  161. else:
  162. error_msg = f"API返回格式异常: {result}"
  163. print(error_msg)
  164. raise ValueError(error_msg)
  165. except Exception as e:
  166. print(f"生成embedding时出错: {str(e)}")
  167. retries += 1
  168. if retries <= self.max_retries:
  169. wait_time = self.retry_interval * (2 ** (retries - 1)) # 指数退避
  170. print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
  171. time.sleep(wait_time)
  172. else:
  173. print(f"已达到最大重试次数 ({self.max_retries}),生成embedding失败")
  174. # 决定是返回零向量还是重新抛出异常
  175. if self.embedding_dimension:
  176. print(f"返回零向量 (维度: {self.embedding_dimension})")
  177. return [0.0] * self.embedding_dimension
  178. raise
  179. # 这里不应该到达,但为了完整性添加
  180. raise RuntimeError("生成embedding失败")
  181. def test_connection(self, test_text="测试文本") -> dict:
  182. """
  183. 测试嵌入模型的连接和功能
  184. Args:
  185. test_text (str): 用于测试的文本
  186. Returns:
  187. dict: 包含测试结果的字典,包括是否成功、维度信息等
  188. """
  189. result = {
  190. "success": False,
  191. "model": self.model_name,
  192. "base_url": self.base_url,
  193. "message": "",
  194. "actual_dimension": None,
  195. "expected_dimension": self.embedding_dimension
  196. }
  197. try:
  198. print(f"测试嵌入模型连接 - 模型: {self.model_name}")
  199. print(f"API服务地址: {self.base_url}")
  200. # 验证配置
  201. if not self.api_key:
  202. result["message"] = "API密钥未设置或为空"
  203. return result
  204. if not self.base_url:
  205. result["message"] = "API服务地址未设置或为空"
  206. return result
  207. # 测试生成向量
  208. vector = self.generate_embedding(test_text)
  209. actual_dimension = len(vector)
  210. result["success"] = True
  211. result["actual_dimension"] = actual_dimension
  212. # 检查维度是否一致
  213. if actual_dimension != self.embedding_dimension:
  214. result["message"] = f"警告: 模型实际生成的向量维度({actual_dimension})与配置维度({self.embedding_dimension})不一致"
  215. else:
  216. result["message"] = f"连接测试成功,向量维度: {actual_dimension}"
  217. return result
  218. except Exception as e:
  219. result["message"] = f"连接测试失败: {str(e)}"
  220. return result
  221. def test_embedding_connection() -> dict:
  222. """
  223. 测试嵌入模型连接和配置是否正确
  224. Returns:
  225. dict: 测试结果,包括成功/失败状态、错误消息等
  226. """
  227. try:
  228. # 获取嵌入函数实例
  229. embedding_function = get_embedding_function()
  230. # 测试连接
  231. test_result = embedding_function.test_connection()
  232. if test_result["success"]:
  233. print(f"嵌入模型连接测试成功!")
  234. if "警告" in test_result["message"]:
  235. print(test_result["message"])
  236. print(f"建议将app_config.py中的EMBEDDING_CONFIG['embedding_dimension']修改为{test_result['actual_dimension']}")
  237. else:
  238. print(f"嵌入模型连接测试失败: {test_result['message']}")
  239. return test_result
  240. except Exception as e:
  241. error_message = f"无法测试嵌入模型连接: {str(e)}"
  242. print(error_message)
  243. return {
  244. "success": False,
  245. "message": error_message
  246. }
  247. def get_embedding_function():
  248. """
  249. 根据当前配置创建合适的EmbeddingFunction实例
  250. 支持API和Ollama两种提供商
  251. Returns:
  252. EmbeddingFunction或OllamaEmbeddingFunction: 根据配置类型返回相应的实例
  253. Raises:
  254. ImportError: 如果无法导入必要的模块
  255. ValueError: 如果配置无效
  256. """
  257. try:
  258. from common.utils import get_current_embedding_config, is_using_ollama_embedding
  259. except ImportError:
  260. raise ImportError("无法导入 common.utils,请确保该文件存在")
  261. # 获取当前embedding配置
  262. embedding_config = get_current_embedding_config()
  263. if is_using_ollama_embedding():
  264. # 使用Ollama Embedding
  265. try:
  266. from customollama.ollama_embedding import OllamaEmbeddingFunction
  267. except ImportError:
  268. raise ImportError("无法导入 OllamaEmbeddingFunction,请确保 customollama 包存在")
  269. return OllamaEmbeddingFunction(
  270. model_name=embedding_config["model_name"],
  271. base_url=embedding_config["base_url"],
  272. embedding_dimension=embedding_config["embedding_dimension"]
  273. )
  274. else:
  275. # 使用API Embedding
  276. try:
  277. api_key = embedding_config["api_key"]
  278. model_name = embedding_config["model_name"]
  279. base_url = embedding_config["base_url"]
  280. embedding_dimension = embedding_config["embedding_dimension"]
  281. if api_key is None:
  282. raise KeyError("API模式下 'api_key' 未设置 (可能环境变量 EMBEDDING_API_KEY 未定义)")
  283. except KeyError as e:
  284. raise KeyError(f"API Embedding配置中缺少必要的键或值无效:{e}")
  285. return EmbeddingFunction(
  286. model_name=model_name,
  287. api_key=api_key,
  288. base_url=base_url,
  289. embedding_dimension=embedding_dimension
  290. )