123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- import os
- from fastapi import Request
- from fastapi.responses import JSONResponse
- from langchain_community.document_loaders import TextLoader
- from langchain_text_splitters import RecursiveCharacterTextSplitter
- from chat.file_chat import allowed_file, UPLOAD_FOLDER
- from qa_chain.Chat_QA_chain_self import Chat_QA_chain_self # 替换为你的模块名
- from qa_chain.get_vectordb import get_vectordb
- from qa_chain.result import success,failed
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi import FastAPI, File, UploadFile
- app = FastAPI()
- app.add_middleware(
- CORSMiddleware,
- # 允许跨域的源列表,例如 ["http://www.example.org"] 等等,["*"] 表示允许任何源
- allow_origins=["*"],
- # 跨域请求是否支持 cookie,默认是 False,如果为 True,allow_origins 必须为具体的源,不可以是 ["*"]
- allow_credentials=False,
- # 允许跨域请求的 HTTP 方法列表,默认是 ["GET"]
- allow_methods=["GET",'POST','PUT','PATCH','DELETE','HEAD','OPTIONS'],
- # 允许跨域请求的 HTTP 请求头列表,默认是 [],可以使用 ["*"] 表示允许所有的请求头
- # 当然 Accept、Accept-Language、Content-Language 以及 Content-Type 总之被允许的
- allow_headers=["*"],
- # 可以被浏览器访问的响应头, 默认是 [],一般很少指定
- # expose_headers=["*"]
- # 设定浏览器缓存 CORS 响应的最长时间,单位是秒。默认为 600,一般也很少指定
- # max_age=1000
- )
- # 初始化你的Chat_QA_chain_self实例
- chat_qa_chain = Chat_QA_chain_self(file_path="./knowledge_db", persist_path="./vector_db/chroma")
- @app.post('/clear/history')
- async def clear_history():
- try:
- chat_qa_chain.clear_chat_history()
- res = success({"message": "Chat history has been cleared."}, "success")
- return JSONResponse(res)
- except Exception as e:
- res = success({"response": "请尝试换个方式提问。"}, "success")
- return JSONResponse(res)
- @app.post('/view/history')
- async def view_history():
- try:
- history = chat_qa_chain.get_chat_history()
- formatted_history = []
- for item in history:
- formatted_item = {
- "Question": item[0],
- "Answer": item[1]
- }
- formatted_history.append(formatted_item)
- res = success({"history": formatted_history}, "success")
- return JSONResponse(res)
- except Exception as e:
- res = success({"response": "请尝试换个方式提问。"}, "success")
- return JSONResponse(res)
- # @app.post('/ask')
- # async def ask(request: Request):
- # try:
- # data = await request.json()
- # question = data.get('question')
- # chat_history = data.get('chat_history', [])
- # response = chat_qa_chain.build_chain(question,chat_history)
- # # response = chain.invoke({"question": question, "chat_history": chat_history})
- # # 添加新的聊天记录到历史记录中
- # chat_qa_chain.add_to_chat_history(question, response)
- # res = success({"response": response}, "success")
- # return JSONResponse(res)
- # except Exception as e:
- # res = failed({}, {"error": f"{e}"})
- # return JSONResponse(res)
- @app.post('/ask')
- async def ask(request: Request):
- try:
- data = await request.json()
- question = data.get('question')
- chat_history = data.get('chat_history', [])
- if question:
- if '你好' in question:
- response = '你好!请提出相应问题'
- else:
- chain = chat_qa_chain.build_chain()
- response = chain.invoke({"question": question, "chat_history": chat_history})
- # 添加新的聊天记录到历史记录中
- chat_qa_chain.add_to_chat_history(question, response)
- res = success({"response": response}, "success")
- return JSONResponse(res)
- except Exception as e:
- res = success({"response": "请尝试换个方式提问。"}, "success")
- return JSONResponse(res)
- @app.post('/ask/rag')
- async def ask_rag(request: Request):
- try:
- data = await request.json()
- question = data.get('question')
- chat_history = data.get('chat_history', [])
- chain = chat_qa_chain.build_rag_chain()
- response = chain.invoke({"question": question, "chat_history": chat_history})
- # 添加新的聊天记录到历史记录中
- chat_qa_chain.add_to_chat_history(question, response)
- res = success({"response": response}, "success")
- return JSONResponse(res)
- except Exception as e:
- res = success({"response": "请尝试换个方式提问。"}, "success")
- return JSONResponse(res)
- # build_text_chain
- @app.post('/ask/unstructure/rag')
- async def ask_rag(request: Request):
- try:
- data = await request.json()
- question = data.get('question')
- chat_history = data.get('chat_history', [])
- chain = chat_qa_chain.build_text_chain()
- response = chain.invoke({"question": question, "chat_history": chat_history})
- # 添加新的聊天记录到历史记录中
- chat_qa_chain.add_to_chat_history(question, response)
- res = success({"response": response}, "success")
- return JSONResponse(res)
- except Exception as e:
- res = success({"response": "请尝试换个方式提问。"}, "success")
- return JSONResponse(res)
- @app.post('/file/list')
- async def file_list():
- try:
- file_path="./knowledge_db"
- # 检查文件夹路径是否存在
- if not os.path.isdir(file_path):
- res = failed({}, {"error": "Folder does not exist"})
- return JSONResponse(res)
- # 获取文件夹中的文件列表
- files = os.listdir(file_path)
- res = success({"files": files}, "success")
- return JSONResponse(res)
- except Exception as e:
- res = success({"response": "请尝试换个方式提问。"}, "success")
- return JSONResponse(res)
- @app.post("/file/upload")
- async def upload_file(file: UploadFile = File(...)):
- try:
- if not allowed_file(file.filename):
- return JSONResponse(failed({}, {"error": "File type not allowed"}))
- # Ensure the upload folder exists
- os.makedirs(UPLOAD_FOLDER, exist_ok=True)
- file_path = os.path.join(UPLOAD_FOLDER, file.filename)
- with open(file_path, "wb") as buffer:
- buffer.write(await file.read())
- data = []
- loader = TextLoader(f'{UPLOAD_FOLDER}/{file.filename}', encoding='utf8')
- data.extend(loader.load())
- text_splitter = RecursiveCharacterTextSplitter(
- chunk_size=500, chunk_overlap=150)
- split_docs = text_splitter.split_documents(data)
- vector_db = get_vectordb(file_path="./knowledge_db", persist_path="./vector_db/chroma",embedding="m3e")
- vector_db.add_documents(split_docs)
- return JSONResponse(success({"response": "File uploaded successfully"}, "success"))
- except Exception as e:
- res = success({"response": "请尝试换个方式提问。"}, "success")
- return JSONResponse(res)
- if __name__ == '__main__':
- import uvicorn
- uvicorn.run(app, host='0.0.0.0', port=3004)
|