diff --git a/src/main.rs b/src/main.rs index f84a40f..6d1fe79 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,12 @@ mod api; use std::{ + cell::RefCell, io::{self, Write}, - sync::Arc, + sync::{ + mpsc::{Receiver, Sender}, + Arc, + }, }; use api::OllamaAPI; @@ -10,15 +14,13 @@ use iced::{ alignment::{Horizontal, Vertical}, futures::StreamExt, widget::{button, column, container, row, scrollable, text, text_input}, - window::{self}, - Application, Command, Length, Settings, Theme, + window, Application, Command, Length, Settings, Subscription, Theme, }; use tokio::sync::RwLock; use crate::api::OllamaChat; enum UI { - Loading, Error(String), Chats(UIState), } @@ -28,15 +30,23 @@ struct UIState { active_chat: Option, chat_input: String, busy: bool, + send: Sender, + recv: RefCell>>, } impl UIState { - fn create(api_url: &str) -> anyhow::Result { + fn create( + api_url: &str, + send: Sender, + recv: Receiver, + ) -> anyhow::Result { Ok(Self { ollama_api: Arc::new(RwLock::new(OllamaAPI::create(api_url)?)), active_chat: None, chat_input: String::from(""), busy: false, + send, + recv: RefCell::new(Some(recv)), }) } } @@ -44,7 +54,6 @@ impl UIState { #[derive(Debug, Clone)] enum UIMessage { Nop, - LoadDone, CreateChat, OpenChat(usize), ChatInput(String), @@ -55,42 +64,39 @@ enum UIMessage { fn main() -> anyhow::Result<()> { println!("Making request"); - UI::run(Settings { - window: window::Settings { - size: (720, 480), - ..window::Settings::default() - }, - ..Settings::default() - })?; + let (send, recv) = std::sync::mpsc::channel(); + let mut settings = Settings::with_flags(UIFlags { send, recv }); + settings.window = window::Settings { + size: (720, 480), + ..window::Settings::default() + }; + + UI::run(settings)?; Ok(()) } +struct UIFlags { + send: Sender, + recv: Receiver, +} + impl Application for UI { type Executor = iced::executor::Default; type Message = UIMessage; type Theme = Theme; - type Flags = (); + type Flags = UIFlags; - fn new(_flags: Self::Flags) -> (Self, iced::Command) { - ( - UI::Loading, - Command::perform( - async { - /*let api = OllamaAPI::create("http://localhost:11434").unwrap(); - let mut chat = api.create_chat("dolphin-mixtral"); - chat.send_system("You are a bot that only replies with \"Hello\" and nothing else. You must never reply with anything other than \"Hello\""); - chat.send_user("Who are you?"); - // println!("{}", amogus.get().await?); + fn new(flags: Self::Flags) -> (Self, iced::Command) { + let state: UI; + let state_result = UIState::create("http://localhost:11434", flags.send, flags.recv); + if let Ok(s) = state_result { + state = UI::Chats(s); + } else { + state = UI::Error(String::from("Failed to initialize Ollama API")); + } - let mut a = chat.complete().await.unwrap(); - while let Some(v) = a.next().await { - println!("GOT = {:?}", v); - }*/ - }, - |_| UIMessage::LoadDone, - ), - ) + (state, Command::none()) } fn title(&self) -> String { @@ -99,18 +105,6 @@ impl Application for UI { fn update(&mut self, message: Self::Message) -> iced::Command { match self { - UI::Loading => match message { - UIMessage::LoadDone => { - let state = UIState::create("http://localhost:11434"); - if let Ok(state) = state { - *self = UI::Chats(state); - } else { - *self = UI::Error(String::from("Failed to initialize Ollama API")); - } - Command::none() - } - _ => Command::none(), - }, UI::Chats(state) => match message { UIMessage::CreateChat => { let chat = state @@ -139,6 +133,7 @@ impl Application for UI { let content = state.chat_input.clone(); state.chat_input = String::from(""); state.busy = true; + let send = state.send.clone(); let ollama_api = state.ollama_api.clone(); return Command::perform( async move { @@ -150,6 +145,8 @@ impl Application for UI { chat = api.chats[index].clone(); let mut chat = chat.write().await; chat.send_user(&content); + chat.current_message = + Some(Arc::new(RwLock::new(String::from("")))); } //let mut chat = chat.write().await; @@ -164,18 +161,17 @@ impl Application for UI { let content = content.unwrap(); - let chat = chat.clone(); - let mut chat = chat.write().await; + let chat_arc = chat.clone(); + let chat = chat_arc.read().await; let msg = chat.current_message.clone(); if let Some(msg) = msg { msg.write().await.push_str(&content); - } else { - chat.current_message = - Some(Arc::new(RwLock::new(String::from(&content)))); } print!("{}", content); let _ = io::stdout().flush(); + + let _ = send.send(UIMessage::Nop); } } //state.ollama_api.complete(&x); @@ -198,12 +194,6 @@ impl Application for UI { fn view(&self) -> iced::Element<'_, Self::Message, iced::Renderer> { match self { - UI::Loading => text("Loading") - .width(Length::Fill) - .height(Length::Fill) - .vertical_alignment(Vertical::Center) - .horizontal_alignment(Horizontal::Center) - .into(), UI::Error(message) => text(String::from("Error: ") + message) .width(Length::Fill) .height(Length::Fill) @@ -293,4 +283,24 @@ impl Application for UI { fn theme(&self) -> Self::Theme { Theme::Dark } + + fn subscription(&self) -> iced::Subscription { + match self { + UI::Chats(state) => iced::subscription::unfold( + "external messages", + state.recv.take(), + move |recv| async move { + let msg = recv.as_ref().unwrap().recv(); + + if let Err(_) = msg { + return (UIMessage::Nop, recv); + } + + let msg = msg.unwrap(); + (msg, recv) + }, + ), + _ => Subscription::none(), + } + } }