123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345 |
- import requests
- import time
- import numpy as np
- from typing import List, Callable
- class EmbeddingFunction:
- def __init__(self, model_name: str, api_key: str, base_url: str, embedding_dimension: int):
- self.model_name = model_name
- self.api_key = api_key
- self.base_url = base_url
- self.embedding_dimension = embedding_dimension
- self.headers = {
- "Authorization": f"Bearer {api_key}",
- "Content-Type": "application/json"
- }
- self.max_retries = 2 # 设置默认的最大重试次数
- self.retry_interval = 2 # 设置默认的重试间隔秒数
- self.normalize_embeddings = True # 设置默认是否归一化
- def _normalize_vector(self, vector: List[float]) -> List[float]:
- """
- 对向量进行L2归一化
- Args:
- vector: 输入向量
- Returns:
- List[float]: 归一化后的向量
- """
- if not vector:
- return []
- norm = np.linalg.norm(vector)
- if norm == 0:
- return vector
- return (np.array(vector) / norm).tolist()
-
- def __call__(self, input) -> List[List[float]]:
- """
- 为文本列表生成嵌入向量
-
- Args:
- input: 要嵌入的文本或文本列表
-
- Returns:
- List[List[float]]: 嵌入向量列表
- """
- if not isinstance(input, list):
- input = [input]
-
- embeddings = []
- for text in input:
- payload = {
- "model": self.model_name,
- "input": text,
- "encoding_format": "float"
- }
-
- try:
- # 修复URL拼接问题
- url = self.base_url
- if not url.endswith("/embeddings"):
- url = url.rstrip("/") # 移除尾部斜杠,避免双斜杠
- if not url.endswith("/v1/embeddings"):
- url = f"{url}/embeddings"
-
- response = requests.post(url, json=payload, headers=self.headers)
- response.raise_for_status()
-
- result = response.json()
-
- if "data" in result and len(result["data"]) > 0:
- vector = result["data"][0]["embedding"]
- embeddings.append(vector)
- else:
- raise ValueError(f"API返回无效: {result}")
-
- except Exception as e:
- print(f"获取embedding时出错: {e}")
- # 使用实例的 embedding_dimension 来创建零向量
- embeddings.append([0.0] * self.embedding_dimension)
-
- return embeddings
-
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
- """
- 为文档列表生成嵌入向量 (LangChain 接口)
-
- Args:
- texts: 要嵌入的文档列表
-
- Returns:
- List[List[float]]: 嵌入向量列表
- """
- return self.__call__(texts)
-
- def embed_query(self, text: str) -> List[float]:
- """
- 为查询文本生成嵌入向量 (LangChain 接口)
-
- Args:
- text: 要嵌入的查询文本
-
- Returns:
- List[float]: 嵌入向量
- """
- return self.generate_embedding(text)
-
- def generate_embedding(self, text: str) -> List[float]:
- """
- 为单个文本生成嵌入向量
-
- Args:
- text (str): 要嵌入的文本
-
- Returns:
- List[float]: 嵌入向量
- """
- print(f"生成嵌入向量,文本长度: {len(text)} 字符")
-
- # 处理空文本
- if not text or len(text.strip()) == 0:
- print("输入文本为空,返回零向量")
- # self.embedding_dimension 在初始化时已被强制要求
- # 因此不应该为 None 或需要默认值
- if self.embedding_dimension is None:
- # 这个分支理论上不应该被执行,因为工厂函数会确保 embedding_dimension 已设置
- # 但为了健壮性,如果它意外地是 None,则抛出错误
- raise ValueError("Embedding dimension (self.embedding_dimension) 未被正确初始化。")
- return [0.0] * self.embedding_dimension
-
- # 准备请求体
- payload = {
- "model": self.model_name,
- "input": text,
- "encoding_format": "float"
- }
-
- # 添加重试机制
- retries = 0
- while retries <= self.max_retries:
- try:
- # 发送API请求
- url = self.base_url
- if not url.endswith("/embeddings"):
- url = url.rstrip("/") # 移除尾部斜杠,避免双斜杠
- if not url.endswith("/v1/embeddings"):
- url = f"{url}/embeddings"
- print(f"请求URL: {url}")
-
- response = requests.post(
- url,
- json=payload,
- headers=self.headers,
- timeout=30 # 设置超时时间
- )
-
- # 检查响应状态
- if response.status_code != 200:
- error_msg = f"API请求错误: {response.status_code}, {response.text}"
- print(error_msg)
-
- # 根据错误码判断是否需要重试
- if response.status_code in (429, 500, 502, 503, 504):
- retries += 1
- if retries <= self.max_retries:
- wait_time = self.retry_interval * (2 ** (retries - 1)) # 指数退避
- print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
- time.sleep(wait_time)
- continue
-
- raise ValueError(error_msg)
-
- # 解析响应
- result = response.json()
-
- # 提取embedding向量
- if "data" in result and len(result["data"]) > 0 and "embedding" in result["data"][0]:
- vector = result["data"][0]["embedding"]
-
- # 如果是首次调用且未提供维度,则自动设置
- if self.embedding_dimension is None:
- self.embedding_dimension = len(vector)
- print(f"自动设置embedding维度为: {self.embedding_dimension}")
- else:
- # 验证向量维度
- actual_dim = len(vector)
- if actual_dim != self.embedding_dimension:
- print(f"向量维度不匹配: 期望 {self.embedding_dimension}, 实际 {actual_dim}")
-
- # 如果需要归一化
- if self.normalize_embeddings:
- vector = self._normalize_vector(vector)
-
- print(f"成功生成embedding向量,维度: {len(vector)}")
- return vector
- else:
- error_msg = f"API返回格式异常: {result}"
- print(error_msg)
- raise ValueError(error_msg)
-
- except Exception as e:
- print(f"生成embedding时出错: {str(e)}")
- retries += 1
-
- if retries <= self.max_retries:
- wait_time = self.retry_interval * (2 ** (retries - 1)) # 指数退避
- print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
- time.sleep(wait_time)
- else:
- print(f"已达到最大重试次数 ({self.max_retries}),生成embedding失败")
- # 决定是返回零向量还是重新抛出异常
- if self.embedding_dimension:
- print(f"返回零向量 (维度: {self.embedding_dimension})")
- return [0.0] * self.embedding_dimension
- raise
-
- # 这里不应该到达,但为了完整性添加
- raise RuntimeError("生成embedding失败")
- def test_connection(self, test_text="测试文本") -> dict:
- """
- 测试嵌入模型的连接和功能
-
- Args:
- test_text (str): 用于测试的文本
-
- Returns:
- dict: 包含测试结果的字典,包括是否成功、维度信息等
- """
- result = {
- "success": False,
- "model": self.model_name,
- "base_url": self.base_url,
- "message": "",
- "actual_dimension": None,
- "expected_dimension": self.embedding_dimension
- }
-
- try:
- print(f"测试嵌入模型连接 - 模型: {self.model_name}")
- print(f"API服务地址: {self.base_url}")
-
- # 验证配置
- if not self.api_key:
- result["message"] = "API密钥未设置或为空"
- return result
-
- if not self.base_url:
- result["message"] = "API服务地址未设置或为空"
- return result
-
- # 测试生成向量
- vector = self.generate_embedding(test_text)
- actual_dimension = len(vector)
-
- result["success"] = True
- result["actual_dimension"] = actual_dimension
-
- # 检查维度是否一致
- if actual_dimension != self.embedding_dimension:
- result["message"] = f"警告: 模型实际生成的向量维度({actual_dimension})与配置维度({self.embedding_dimension})不一致"
- else:
- result["message"] = f"连接测试成功,向量维度: {actual_dimension}"
-
- return result
-
- except Exception as e:
- result["message"] = f"连接测试失败: {str(e)}"
- return result
- def test_embedding_connection() -> dict:
- """
- 测试嵌入模型连接和配置是否正确
-
- Returns:
- dict: 测试结果,包括成功/失败状态、错误消息等
- """
- try:
- # 获取嵌入函数实例
- embedding_function = get_embedding_function()
-
- # 测试连接
- test_result = embedding_function.test_connection()
-
- if test_result["success"]:
- print(f"嵌入模型连接测试成功!")
- if "警告" in test_result["message"]:
- print(test_result["message"])
- print(f"建议将app_config.py中的EMBEDDING_CONFIG['embedding_dimension']修改为{test_result['actual_dimension']}")
- else:
- print(f"嵌入模型连接测试失败: {test_result['message']}")
-
- return test_result
-
- except Exception as e:
- error_message = f"无法测试嵌入模型连接: {str(e)}"
- print(error_message)
- return {
- "success": False,
- "message": error_message
- }
- def get_embedding_function() -> EmbeddingFunction:
- """
- 从 app_config.py 的 EMBEDDING_CONFIG 字典加载配置并创建 EmbeddingFunction 实例。
- 如果任何必需的配置未找到,则抛出异常。
- Returns:
- EmbeddingFunction: EmbeddingFunction 的实例。
- Raises:
- ImportError: 如果 app_config.py 无法导入。
- AttributeError: 如果 app_config.py 中缺少 EMBEDDING_CONFIG。
- KeyError: 如果 EMBEDDING_CONFIG 字典中缺少任何必要的键。
- """
- try:
- import app_config
- except ImportError:
- raise ImportError("无法导入 app_config.py。请确保该文件存在且在PYTHONPATH中。")
- try:
- embedding_config_dict = app_config.EMBEDDING_CONFIG
- except AttributeError:
- raise AttributeError("app_config.py 中缺少 EMBEDDING_CONFIG 配置字典。")
- try:
- api_key = embedding_config_dict["api_key"]
- model_name = embedding_config_dict["model_name"]
- base_url = embedding_config_dict["base_url"]
- embedding_dimension = embedding_config_dict["embedding_dimension"]
-
- if api_key is None:
- # 明确指出 api_key (可能来自环境变量) 未设置的问题
- raise KeyError("EMBEDDING_CONFIG 中的 'api_key' 未设置 (可能环境变量 EMBEDDING_API_KEY 未定义)。")
-
- except KeyError as e:
- # 将原始的KeyError e 作为原因传递,可以提供更详细的上下文,比如哪个键确实缺失了
- raise KeyError(f"app_config.py 的 EMBEDDING_CONFIG 字典中缺少必要的键或值无效:{e}")
- return EmbeddingFunction(
- model_name=model_name,
- api_key=api_key,
- base_url=base_url,
- embedding_dimension=embedding_dimension
- )
|