|
@@ -95,17 +95,163 @@ class PG_VectorStore(VannaBase):
|
|
|
case _:
|
|
|
raise ValueError("Specified collection does not exist.")
|
|
|
|
|
|
+ # def get_similar_question_sql(self, question: str) -> list:
|
|
|
+ # documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
|
|
|
+ # return [ast.literal_eval(document.page_content) for document in documents]
|
|
|
+
|
|
|
+ # 在原来的基础之上,增加相似度的值。
|
|
|
def get_similar_question_sql(self, question: str) -> list:
|
|
|
- documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
|
|
|
- return [ast.literal_eval(document.page_content) for document in documents]
|
|
|
+ docs_with_scores = self.sql_collection.similarity_search_with_score(
|
|
|
+ query=question,
|
|
|
+ k=self.n_results
|
|
|
+ )
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for doc, score in docs_with_scores:
|
|
|
+ # 将文档内容转换为 dict
|
|
|
+ base = ast.literal_eval(doc.page_content)
|
|
|
+
|
|
|
+ # 计算相似度
|
|
|
+ similarity = round(1 - score, 4)
|
|
|
+
|
|
|
+ # 每条记录单独打印
|
|
|
+ print(f"[DEBUG] SQL Match: {base.get('question', '')} | similarity: {similarity}")
|
|
|
+
|
|
|
+ # 添加 similarity 字段
|
|
|
+ base["similarity"] = similarity
|
|
|
+ results.append(base)
|
|
|
+
|
|
|
+ # 应用阈值过滤
|
|
|
+ filtered_results = self._apply_score_threshold_filter(
|
|
|
+ results,
|
|
|
+ "RESULT_VECTOR_SQL_SCORE_THRESHOLD",
|
|
|
+ "SQL"
|
|
|
+ )
|
|
|
+
|
|
|
+ return filtered_results
|
|
|
|
|
|
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
|
- documents = self.ddl_collection.similarity_search(query=question, k=self.n_results)
|
|
|
- return [document.page_content for document in documents]
|
|
|
+ docs_with_scores = self.ddl_collection.similarity_search_with_score(
|
|
|
+ query=question,
|
|
|
+ k=self.n_results
|
|
|
+ )
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for doc, score in docs_with_scores:
|
|
|
+ # 计算相似度
|
|
|
+ similarity = round(1 - score, 4)
|
|
|
+
|
|
|
+ # 每条记录单独打印
|
|
|
+ print(f"[DEBUG] DDL Match: {doc.page_content[:50]}... | similarity: {similarity}")
|
|
|
+
|
|
|
+ # 添加 similarity 字段
|
|
|
+ result = {
|
|
|
+ "content": doc.page_content,
|
|
|
+ "similarity": similarity
|
|
|
+ }
|
|
|
+ results.append(result)
|
|
|
+
|
|
|
+ # 应用阈值过滤
|
|
|
+ filtered_results = self._apply_score_threshold_filter(
|
|
|
+ results,
|
|
|
+ "RESULT_VECTOR_DDL_SCORE_THRESHOLD",
|
|
|
+ "DDL"
|
|
|
+ )
|
|
|
+
|
|
|
+ return filtered_results
|
|
|
|
|
|
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
|
- documents = self.documentation_collection.similarity_search(query=question, k=self.n_results)
|
|
|
- return [document.page_content for document in documents]
|
|
|
+ docs_with_scores = self.documentation_collection.similarity_search_with_score(
|
|
|
+ query=question,
|
|
|
+ k=self.n_results
|
|
|
+ )
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for doc, score in docs_with_scores:
|
|
|
+ # 计算相似度
|
|
|
+ similarity = round(1 - score, 4)
|
|
|
+
|
|
|
+ # 每条记录单独打印
|
|
|
+ print(f"[DEBUG] Doc Match: {doc.page_content[:50]}... | similarity: {similarity}")
|
|
|
+
|
|
|
+ # 添加 similarity 字段
|
|
|
+ result = {
|
|
|
+ "content": doc.page_content,
|
|
|
+ "similarity": similarity
|
|
|
+ }
|
|
|
+ results.append(result)
|
|
|
+
|
|
|
+ # 应用阈值过滤
|
|
|
+ filtered_results = self._apply_score_threshold_filter(
|
|
|
+ results,
|
|
|
+ "RESULT_VECTOR_DOC_SCORE_THRESHOLD",
|
|
|
+ "DOC"
|
|
|
+ )
|
|
|
+
|
|
|
+ return filtered_results
|
|
|
+
|
|
|
+ def _apply_score_threshold_filter(self, results: list, threshold_config_key: str, result_type: str) -> list:
|
|
|
+ """
|
|
|
+ 应用相似度阈值过滤逻辑
|
|
|
+
|
|
|
+ Args:
|
|
|
+ results: 原始结果列表,每个元素包含 similarity 字段
|
|
|
+ threshold_config_key: 配置中的阈值参数名
|
|
|
+ result_type: 结果类型(用于日志)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 过滤后的结果列表
|
|
|
+ """
|
|
|
+ if not results:
|
|
|
+ return results
|
|
|
+
|
|
|
+ # 导入配置
|
|
|
+ try:
|
|
|
+ import app_config
|
|
|
+ enable_threshold = getattr(app_config, 'ENABLE_RESULT_VECTOR_SCORE_THRESHOLD', False)
|
|
|
+ threshold = getattr(app_config, threshold_config_key, 0.65)
|
|
|
+ except (ImportError, AttributeError) as e:
|
|
|
+ print(f"[WARNING] 无法加载阈值配置: {e},使用默认值")
|
|
|
+ enable_threshold = False
|
|
|
+ threshold = 0.65
|
|
|
+
|
|
|
+ # 如果未启用阈值过滤,直接返回原结果
|
|
|
+ if not enable_threshold:
|
|
|
+ print(f"[DEBUG] {result_type} 阈值过滤未启用,返回全部 {len(results)} 条结果")
|
|
|
+ return results
|
|
|
+
|
|
|
+ total_count = len(results)
|
|
|
+ min_required = max((total_count + 1) // 2, 1)
|
|
|
+
|
|
|
+ print(f"[DEBUG] {result_type} 阈值过滤: 总数={total_count}, 阈值={threshold}, 最少保留={min_required}")
|
|
|
+
|
|
|
+ # 按相似度降序排序(确保最相似的在前面)
|
|
|
+ sorted_results = sorted(results, key=lambda x: x.get('similarity', 0), reverse=True)
|
|
|
+
|
|
|
+ # 找出满足阈值的结果
|
|
|
+ above_threshold = [r for r in sorted_results if r.get('similarity', 0) >= threshold]
|
|
|
+
|
|
|
+ # 应用过滤逻辑
|
|
|
+ if len(above_threshold) >= min_required:
|
|
|
+ # 情况1: 满足阈值的结果数量 >= 最少保留数量,返回满足阈值的结果
|
|
|
+ filtered_results = above_threshold
|
|
|
+ filtered_count = len(above_threshold)
|
|
|
+ print(f"[DEBUG] {result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (全部满足阈值)")
|
|
|
+ else:
|
|
|
+ # 情况2: 满足阈值的结果数量 < 最少保留数量,强制保留前 min_required 条
|
|
|
+ filtered_results = sorted_results[:min_required]
|
|
|
+ above_count = len(above_threshold)
|
|
|
+ below_count = min_required - above_count
|
|
|
+ filtered_count = min_required
|
|
|
+ print(f"[DEBUG] {result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (满足阈值: {above_count}, 强制保留: {below_count})")
|
|
|
+
|
|
|
+ # 打印过滤详情
|
|
|
+ for i, result in enumerate(filtered_results):
|
|
|
+ similarity = result.get('similarity', 0)
|
|
|
+ status = "✓" if similarity >= threshold else "✗"
|
|
|
+ print(f"[DEBUG] {result_type} 保留 {i+1}: similarity={similarity} {status}")
|
|
|
+
|
|
|
+ return filtered_results
|
|
|
|
|
|
def train(
|
|
|
self,
|