app.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import os
  2. from fastapi import Request
  3. from fastapi.responses import JSONResponse
  4. from langchain_community.document_loaders import TextLoader
  5. from langchain_text_splitters import RecursiveCharacterTextSplitter
  6. from chat.file_chat import allowed_file, UPLOAD_FOLDER
  7. from qa_chain.Chat_QA_chain_self import Chat_QA_chain_self # 替换为你的模块名
  8. from qa_chain.get_vectordb import get_vectordb
  9. from qa_chain.result import success,failed
  10. from fastapi.middleware.cors import CORSMiddleware
  11. from fastapi import FastAPI, File, UploadFile
  12. app = FastAPI()
  13. app.add_middleware(
  14. CORSMiddleware,
  15. # 允许跨域的源列表,例如 ["http://www.example.org"] 等等,["*"] 表示允许任何源
  16. allow_origins=["*"],
  17. # 跨域请求是否支持 cookie,默认是 False,如果为 True,allow_origins 必须为具体的源,不可以是 ["*"]
  18. allow_credentials=False,
  19. # 允许跨域请求的 HTTP 方法列表,默认是 ["GET"]
  20. allow_methods=["GET",'POST','PUT','PATCH','DELETE','HEAD','OPTIONS'],
  21. # 允许跨域请求的 HTTP 请求头列表,默认是 [],可以使用 ["*"] 表示允许所有的请求头
  22. # 当然 Accept、Accept-Language、Content-Language 以及 Content-Type 总之被允许的
  23. allow_headers=["*"],
  24. # 可以被浏览器访问的响应头, 默认是 [],一般很少指定
  25. # expose_headers=["*"]
  26. # 设定浏览器缓存 CORS 响应的最长时间,单位是秒。默认为 600,一般也很少指定
  27. # max_age=1000
  28. )
  29. # 获取当前脚本所在的绝对路径
  30. current_dir = os.path.dirname(os.path.abspath(__file__))
  31. # 构建默认路径
  32. DEFAULT_DB_PATH = os.path.join(current_dir, "knowledge_db")
  33. DEFAULT_PERSIST_PATH = os.path.join(current_dir, "vector_db", "chroma")
  34. # 初始化 QA 链接
  35. chat_qa_chain = Chat_QA_chain_self(file_path=DEFAULT_DB_PATH, persist_path=DEFAULT_PERSIST_PATH, embedding='m3e')
  36. @app.post('/clear/history')
  37. async def clear_history():
  38. try:
  39. chat_qa_chain.clear_chat_history()
  40. res = success({"message": "Chat history has been cleared."}, "success")
  41. return JSONResponse(res)
  42. except Exception as e:
  43. res = failed({}, {"error": f"{e}"})
  44. return JSONResponse(res)
  45. @app.post('/view/history')
  46. async def view_history():
  47. try:
  48. history = chat_qa_chain.get_chat_history()
  49. formatted_history = []
  50. for item in history:
  51. formatted_item = {
  52. "Question": item[0],
  53. "Answer": item[1]
  54. }
  55. formatted_history.append(formatted_item)
  56. res = success({"history": formatted_history}, "success")
  57. return JSONResponse(res)
  58. except Exception as e:
  59. res = failed({}, {"error": f"{e}"})
  60. return JSONResponse(res)
  61. @app.post('/ask')
  62. async def ask(request: Request):
  63. try:
  64. data = await request.json()
  65. question = data.get('question')
  66. chat_history = data.get('chat_history', [])
  67. if question:
  68. if '你好' in question:
  69. response = '你好!请提出相应问题'
  70. else:
  71. chain = chat_qa_chain.build_chain()
  72. response = chain.invoke({"question": question, "chat_history": chat_history})
  73. # 添加新的聊天记录到历史记录中
  74. chat_qa_chain.add_to_chat_history(question, response)
  75. res = success({"response": response}, "success")
  76. return JSONResponse(res)
  77. except Exception as e:
  78. res = failed({}, {"error": f"{e}"})
  79. return JSONResponse(res)
  80. @app.post('/ask/rag')
  81. async def ask_rag(request: Request):
  82. try:
  83. data = await request.json()
  84. question = data.get('question')
  85. chat_history = data.get('chat_history', [])
  86. chain = chat_qa_chain.build_rag_chain()
  87. response = chain.invoke({"question": question, "chat_history": chat_history})
  88. # 添加新的聊天记录到历史记录中
  89. chat_qa_chain.add_to_chat_history(question, response)
  90. res = success({"response": response}, "success")
  91. return JSONResponse(res)
  92. except Exception as e:
  93. res = failed({}, {"error": f"{e}"})
  94. return JSONResponse(res)
  95. # build_text_chain
  96. @app.post('/ask/unstructure/rag')
  97. async def ask_rag(request: Request):
  98. try:
  99. data = await request.json()
  100. question = data.get('question')
  101. chat_history = data.get('chat_history', [])
  102. chain = chat_qa_chain.build_text_chain()
  103. response = chain.invoke({"question": question, "chat_history": chat_history})
  104. # 添加新的聊天记录到历史记录中
  105. chat_qa_chain.add_to_chat_history(question, response)
  106. res = success({"response": response}, "success")
  107. return JSONResponse(res)
  108. except Exception as e:
  109. res = failed({}, {"error": f"{e}"})
  110. return JSONResponse(res)
  111. @app.post('/file/list')
  112. async def file_list():
  113. try:
  114. file_path="./knowledge_db"
  115. # 检查文件夹路径是否存在
  116. if not os.path.isdir(file_path):
  117. res = failed({}, {"error": "Folder does not exist"})
  118. return JSONResponse(res)
  119. # 获取文件夹中的文件列表
  120. files = os.listdir(file_path)
  121. res = success({"files": files}, "success")
  122. return JSONResponse(res)
  123. except Exception as e:
  124. res = failed({}, {"error": f"{e}"})
  125. return JSONResponse(res)
  126. @app.post("/file/upload")
  127. async def upload_file(file: UploadFile = File(...)):
  128. try:
  129. if not allowed_file(file.filename):
  130. return JSONResponse(failed({}, {"error": "File type not allowed"}))
  131. # Ensure the upload folder exists
  132. os.makedirs(UPLOAD_FOLDER, exist_ok=True)
  133. file_path = os.path.join(UPLOAD_FOLDER, file.filename)
  134. with open(file_path, "wb") as buffer:
  135. buffer.write(await file.read())
  136. data = []
  137. loader = TextLoader(f'{UPLOAD_FOLDER}/{file.filename}', encoding='utf8')
  138. data.extend(loader.load())
  139. text_splitter = RecursiveCharacterTextSplitter(
  140. chunk_size=500, chunk_overlap=150)
  141. split_docs = text_splitter.split_documents(data)
  142. vector_db = get_vectordb(file_path="./knowledge_db", persist_path="./vector_db/chroma",embedding="m3e")
  143. vector_db.add_documents(split_docs)
  144. return JSONResponse(success({"response": "File uploaded successfully"}, "success"))
  145. except Exception as e:
  146. return JSONResponse(failed({}, {"error": str(e)}))
  147. if __name__ == '__main__':
  148. import uvicorn
  149. uvicorn.run(app, host='0.0.0.0', port=3004)