app.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import os
  2. from fastapi import Request
  3. from fastapi.responses import JSONResponse
  4. from langchain_community.document_loaders import TextLoader
  5. from langchain_text_splitters import RecursiveCharacterTextSplitter
  6. from chat.file_chat import allowed_file, UPLOAD_FOLDER
  7. from qa_chain.Chat_QA_chain_self import Chat_QA_chain_self # 替换为你的模块名
  8. from qa_chain.get_vectordb import get_vectordb
  9. from qa_chain.result import success,failed
  10. from fastapi.middleware.cors import CORSMiddleware
  11. from fastapi import FastAPI, File, UploadFile
  12. app = FastAPI()
  13. app.add_middleware(
  14. CORSMiddleware,
  15. # 允许跨域的源列表,例如 ["http://www.example.org"] 等等,["*"] 表示允许任何源
  16. allow_origins=["*"],
  17. # 跨域请求是否支持 cookie,默认是 False,如果为 True,allow_origins 必须为具体的源,不可以是 ["*"]
  18. allow_credentials=False,
  19. # 允许跨域请求的 HTTP 方法列表,默认是 ["GET"]
  20. allow_methods=["GET",'POST','PUT','PATCH','DELETE','HEAD','OPTIONS'],
  21. # 允许跨域请求的 HTTP 请求头列表,默认是 [],可以使用 ["*"] 表示允许所有的请求头
  22. # 当然 Accept、Accept-Language、Content-Language 以及 Content-Type 总之被允许的
  23. allow_headers=["*"],
  24. # 可以被浏览器访问的响应头, 默认是 [],一般很少指定
  25. # expose_headers=["*"]
  26. # 设定浏览器缓存 CORS 响应的最长时间,单位是秒。默认为 600,一般也很少指定
  27. # max_age=1000
  28. )
  29. DEFAULT_DB_PATH = os.path.join("knowledge_db")
  30. DEFAULT_PERSIST_PATH = os.path.join("vector_db", "chroma")
  31. # 初始化你的Chat_QA_chain_self实例
  32. # chat_qa_chain = Chat_QA_chain_self(file_path="./knowledge_db", persist_path="./vector_db/chroma")
  33. chat_qa_chain = Chat_QA_chain_self(file_path=DEFAULT_DB_PATH, persist_path=DEFAULT_PERSIST_PATH)
  34. @app.post('/clear/history')
  35. async def clear_history():
  36. try:
  37. chat_qa_chain.clear_chat_history()
  38. res = success({"message": "Chat history has been cleared."}, "success")
  39. return JSONResponse(res)
  40. except Exception as e:
  41. res = success({"response": "请尝试换个方式提问。"}, "success")
  42. return JSONResponse(res)
  43. @app.post('/view/history')
  44. async def view_history():
  45. try:
  46. history = chat_qa_chain.get_chat_history()
  47. formatted_history = []
  48. for item in history:
  49. formatted_item = {
  50. "Question": item[0],
  51. "Answer": item[1]
  52. }
  53. formatted_history.append(formatted_item)
  54. res = success({"history": formatted_history}, "success")
  55. return JSONResponse(res)
  56. except Exception as e:
  57. res = success({"response": "请尝试换个方式提问。"}, "success")
  58. return JSONResponse(res)
  59. # @app.post('/ask')
  60. # async def ask(request: Request):
  61. # try:
  62. # data = await request.json()
  63. # question = data.get('question')
  64. # chat_history = data.get('chat_history', [])
  65. # response = chat_qa_chain.build_chain(question,chat_history)
  66. # # response = chain.invoke({"question": question, "chat_history": chat_history})
  67. # # 添加新的聊天记录到历史记录中
  68. # chat_qa_chain.add_to_chat_history(question, response)
  69. # res = success({"response": response}, "success")
  70. # return JSONResponse(res)
  71. # except Exception as e:
  72. # res = failed({}, {"error": f"{e}"})
  73. # return JSONResponse(res)
  74. @app.post('/ask')
  75. async def ask(request: Request):
  76. try:
  77. data = await request.json()
  78. question = data.get('question')
  79. chat_history = data.get('chat_history', [])
  80. if question:
  81. if '你好' in question:
  82. response = '你好!请提出相应问题'
  83. else:
  84. chain = chat_qa_chain.build_chain()
  85. response = chain.invoke({"question": question, "chat_history": chat_history})
  86. # 添加新的聊天记录到历史记录中
  87. chat_qa_chain.add_to_chat_history(question, response)
  88. res = success({"response": response}, "success")
  89. return JSONResponse(res)
  90. except Exception as e:
  91. res = success({"response": "请尝试换个方式提问。"}, "success")
  92. return JSONResponse(res)
  93. @app.post('/ask/rag')
  94. async def ask_rag(request: Request):
  95. try:
  96. data = await request.json()
  97. question = data.get('question')
  98. chat_history = data.get('chat_history', [])
  99. chain = chat_qa_chain.build_rag_chain()
  100. response = chain.invoke({"question": question, "chat_history": chat_history})
  101. # 添加新的聊天记录到历史记录中
  102. chat_qa_chain.add_to_chat_history(question, response)
  103. res = success({"response": response}, "success")
  104. return JSONResponse(res)
  105. except Exception as e:
  106. res = success({"response": "请尝试换个方式提问。"}, "success")
  107. return JSONResponse(res)
  108. # build_text_chain
  109. @app.post('/ask/unstructure/rag')
  110. async def ask_rag(request: Request):
  111. try:
  112. data = await request.json()
  113. question = data.get('question')
  114. chat_history = data.get('chat_history', [])
  115. chain = chat_qa_chain.build_text_chain()
  116. response = chain.invoke({"question": question, "chat_history": chat_history})
  117. # 添加新的聊天记录到历史记录中
  118. chat_qa_chain.add_to_chat_history(question, response)
  119. res = success({"response": response}, "success")
  120. return JSONResponse(res)
  121. except Exception as e:
  122. res = success({"response": "请尝试换个方式提问。"}, "success")
  123. return JSONResponse(res)
  124. @app.post('/file/list')
  125. async def file_list():
  126. try:
  127. file_path="./knowledge_db"
  128. # 检查文件夹路径是否存在
  129. if not os.path.isdir(file_path):
  130. res = failed({}, {"error": "Folder does not exist"})
  131. return JSONResponse(res)
  132. # 获取文件夹中的文件列表
  133. files = os.listdir(file_path)
  134. res = success({"files": files}, "success")
  135. return JSONResponse(res)
  136. except Exception as e:
  137. res = success({"response": "请尝试换个方式提问。"}, "success")
  138. return JSONResponse(res)
  139. @app.post("/file/upload")
  140. async def upload_file(file: UploadFile = File(...)):
  141. try:
  142. if not allowed_file(file.filename):
  143. return JSONResponse(failed({}, {"error": "File type not allowed"}))
  144. # Ensure the upload folder exists
  145. os.makedirs(UPLOAD_FOLDER, exist_ok=True)
  146. file_path = os.path.join(UPLOAD_FOLDER, file.filename)
  147. with open(file_path, "wb") as buffer:
  148. buffer.write(await file.read())
  149. data = []
  150. loader = TextLoader(f'{UPLOAD_FOLDER}/{file.filename}', encoding='utf8')
  151. data.extend(loader.load())
  152. text_splitter = RecursiveCharacterTextSplitter(
  153. chunk_size=500, chunk_overlap=150)
  154. split_docs = text_splitter.split_documents(data)
  155. vector_db = get_vectordb(file_path="./knowledge_db", persist_path="./vector_db/chroma",embedding="m3e")
  156. vector_db.add_documents(split_docs)
  157. return JSONResponse(success({"response": "File uploaded successfully"}, "success"))
  158. except Exception as e:
  159. res = success({"response": "请尝试换个方式提问。"}, "success")
  160. return JSONResponse(res)
  161. if __name__ == '__main__':
  162. import uvicorn
  163. uvicorn.run(app, host='0.0.0.0', port=3004)