| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- """
- Cohere API 使用示例 - 展示如何使用 Bearer Token
- 演示如何使用 Cohere API Key 作为 Bearer token 调用 Cohere API
- """
- import os
- import sys
- from typing import Any, Dict, List, Optional
- import requests
- from loguru import logger
- # 配置日志
- logger.remove()
- logger.add(
- sys.stdout,
- level="INFO",
- format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}",
- )
- class CohereAPIClient:
- """Cohere API 客户端 - 使用 Bearer Token 认证"""
- def __init__(self, api_key: Optional[str] = None):
- """
- 初始化 Cohere API 客户端
- Args:
- api_key: Cohere API Key(如果不提供,从环境变量读取)
- """
- self.api_key = api_key or os.environ.get("COHERE_API_KEY")
- if not self.api_key:
- raise ValueError("请提供 Cohere API Key 或设置 COHERE_API_KEY 环境变量")
- self.base_url = "https://api.cohere.ai/v1"
- self.headers = {
- "Authorization": f"Bearer {self.api_key}",
- "Content-Type": "application/json",
- }
- def list_models(self) -> Dict[str, Any]:
- """
- 获取可用的模型列表
- Returns:
- 模型列表
- """
- logger.info("获取模型列表...")
- response = requests.get(
- f"{self.base_url}/models",
- headers=self.headers,
- timeout=10,
- )
- response.raise_for_status()
- return response.json()
- def rerank(
- self,
- query: str,
- documents: List[str],
- model: str = "rerank-multilingual-v3.0",
- top_n: int = 3,
- ) -> Dict[str, Any]:
- """
- 使用 Rerank API 对文档进行重排序
- Args:
- query: 查询文本
- documents: 文档列表
- model: 使用的模型(默认: rerank-multilingual-v3.0)
- top_n: 返回前 N 个结果
- Returns:
- 重排序结果
- """
- logger.info(f"使用模型 {model} 对 {len(documents)} 个文档进行重排序...")
- data = {
- "model": model,
- "query": query,
- "documents": documents,
- "top_n": top_n,
- }
- response = requests.post(
- f"{self.base_url}/rerank",
- headers=self.headers,
- json=data,
- timeout=30,
- )
- response.raise_for_status()
- return response.json()
- def test_connection(self) -> bool:
- """
- 测试 API 连接是否正常
- Returns:
- 连接是否成功
- """
- try:
- result = self.list_models()
- logger.success(f"连接成功!可用模型数量: {len(result.get('models', []))}")
- return True
- except Exception as e:
- logger.error(f"连接失败: {str(e)}")
- return False
- def main():
- """主函数 - 演示如何使用 Cohere API"""
- # 从命令行参数或环境变量获取 API Key
- api_key = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("COHERE_API_KEY")
- if not api_key:
- logger.error("请提供 Cohere API Key")
- logger.info("使用方法: python cohere_api_example.py <API_KEY>")
- logger.info("或设置环境变量: COHERE_API_KEY=<API_KEY>")
- return 1
- # 创建客户端
- client = CohereAPIClient(api_key=api_key)
- # 测试连接
- logger.info("=" * 60)
- logger.info("测试 API 连接...")
- logger.info("=" * 60)
- if not client.test_connection():
- return 1
- # 示例 1: 获取模型列表
- logger.info("\n" + "=" * 60)
- logger.info("示例 1: 获取模型列表")
- logger.info("=" * 60)
- try:
- models = client.list_models()
- logger.info(f"可用模型: {len(models.get('models', []))} 个")
- # 显示前几个模型
- for model in models.get("models", [])[:5]:
- logger.info(f" - {model.get('name', 'N/A')}")
- except Exception as e:
- logger.error(f"获取模型列表失败: {str(e)}")
- # 示例 2: 使用 Rerank API
- logger.info("\n" + "=" * 60)
- logger.info("示例 2: 使用 Rerank API 重排序文档")
- logger.info("=" * 60)
- query = "什么是人工智能?"
- documents = [
- "人工智能是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。",
- "机器学习是人工智能的核心技术之一,它使计算机能够从数据中学习。",
- "深度学习是机器学习的一个子集,使用神经网络来模拟人脑的工作方式。",
- "自然语言处理是人工智能的一个领域,专注于计算机与人类语言之间的交互。",
- ]
- try:
- result = client.rerank(
- query=query,
- documents=documents,
- model="rerank-multilingual-v3.0",
- top_n=2,
- )
- logger.success("重排序完成!")
- logger.info(f"查询: {query}")
- logger.info(f"返回了 {len(result.get('results', []))} 个结果:\n")
- for idx, item in enumerate(result.get("results", []), 1):
- doc_index = item.get("index", 0)
- relevance_score = item.get("relevance_score", 0)
- document = documents[doc_index] if doc_index < len(documents) else "N/A"
- logger.info(f"{idx}. 文档 {doc_index + 1} (相关性: {relevance_score:.4f})")
- logger.info(f" {document[:80]}...")
- except Exception as e:
- logger.error(f"Rerank API 调用失败: {str(e)}")
- return 1
- logger.info("\n" + "=" * 60)
- logger.success("所有示例执行完成!")
- logger.info("=" * 60)
- return 0
- if __name__ == "__main__":
- sys.exit(main())
|