app.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import chainlit as cl
  2. from chainlit.input_widget import Select
  3. from vanna_llm_factory import create_vanna_instance
  4. import os
  5. # vn.set_api_key(os.environ['VANNA_API_KEY'])
  6. # vn.set_model('chinook')
  7. # vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
  8. vn = create_vanna_instance()
  9. @cl.set_chat_profiles
  10. async def chat_profile():
  11. return [
  12. cl.ChatProfile(
  13. name="Vanna助手",
  14. markdown_description="基于Vanna的智能数据库查询助手,支持自然语言转SQL查询和数据可视化",
  15. icon="./public/avatars/huoche.png",
  16. # 备用在线图标,如果本地图标不显示可以取消注释下面的行
  17. #icon="https://raw.githubusercontent.com/tabler/tabler-icons/master/icons/database.svg",
  18. ),
  19. ]
  20. @cl.step(language="sql", name="Vanna")
  21. async def gen_query(human_query: str):
  22. sql_query = vn.generate_sql(human_query)
  23. return sql_query
  24. @cl.step(name="Vanna")
  25. async def execute_query(query):
  26. current_step = cl.context.current_step
  27. df = vn.run_sql(query)
  28. current_step.output = df.head().to_markdown(index=False)
  29. return df
  30. @cl.step(name="Plot", language="python")
  31. async def plot(human_query, sql, df):
  32. current_step = cl.context.current_step
  33. plotly_code = vn.generate_plotly_code(question=human_query, sql=sql, df=df)
  34. fig = vn.get_plotly_figure(plotly_code=plotly_code, df=df)
  35. current_step.output = plotly_code
  36. return fig
  37. @cl.step(type="run", name="Vanna")
  38. async def chain(human_query: str):
  39. sql_query = await gen_query(human_query)
  40. df = await execute_query(sql_query)
  41. fig = await plot(human_query, sql_query, df)
  42. # 创建表格和图表元素
  43. elements = [
  44. cl.Text(name="data_table", content=df.to_markdown(index=False), display="inline"),
  45. cl.Plotly(name="chart", figure=fig, display="inline")
  46. ]
  47. await cl.Message(content=human_query, elements=elements, author="Vanna助手").send()
  48. @cl.on_message
  49. async def main(message: cl.Message):
  50. await chain(message.content)
  51. @cl.on_chat_start
  52. async def on_chat_start():
  53. # 发送中文欢迎消息
  54. welcome_message = """
  55. 🎉 **欢迎使用智能数据库查询助手!**
  56. 我可以帮助您:
  57. - 🔍 将自然语言问题转换为SQL查询
  58. - 📊 执行数据库查询并展示结果
  59. - 📈 生成数据可视化图表
  60. 请直接输入您想了解的数据问题,例如:
  61. - "交易次数最多的前5位客户是谁?"
  62. - "查看过去30天的交易趋势"
  63. 让我们开始探索数据吧!✨
  64. """
  65. await cl.Message(
  66. content=welcome_message,
  67. author="Vanna助手"
  68. ).send()