app.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import os
  2. from fastapi import Request
  3. from fastapi.responses import JSONResponse
  4. from langchain_community.document_loaders import TextLoader
  5. from langchain_text_splitters import RecursiveCharacterTextSplitter
  6. from chat.file_chat import allowed_file, UPLOAD_FOLDER
  7. from qa_chain.Chat_QA_chain_self import Chat_QA_chain_self # 替换为你的模块名
  8. from qa_chain.get_vectordb import get_vectordb
  9. from qa_chain.result import success,failed
  10. from fastapi.middleware.cors import CORSMiddleware
  11. from fastapi import FastAPI, File, UploadFile
  12. app = FastAPI()
  13. app.add_middleware(
  14. CORSMiddleware,
  15. # 允许跨域的源列表,例如 ["http://www.example.org"] 等等,["*"] 表示允许任何源
  16. allow_origins=["*"],
  17. # 跨域请求是否支持 cookie,默认是 False,如果为 True,allow_origins 必须为具体的源,不可以是 ["*"]
  18. allow_credentials=False,
  19. # 允许跨域请求的 HTTP 方法列表,默认是 ["GET"]
  20. allow_methods=["GET",'POST','PUT','PATCH','DELETE','HEAD','OPTIONS'],
  21. # 允许跨域请求的 HTTP 请求头列表,默认是 [],可以使用 ["*"] 表示允许所有的请求头
  22. # 当然 Accept、Accept-Language、Content-Language 以及 Content-Type 总之被允许的
  23. allow_headers=["*"],
  24. # 可以被浏览器访问的响应头, 默认是 [],一般很少指定
  25. # expose_headers=["*"]
  26. # 设定浏览器缓存 CORS 响应的最长时间,单位是秒。默认为 600,一般也很少指定
  27. # max_age=1000
  28. )
  29. # 定义默认路径
  30. DEFAULT_DB_PATH = os.path.join(".", "knowledge_db")
  31. DEFAULT_PERSIST_PATH = os.path.join(".", "vector_db", "chroma")
  32. # 初始化你的Chat_QA_chain_self实例
  33. chat_qa_chain = Chat_QA_chain_self(file_path=DEFAULT_DB_PATH, persist_path=DEFAULT_PERSIST_PATH,embedding='m3e')
  34. @app.post('/clear/history')
  35. async def clear_history():
  36. try:
  37. chat_qa_chain.clear_chat_history()
  38. res = success({"message": "Chat history has been cleared."}, "success")
  39. return JSONResponse(res)
  40. except Exception as e:
  41. res = failed({}, {"error": f"{e}"})
  42. return JSONResponse(res)
  43. @app.post('/view/history')
  44. async def view_history():
  45. try:
  46. history = chat_qa_chain.get_chat_history()
  47. formatted_history = []
  48. for item in history:
  49. formatted_item = {
  50. "Question": item[0],
  51. "Answer": item[1]
  52. }
  53. formatted_history.append(formatted_item)
  54. res = success({"history": formatted_history}, "success")
  55. return JSONResponse(res)
  56. except Exception as e:
  57. res = failed({}, {"error": f"{e}"})
  58. return JSONResponse(res)
  59. @app.post('/ask')
  60. async def ask(request: Request):
  61. try:
  62. data = await request.json()
  63. question = data.get('question')
  64. chat_history = data.get('chat_history', [])
  65. if question:
  66. if '你好' in question:
  67. response = '你好!请提出相应问题'
  68. else:
  69. chain = chat_qa_chain.build_chain()
  70. response = chain.invoke({"question": question, "chat_history": chat_history})
  71. # 添加新的聊天记录到历史记录中
  72. chat_qa_chain.add_to_chat_history(question, response)
  73. res = success({"response": response}, "success")
  74. return JSONResponse(res)
  75. except Exception as e:
  76. res = failed({}, {"error": f"{e}"})
  77. return JSONResponse(res)
  78. @app.post('/ask/rag')
  79. async def ask_rag(request: Request):
  80. try:
  81. data = await request.json()
  82. question = data.get('question')
  83. chat_history = data.get('chat_history', [])
  84. chain = chat_qa_chain.build_rag_chain()
  85. response = chain.invoke({"question": question, "chat_history": chat_history})
  86. # 添加新的聊天记录到历史记录中
  87. chat_qa_chain.add_to_chat_history(question, response)
  88. res = success({"response": response}, "success")
  89. return JSONResponse(res)
  90. except Exception as e:
  91. res = failed({}, {"error": f"{e}"})
  92. return JSONResponse(res)
  93. # build_text_chain
  94. @app.post('/ask/unstructure/rag')
  95. async def ask_rag(request: Request):
  96. try:
  97. data = await request.json()
  98. question = data.get('question')
  99. chat_history = data.get('chat_history', [])
  100. chain = chat_qa_chain.build_text_chain()
  101. response = chain.invoke({"question": question, "chat_history": chat_history})
  102. # 添加新的聊天记录到历史记录中
  103. chat_qa_chain.add_to_chat_history(question, response)
  104. res = success({"response": response}, "success")
  105. return JSONResponse(res)
  106. except Exception as e:
  107. res = failed({}, {"error": f"{e}"})
  108. return JSONResponse(res)
  109. @app.post('/file/list')
  110. async def file_list():
  111. try:
  112. file_path="./knowledge_db"
  113. # 检查文件夹路径是否存在
  114. if not os.path.isdir(file_path):
  115. res = failed({}, {"error": "Folder does not exist"})
  116. return JSONResponse(res)
  117. # 获取文件夹中的文件列表
  118. files = os.listdir(file_path)
  119. res = success({"files": files}, "success")
  120. return JSONResponse(res)
  121. except Exception as e:
  122. res = failed({}, {"error": f"{e}"})
  123. return JSONResponse(res)
  124. @app.post("/file/upload")
  125. async def upload_file(file: UploadFile = File(...)):
  126. try:
  127. if not allowed_file(file.filename):
  128. return JSONResponse(failed({}, {"error": "File type not allowed"}))
  129. # Ensure the upload folder exists
  130. os.makedirs(UPLOAD_FOLDER, exist_ok=True)
  131. file_path = os.path.join(UPLOAD_FOLDER, file.filename)
  132. with open(file_path, "wb") as buffer:
  133. buffer.write(await file.read())
  134. data = []
  135. loader = TextLoader(f'{UPLOAD_FOLDER}/{file.filename}', encoding='utf8')
  136. data.extend(loader.load())
  137. text_splitter = RecursiveCharacterTextSplitter(
  138. chunk_size=500, chunk_overlap=150)
  139. split_docs = text_splitter.split_documents(data)
  140. vector_db = get_vectordb(file_path="./knowledge_db", persist_path="./vector_db/chroma",embedding="m3e")
  141. vector_db.add_documents(split_docs)
  142. return JSONResponse(success({"response": "File uploaded successfully"}, "success"))
  143. except Exception as e:
  144. return JSONResponse(failed({}, {"error": str(e)}))
  145. if __name__ == '__main__':
  146. import uvicorn
  147. uvicorn.run(app, host='0.0.0.0', port=3004)