123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- import requests
- import time
- import numpy as np
- from typing import List, Callable
- class OllamaEmbeddingFunction:
- def __init__(self, model_name: str, base_url: str, embedding_dimension: int):
- self.model_name = model_name
- self.base_url = base_url
- self.embedding_dimension = embedding_dimension
- self.max_retries = 3
- self.retry_interval = 2
- def __call__(self, input) -> List[List[float]]:
- """为文本列表生成嵌入向量"""
- if not isinstance(input, list):
- input = [input]
-
- embeddings = []
- for text in input:
- try:
- embedding = self.generate_embedding(text)
- embeddings.append(embedding)
- except Exception as e:
- print(f"获取embedding时出错: {e}")
- embeddings.append([0.0] * self.embedding_dimension)
-
- return embeddings
-
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
- """为文档列表生成嵌入向量(兼容ChromaDB接口)"""
- return self.__call__(texts)
-
- def embed_query(self, text: str) -> List[float]:
- """为单个查询文本生成嵌入向量(兼容ChromaDB接口)"""
- return self.generate_embedding(text)
-
- def generate_embedding(self, text: str) -> List[float]:
- """为单个文本生成嵌入向量"""
- print(f"生成Ollama嵌入向量,文本长度: {len(text)} 字符")
-
- if not text or len(text.strip()) == 0:
- print("输入文本为空,返回零向量")
- return [0.0] * self.embedding_dimension
- url = f"{self.base_url}/api/embeddings"
- payload = {
- "model": self.model_name,
- "prompt": text
- }
-
- retries = 0
- while retries <= self.max_retries:
- try:
- response = requests.post(
- url,
- json=payload,
- timeout=30
- )
-
- if response.status_code != 200:
- error_msg = f"Ollama API请求错误: {response.status_code}, {response.text}"
- print(error_msg)
-
- if response.status_code in (429, 500, 502, 503, 504):
- retries += 1
- if retries <= self.max_retries:
- wait_time = self.retry_interval * (2 ** (retries - 1))
- print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
- time.sleep(wait_time)
- continue
-
- raise ValueError(error_msg)
-
- result = response.json()
-
- if "embedding" in result:
- vector = result["embedding"]
-
- # 验证向量维度
- actual_dim = len(vector)
- if actual_dim != self.embedding_dimension:
- print(f"向量维度不匹配: 期望 {self.embedding_dimension}, 实际 {actual_dim}")
- # 如果维度不匹配,可以选择截断或填充
- if actual_dim > self.embedding_dimension:
- vector = vector[:self.embedding_dimension]
- else:
- vector.extend([0.0] * (self.embedding_dimension - actual_dim))
-
- # 添加成功生成embedding的debug日志
- print(f"[DEBUG] ✓ 成功生成Ollama embedding向量,维度: {len(vector)}")
- return vector
- else:
- error_msg = f"Ollama API返回格式异常: {result}"
- print(error_msg)
- raise ValueError(error_msg)
-
- except Exception as e:
- print(f"生成Ollama embedding时出错: {str(e)}")
- retries += 1
-
- if retries <= self.max_retries:
- wait_time = self.retry_interval * (2 ** (retries - 1))
- print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
- time.sleep(wait_time)
- else:
- print(f"已达到最大重试次数 ({self.max_retries}),生成embedding失败")
- return [0.0] * self.embedding_dimension
-
- raise RuntimeError("生成Ollama embedding失败")
- def test_connection(self, test_text="测试文本") -> dict:
- """测试Ollama嵌入模型的连接"""
- result = {
- "success": False,
- "model": self.model_name,
- "base_url": self.base_url,
- "message": "",
- "actual_dimension": None,
- "expected_dimension": self.embedding_dimension
- }
-
- try:
- print(f"测试Ollama嵌入模型连接 - 模型: {self.model_name}")
- print(f"Ollama服务地址: {self.base_url}")
-
- vector = self.generate_embedding(test_text)
- actual_dimension = len(vector)
-
- result["success"] = True
- result["actual_dimension"] = actual_dimension
-
- if actual_dimension != self.embedding_dimension:
- result["message"] = f"警告: 模型实际生成的向量维度({actual_dimension})与配置维度({self.embedding_dimension})不一致"
- else:
- result["message"] = f"Ollama连接测试成功,向量维度: {actual_dimension}"
-
- return result
-
- except Exception as e:
- result["message"] = f"Ollama连接测试失败: {str(e)}"
- return result
|