ollama_embedding.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import requests
  2. import time
  3. import numpy as np
  4. from typing import List, Callable
  5. from core.logging import get_vanna_logger
  6. class OllamaEmbeddingFunction:
  7. def __init__(self, model_name: str, base_url: str, embedding_dimension: int):
  8. self.model_name = model_name
  9. self.base_url = base_url
  10. self.embedding_dimension = embedding_dimension
  11. self.max_retries = 3
  12. self.retry_interval = 2
  13. # 初始化日志
  14. self.logger = get_vanna_logger("OllamaEmbedding")
  15. def __call__(self, input) -> List[List[float]]:
  16. """为文本列表生成嵌入向量"""
  17. if not isinstance(input, list):
  18. input = [input]
  19. embeddings = []
  20. for text in input:
  21. try:
  22. embedding = self.generate_embedding(text)
  23. embeddings.append(embedding)
  24. except Exception as e:
  25. self.logger.error(f"获取embedding时出错: {e}")
  26. embeddings.append([0.0] * self.embedding_dimension)
  27. return embeddings
  28. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  29. """为文档列表生成嵌入向量(兼容ChromaDB接口)"""
  30. return self.__call__(texts)
  31. def embed_query(self, text: str) -> List[float]:
  32. """为单个查询文本生成嵌入向量(兼容ChromaDB接口)"""
  33. return self.generate_embedding(text)
  34. def generate_embedding(self, text: str) -> List[float]:
  35. """为单个文本生成嵌入向量"""
  36. self.logger.debug(f"生成Ollama嵌入向量,文本长度: {len(text)} 字符")
  37. if not text or len(text.strip()) == 0:
  38. self.logger.debug("输入文本为空,返回零向量")
  39. return [0.0] * self.embedding_dimension
  40. url = f"{self.base_url}/api/embeddings"
  41. payload = {
  42. "model": self.model_name,
  43. "prompt": text
  44. }
  45. retries = 0
  46. while retries <= self.max_retries:
  47. try:
  48. response = requests.post(
  49. url,
  50. json=payload,
  51. timeout=30
  52. )
  53. if response.status_code != 200:
  54. error_msg = f"Ollama API请求错误: {response.status_code}, {response.text}"
  55. self.logger.error(error_msg)
  56. if response.status_code in (429, 500, 502, 503, 504):
  57. retries += 1
  58. if retries <= self.max_retries:
  59. wait_time = self.retry_interval * (2 ** (retries - 1))
  60. self.logger.info(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
  61. time.sleep(wait_time)
  62. continue
  63. raise ValueError(error_msg)
  64. result = response.json()
  65. if "embedding" in result:
  66. vector = result["embedding"]
  67. # 验证向量维度
  68. actual_dim = len(vector)
  69. if actual_dim != self.embedding_dimension:
  70. self.logger.debug(f"向量维度不匹配: 期望 {self.embedding_dimension}, 实际 {actual_dim}")
  71. # 如果维度不匹配,可以选择截断或填充
  72. if actual_dim > self.embedding_dimension:
  73. vector = vector[:self.embedding_dimension]
  74. else:
  75. vector.extend([0.0] * (self.embedding_dimension - actual_dim))
  76. # 添加成功生成embedding的debug日志
  77. self.logger.debug(f"✓ 成功生成Ollama embedding向量,维度: {len(vector)}")
  78. return vector
  79. else:
  80. error_msg = f"Ollama API返回格式异常: {result}"
  81. self.logger.error(error_msg)
  82. raise ValueError(error_msg)
  83. except Exception as e:
  84. self.logger.error(f"生成Ollama embedding时出错: {str(e)}")
  85. retries += 1
  86. if retries <= self.max_retries:
  87. wait_time = self.retry_interval * (2 ** (retries - 1))
  88. self.logger.info(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
  89. time.sleep(wait_time)
  90. else:
  91. self.logger.error(f"已达到最大重试次数 ({self.max_retries}),生成embedding失败")
  92. return [0.0] * self.embedding_dimension
  93. raise RuntimeError("生成Ollama embedding失败")
  94. def test_connection(self, test_text="测试文本") -> dict:
  95. """测试Ollama嵌入模型的连接"""
  96. result = {
  97. "success": False,
  98. "model": self.model_name,
  99. "base_url": self.base_url,
  100. "message": "",
  101. "actual_dimension": None,
  102. "expected_dimension": self.embedding_dimension
  103. }
  104. try:
  105. self.logger.info(f"测试Ollama嵌入模型连接 - 模型: {self.model_name}")
  106. self.logger.info(f"Ollama服务地址: {self.base_url}")
  107. vector = self.generate_embedding(test_text)
  108. actual_dimension = len(vector)
  109. result["success"] = True
  110. result["actual_dimension"] = actual_dimension
  111. if actual_dimension != self.embedding_dimension:
  112. result["message"] = f"警告: 模型实际生成的向量维度({actual_dimension})与配置维度({self.embedding_dimension})不一致"
  113. else:
  114. result["message"] = f"Ollama连接测试成功,向量维度: {actual_dimension}"
  115. return result
  116. except Exception as e:
  117. result["message"] = f"Ollama连接测试失败: {str(e)}"
  118. return result