app.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import os
  2. from fastapi import Request
  3. from fastapi.responses import JSONResponse
  4. from langchain_community.document_loaders import TextLoader
  5. from langchain_text_splitters import RecursiveCharacterTextSplitter
  6. from chat.file_chat import allowed_file, UPLOAD_FOLDER
  7. from qa_chain.Chat_QA_chain_self import Chat_QA_chain_self # 替换为你的模块名
  8. from qa_chain.get_vectordb import get_vectordb
  9. from qa_chain.result import success,failed
  10. from fastapi.middleware.cors import CORSMiddleware
  11. from fastapi import FastAPI, File, UploadFile
  12. app = FastAPI()
  13. app.add_middleware(
  14. CORSMiddleware,
  15. # 允许跨域的源列表,例如 ["http://www.example.org"] 等等,["*"] 表示允许任何源
  16. allow_origins=["*"],
  17. # 跨域请求是否支持 cookie,默认是 False,如果为 True,allow_origins 必须为具体的源,不可以是 ["*"]
  18. allow_credentials=False,
  19. # 允许跨域请求的 HTTP 方法列表,默认是 ["GET"]
  20. allow_methods=["GET",'POST','PUT','PATCH','DELETE','HEAD','OPTIONS'],
  21. # 允许跨域请求的 HTTP 请求头列表,默认是 [],可以使用 ["*"] 表示允许所有的请求头
  22. # 当然 Accept、Accept-Language、Content-Language 以及 Content-Type 总之被允许的
  23. allow_headers=["*"],
  24. # 可以被浏览器访问的响应头, 默认是 [],一般很少指定
  25. # expose_headers=["*"]
  26. # 设定浏览器缓存 CORS 响应的最长时间,单位是秒。默认为 600,一般也很少指定
  27. # max_age=1000
  28. )
  29. # 初始化你的Chat_QA_chain_self实例
  30. chat_qa_chain = Chat_QA_chain_self(file_path="./knowledge_db", persist_path="./vector_db/chroma")
  31. @app.post('/clear/history')
  32. async def clear_history():
  33. try:
  34. chat_qa_chain.clear_chat_history()
  35. res = success({"message": "Chat history has been cleared."}, "success")
  36. return JSONResponse(res)
  37. except Exception as e:
  38. res = success({"response": "请尝试换个方式提问。"}, "success")
  39. return JSONResponse(res)
  40. @app.post('/view/history')
  41. async def view_history():
  42. try:
  43. history = chat_qa_chain.get_chat_history()
  44. formatted_history = []
  45. for item in history:
  46. formatted_item = {
  47. "Question": item[0],
  48. "Answer": item[1]
  49. }
  50. formatted_history.append(formatted_item)
  51. res = success({"history": formatted_history}, "success")
  52. return JSONResponse(res)
  53. except Exception as e:
  54. res = success({"response": "请尝试换个方式提问。"}, "success")
  55. return JSONResponse(res)
  56. # @app.post('/ask')
  57. # async def ask(request: Request):
  58. # try:
  59. # data = await request.json()
  60. # question = data.get('question')
  61. # chat_history = data.get('chat_history', [])
  62. # response = chat_qa_chain.build_chain(question,chat_history)
  63. # # response = chain.invoke({"question": question, "chat_history": chat_history})
  64. # # 添加新的聊天记录到历史记录中
  65. # chat_qa_chain.add_to_chat_history(question, response)
  66. # res = success({"response": response}, "success")
  67. # return JSONResponse(res)
  68. # except Exception as e:
  69. # res = failed({}, {"error": f"{e}"})
  70. # return JSONResponse(res)
  71. @app.post('/ask')
  72. async def ask(request: Request):
  73. try:
  74. data = await request.json()
  75. question = data.get('question')
  76. chat_history = data.get('chat_history', [])
  77. if question:
  78. if '你好' in question:
  79. response = '你好!请提出相应问题'
  80. else:
  81. chain = chat_qa_chain.build_chain()
  82. response = chain.invoke({"question": question, "chat_history": chat_history})
  83. # 添加新的聊天记录到历史记录中
  84. chat_qa_chain.add_to_chat_history(question, response)
  85. res = success({"response": response}, "success")
  86. return JSONResponse(res)
  87. except Exception as e:
  88. res = success({"response": "请尝试换个方式提问。"}, "success")
  89. return JSONResponse(res)
  90. @app.post('/ask/rag')
  91. async def ask_rag(request: Request):
  92. try:
  93. data = await request.json()
  94. question = data.get('question')
  95. chat_history = data.get('chat_history', [])
  96. chain = chat_qa_chain.build_rag_chain()
  97. response = chain.invoke({"question": question, "chat_history": chat_history})
  98. # 添加新的聊天记录到历史记录中
  99. chat_qa_chain.add_to_chat_history(question, response)
  100. res = success({"response": response}, "success")
  101. return JSONResponse(res)
  102. except Exception as e:
  103. res = success({"response": "请尝试换个方式提问。"}, "success")
  104. return JSONResponse(res)
  105. # build_text_chain
  106. @app.post('/ask/unstructure/rag')
  107. async def ask_rag(request: Request):
  108. try:
  109. data = await request.json()
  110. question = data.get('question')
  111. chat_history = data.get('chat_history', [])
  112. chain = chat_qa_chain.build_text_chain()
  113. response = chain.invoke({"question": question, "chat_history": chat_history})
  114. # 添加新的聊天记录到历史记录中
  115. chat_qa_chain.add_to_chat_history(question, response)
  116. res = success({"response": response}, "success")
  117. return JSONResponse(res)
  118. except Exception as e:
  119. res = success({"response": "请尝试换个方式提问。"}, "success")
  120. return JSONResponse(res)
  121. @app.post('/file/list')
  122. async def file_list():
  123. try:
  124. file_path="./knowledge_db"
  125. # 检查文件夹路径是否存在
  126. if not os.path.isdir(file_path):
  127. res = failed({}, {"error": "Folder does not exist"})
  128. return JSONResponse(res)
  129. # 获取文件夹中的文件列表
  130. files = os.listdir(file_path)
  131. res = success({"files": files}, "success")
  132. return JSONResponse(res)
  133. except Exception as e:
  134. res = success({"response": "请尝试换个方式提问。"}, "success")
  135. return JSONResponse(res)
  136. @app.post("/file/upload")
  137. async def upload_file(file: UploadFile = File(...)):
  138. try:
  139. if not allowed_file(file.filename):
  140. return JSONResponse(failed({}, {"error": "File type not allowed"}))
  141. # Ensure the upload folder exists
  142. os.makedirs(UPLOAD_FOLDER, exist_ok=True)
  143. file_path = os.path.join(UPLOAD_FOLDER, file.filename)
  144. with open(file_path, "wb") as buffer:
  145. buffer.write(await file.read())
  146. data = []
  147. loader = TextLoader(f'{UPLOAD_FOLDER}/{file.filename}', encoding='utf8')
  148. data.extend(loader.load())
  149. text_splitter = RecursiveCharacterTextSplitter(
  150. chunk_size=500, chunk_overlap=150)
  151. split_docs = text_splitter.split_documents(data)
  152. vector_db = get_vectordb(file_path="./knowledge_db", persist_path="./vector_db/chroma",embedding="m3e")
  153. vector_db.add_documents(split_docs)
  154. return JSONResponse(success({"response": "File uploaded successfully"}, "success"))
  155. except Exception as e:
  156. res = success({"response": "请尝试换个方式提问。"}, "success")
  157. return JSONResponse(res)
  158. if __name__ == '__main__':
  159. import uvicorn
  160. uvicorn.run(app, host='0.0.0.0', port=3004)