import os from fastapi import Request from fastapi.responses import JSONResponse from langchain_community.document_loaders import TextLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from chat.file_chat import allowed_file, UPLOAD_FOLDER from qa_chain.Chat_QA_chain_self import Chat_QA_chain_self # 替换为你的模块名 from qa_chain.get_vectordb import get_vectordb from qa_chain.result import success,failed from fastapi.middleware.cors import CORSMiddleware from fastapi import FastAPI, File, UploadFile app = FastAPI() app.add_middleware( CORSMiddleware, # 允许跨域的源列表,例如 ["http://www.example.org"] 等等,["*"] 表示允许任何源 allow_origins=["*"], # 跨域请求是否支持 cookie,默认是 False,如果为 True,allow_origins 必须为具体的源,不可以是 ["*"] allow_credentials=False, # 允许跨域请求的 HTTP 方法列表,默认是 ["GET"] allow_methods=["GET",'POST','PUT','PATCH','DELETE','HEAD','OPTIONS'], # 允许跨域请求的 HTTP 请求头列表,默认是 [],可以使用 ["*"] 表示允许所有的请求头 # 当然 Accept、Accept-Language、Content-Language 以及 Content-Type 总之被允许的 allow_headers=["*"], # 可以被浏览器访问的响应头, 默认是 [],一般很少指定 # expose_headers=["*"] # 设定浏览器缓存 CORS 响应的最长时间,单位是秒。默认为 600,一般也很少指定 # max_age=1000 ) # 定义默认路径 DEFAULT_DB_PATH = os.path.join(".", "knowledge_db") DEFAULT_PERSIST_PATH = os.path.join(".", "vector_db", "chroma") # 初始化你的Chat_QA_chain_self实例 chat_qa_chain = Chat_QA_chain_self(file_path=DEFAULT_DB_PATH, persist_path=DEFAULT_PERSIST_PATH,embedding='m3e') @app.post('/clear/history') async def clear_history(): try: chat_qa_chain.clear_chat_history() res = success({"message": "Chat history has been cleared."}, "success") return JSONResponse(res) except Exception as e: res = failed({}, {"error": f"{e}"}) return JSONResponse(res) @app.post('/view/history') async def view_history(): try: history = chat_qa_chain.get_chat_history() formatted_history = [] for item in history: formatted_item = { "Question": item[0], "Answer": item[1] } formatted_history.append(formatted_item) res = success({"history": formatted_history}, "success") return JSONResponse(res) except Exception as e: res = failed({}, {"error": f"{e}"}) return JSONResponse(res) @app.post('/ask') async def ask(request: Request): try: data = await request.json() question = data.get('question') chat_history = data.get('chat_history', []) if question: if '你好' in question: response = '你好!请提出相应问题' else: chain = chat_qa_chain.build_chain() response = chain.invoke({"question": question, "chat_history": chat_history}) # 添加新的聊天记录到历史记录中 chat_qa_chain.add_to_chat_history(question, response) res = success({"response": response}, "success") return JSONResponse(res) except Exception as e: res = failed({}, {"error": f"{e}"}) return JSONResponse(res) @app.post('/ask/rag') async def ask_rag(request: Request): try: data = await request.json() question = data.get('question') chat_history = data.get('chat_history', []) chain = chat_qa_chain.build_rag_chain() response = chain.invoke({"question": question, "chat_history": chat_history}) # 添加新的聊天记录到历史记录中 chat_qa_chain.add_to_chat_history(question, response) res = success({"response": response}, "success") return JSONResponse(res) except Exception as e: res = failed({}, {"error": f"{e}"}) return JSONResponse(res) # build_text_chain @app.post('/ask/unstructure/rag') async def ask_rag(request: Request): try: data = await request.json() question = data.get('question') chat_history = data.get('chat_history', []) chain = chat_qa_chain.build_text_chain() response = chain.invoke({"question": question, "chat_history": chat_history}) # 添加新的聊天记录到历史记录中 chat_qa_chain.add_to_chat_history(question, response) res = success({"response": response}, "success") return JSONResponse(res) except Exception as e: res = failed({}, {"error": f"{e}"}) return JSONResponse(res) @app.post('/file/list') async def file_list(): try: file_path="./knowledge_db" # 检查文件夹路径是否存在 if not os.path.isdir(file_path): res = failed({}, {"error": "Folder does not exist"}) return JSONResponse(res) # 获取文件夹中的文件列表 files = os.listdir(file_path) res = success({"files": files}, "success") return JSONResponse(res) except Exception as e: res = failed({}, {"error": f"{e}"}) return JSONResponse(res) @app.post("/file/upload") async def upload_file(file: UploadFile = File(...)): try: if not allowed_file(file.filename): return JSONResponse(failed({}, {"error": "File type not allowed"})) # Ensure the upload folder exists os.makedirs(UPLOAD_FOLDER, exist_ok=True) file_path = os.path.join(UPLOAD_FOLDER, file.filename) with open(file_path, "wb") as buffer: buffer.write(await file.read()) data = [] loader = TextLoader(f'{UPLOAD_FOLDER}/{file.filename}', encoding='utf8') data.extend(loader.load()) text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=150) split_docs = text_splitter.split_documents(data) vector_db = get_vectordb(file_path="./knowledge_db", persist_path="./vector_db/chroma",embedding="m3e") vector_db.add_documents(split_docs) return JSONResponse(success({"response": "File uploaded successfully"}, "success")) except Exception as e: return JSONResponse(failed({}, {"error": str(e)})) if __name__ == '__main__': import uvicorn uvicorn.run(app, host='0.0.0.0', port=3004)