mod api; use std::{ cell::RefCell, io::{self, Write}, sync::{ mpsc::{Receiver, Sender}, Arc, }, }; use api::{ChatMessage, OllamaAPI}; use iced::{ alignment::{Horizontal, Vertical}, futures::StreamExt, subscription, widget::{button, column, container, row, scrollable, text, text_input}, window, Application, Color, Command, Event, Length, Settings, Subscription, Theme, }; use tokio::sync::RwLock; use crate::api::{OllamaChat, Role}; enum UI { Error(String), Chats(UIState), Closing, } struct UIState { ollama_api: Arc>, active_chat: Option, chat_input: String, send: Sender, recv: RefCell>>, } impl UIState { fn create( api_url: &str, send: Sender, recv: Receiver, ) -> anyhow::Result { Ok(Self { ollama_api: Arc::new(RwLock::new(OllamaAPI::create(api_url)?)), active_chat: None, chat_input: String::new(), send, recv: RefCell::new(Some(recv)), }) } } #[derive(Debug, Clone)] enum UIMessage { Nop, EventOccurred(Event), CreateChat, OpenChat(usize), ChatInput(String), ChatSubmit, DoneGenerating, } fn main() -> anyhow::Result<()> { let (send, recv) = std::sync::mpsc::channel(); let mut settings = Settings::with_flags(UIFlags { send, recv }); settings.exit_on_close_request = false; settings.window = window::Settings { size: (720, 480), ..window::Settings::default() }; UI::run(settings)?; Ok(()) } struct UIFlags { send: Sender, recv: Receiver, } impl Application for UI { type Executor = iced::executor::Default; type Message = UIMessage; type Theme = Theme; type Flags = UIFlags; fn new(flags: Self::Flags) -> (Self, iced::Command) { let state: UI; let state_result = UIState::create("http://localhost:11434", flags.send, flags.recv); if let Ok(s) = state_result { state = UI::Chats(s); } else { state = UI::Error(String::from("Failed to initialize Ollama API")); } (state, Command::none()) } fn title(&self) -> String { String::from("AIGUI") } fn update(&mut self, message: Self::Message) -> iced::Command { if let UIMessage::EventOccurred(event) = &message { if let Event::Window(window::Event::CloseRequested) = event { *self = UI::Closing; return window::close(); } else { return Command::none(); } } match self { UI::Chats(state) => match message { UIMessage::CreateChat => { let chat = state .ollama_api .blocking_write() .create_chat("dolphin-mixtral"); if let Ok(chat) = chat { chat.blocking_write().send_system("You write stories about various topics.\n\nYou can include images into these stories using the following syntax:\n{IMAGE:description of image}\n\ne.g.:\n{IMAGE:an image of a large church, volumetric lighting, masterpiece, best quality}\n\nMake sure to use at least one image every five sentences."); } Command::none() } UIMessage::OpenChat(index) => { println!("Open chat {}", index); state.active_chat = Some(index); Command::none() } UIMessage::ChatInput(text) => { println!("Input: {}", text); state.chat_input = text; Command::none() } UIMessage::ChatSubmit => { println!("Submit"); if let Some(index) = state.active_chat { let content = state.chat_input.clone(); state.chat_input = String::new(); let send = state.send.clone(); let ollama_api = state.ollama_api.clone(); return Command::perform( async move { /* TODO: Stable diffusion let sd = StableDiffusionAPI::create("http://localhost:7860").unwrap(); let req = StableDiffusionRequest { prompt: String::from("among us"), steps: 10, }; let img = sd.generate_image(&req).await.unwrap(); println!("{:?}", img); std::fs::write("amogus.png", img);*/ let stream; let chat; { { let api = ollama_api.read().await; chat = api.chats[index].clone(); let mut chat = chat.write().await; chat.send_user(&content); chat.generating_message = Some(String::new()); } //let mut chat = chat.write().await; stream = OllamaChat::complete(chat.clone()).await; } if let Ok(mut stream) = stream { while let Some(content) = stream.next().await { if content.is_err() { break; } let content = content.unwrap(); let chat_arc = chat.clone(); let mut chat = chat_arc.write().await; let msg = &mut chat.generating_message; if let Some(msg) = msg { msg.push_str(&content); } print!("{}", content); let _ = io::stdout().flush(); let _ = send.send(UIMessage::Nop); } let mut chat = chat.write().await; let msg = &chat.generating_message.clone(); if let Some(msg) = msg { chat.send_assistant(&msg); chat.generating_message = None; } } }, |_| UIMessage::DoneGenerating, ); } Command::none() } UIMessage::DoneGenerating => Command::none(), _ => Command::none(), }, UI::Error(_) => Command::none(), UI::Closing => Command::none(), } } fn view(&self) -> iced::Element<'_, Self::Message, iced::Renderer> { match self { UI::Error(message) => text(String::from("Error: ") + message) .width(Length::Fill) .height(Length::Fill) .vertical_alignment(Vertical::Center) .horizontal_alignment(Horizontal::Center) .into(), UI::Chats(state) => { let api = state.ollama_api.blocking_read(); let mut chats: Vec>> = api.chats .iter() .enumerate() .map(|c| { button("Chat") .on_press(UIMessage::OpenChat(c.0)) .width(Length::Fill) .into() }) .collect(); chats.insert( 0, button("Create Chat") .on_press(UIMessage::CreateChat) .width(Length::Fill) .into(), ); let chats_list = scrollable(container(column(chats).padding(10.0).spacing(10))) .height(Length::Fill) .width(Length::Fixed(300.0)); let mut messages: Vec< iced::Element<'_, Self::Message, iced::Renderer>, >; let mut input_box = text_input("Your message", &state.chat_input); if let Some(chat_index) = state.active_chat { let chat = &api.chats[chat_index].blocking_read(); messages = chat .messages .iter() .map(|x| self.render_message(x)) .collect(); if let Some(msg) = &chat.generating_message { let mut msg = msg.clone(); msg.push_str("..."); let in_progress_message = ChatMessage { role: Role::Assistant, content: msg, }; messages.push(self.render_message(&in_progress_message)); } else { input_box = input_box .on_input(|x| UIMessage::ChatInput(x)) .on_submit(UIMessage::ChatSubmit) } } else { messages = Vec::new(); } let active_chat = container( column![ scrollable(container(column(messages).spacing(10))) .width(Length::Fill) .height(Length::Fill), input_box ] .padding(10) .spacing(10), ) .width(Length::Fill) .height(Length::Fill); container(row![chats_list, active_chat]).into() } UI::Closing => text("Closing").into(), } } fn theme(&self) -> Self::Theme { Theme::Dark } fn subscription(&self) -> iced::Subscription { let event_subscription = subscription::events().map(UIMessage::EventOccurred); match self { UI::Chats(state) => { let recv_subscription = iced::subscription::unfold( "external messages", state.recv.take(), move |recv| async move { let msg = recv.as_ref().unwrap().recv(); if let Err(_) = msg { return (UIMessage::Nop, recv); } let msg = msg.unwrap(); (msg, recv) }, ); Subscription::batch(vec![event_subscription, recv_subscription]) } _ => event_subscription, } } } impl UI { fn render_message( &self, message: &ChatMessage, ) -> iced::Element<'_, UIMessage, iced::Renderer> { let (role_name, role_color) = match &message.role { Role::User => ("User", Color::from_rgb8(128, 128, 128)), Role::Assistant => ("Assistant", Color::from_rgb8(100, 100, 255)), Role::System => ("System", Color::from_rgb8(255, 0, 0)), }; column![text(role_name).style(role_color), text(&message.content)].into() } }