app.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import chainlit as cl
  2. from chainlit.input_widget import Select
  3. import vanna as vn
  4. import os
  5. vn.set_api_key(os.environ['VANNA_API_KEY'])
  6. vn.set_model('chinook')
  7. vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
  8. @cl.step(root=True, language="sql", name="Vanna")
  9. async def gen_query(human_query: str):
  10. sql_query = vn.generate_sql(human_query)
  11. return sql_query
  12. @cl.step(root=True, name="Vanna")
  13. async def execute_query(query):
  14. current_step = cl.context.current_step
  15. df = vn.run_sql(query)
  16. current_step.output = df.head().to_markdown(index=False)
  17. return df
  18. @cl.step(name="Plot", language="python")
  19. async def plot(human_query, sql, df):
  20. current_step = cl.context.current_step
  21. plotly_code = vn.generate_plotly_code(question=human_query, sql=sql, df=df)
  22. fig = vn.get_plotly_figure(plotly_code=plotly_code, df=df)
  23. current_step.output = plotly_code
  24. return fig
  25. @cl.step(type="run", root=True, name="Vanna")
  26. async def chain(human_query: str):
  27. sql_query = await gen_query(human_query)
  28. df = await execute_query(sql_query)
  29. fig = await plot(human_query, sql_query, df)
  30. elements = [cl.Plotly(name="chart", figure=fig, display="inline")]
  31. await cl.Message(content=human_query, elements=elements, author="Vanna").send()
  32. @cl.on_message
  33. async def main(message: cl.Message):
  34. await chain(message.content)
  35. @cl.on_chat_start
  36. async def setup():
  37. await cl.Avatar(
  38. name="Vanna",
  39. url="https://app.vanna.ai/vanna.svg",
  40. ).send()
  41. settings = await cl.ChatSettings(
  42. [
  43. Select(
  44. id="Model",
  45. label="OpenAI - Model",
  46. values=["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k"],
  47. initial_index=0,
  48. )
  49. ]
  50. ).send()
  51. value = settings["Model"]