diff --git a/src/api.rs b/src/api.rs index 8e9c39b..e174d6c 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,26 +1,35 @@ +use std::ops::{Deref, DerefMut}; use std::sync::Arc; +use iced::futures::StreamExt; use reqwest::Client; use serde::{Deserialize, Serialize}; use tokio::io::AsyncBufReadExt; use tokio::sync::RwLock; use tokio_stream::wrappers::LinesStream; -use tokio_stream::{Stream, StreamExt}; +use tokio_stream::Stream; use tokio_util::io::StreamReader; pub struct OllamaAPI { pub api_url: String, - client: Client, pub chats: Vec>>, } #[derive(Serialize)] pub struct OllamaChat { + #[serde(skip_serializing)] + client: Client, + #[serde(skip_serializing)] + api_url: String, + #[serde(skip_serializing)] pub history_size: u32, pub model: String, pub messages: Vec, + + #[serde(skip_serializing)] + pub current_message: Option>>, } #[derive(Serialize, Deserialize, Clone)] @@ -55,49 +64,26 @@ fn deserialize_chunk(chunk: Result) -> anyhow::Result anyhow::Result>> { - let body = serde_json::to_string(&chat)?; - println!("Sending: {}", body); - - let request = self - .client - .post(self.api_url.clone() + "/api/chat") - .body(serde_json::to_string(&chat)?) - .build()?; - let stream = self - .client - .execute(request) - .await? - .bytes_stream() - .map(|x| match x { - Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)), - Ok(bytes) => Ok(bytes), - }); - let lines = StreamReader::new(stream).lines(); - Ok(LinesStream::new(lines).map(deserialize_chunk)) - } - pub fn create(api_url: &str) -> anyhow::Result { - let client = Client::builder().build()?; Ok(Self { api_url: String::from(api_url), - client, chats: Vec::new(), }) } - pub fn create_chat(&mut self, model: &str) -> Arc> { + pub fn create_chat(&mut self, model: &str) -> anyhow::Result>> { + let client = Client::builder().build()?; let chat = Arc::new(RwLock::new(OllamaChat { history_size: 0, + client, + api_url: self.api_url.clone(), model: String::from(model), messages: Vec::new(), + current_message: None, })); self.chats.push(chat.clone()); - chat + Ok(chat) } } @@ -120,4 +106,32 @@ impl OllamaChat { pub fn send_assistant(&mut self, content: &str) { self.send_message(Role::Assistant, content); } + + pub async fn complete( + chat: Arc>, + ) -> anyhow::Result>> { + let mut chat = chat.write().await; + let body = serde_json::to_string(chat.deref())?; + println!("Sending: {}", body); + + let msg = Arc::new(RwLock::new(String::from(""))); + chat.current_message = Some(msg.clone()); + + let request = chat + .client + .post(chat.api_url.clone() + "/api/chat") + .body(serde_json::to_string(chat.deref())?) + .build()?; + let stream = chat + .client + .execute(request) + .await? + .bytes_stream() + .map(|x| match x { + Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)), + Ok(bytes) => Ok(bytes), + }); + let lines = StreamReader::new(stream).lines(); + Ok(LinesStream::new(lines).map(deserialize_chunk)) + } } diff --git a/src/main.rs b/src/main.rs index 83e9f5d..f84a40f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,11 @@ mod api; -use std::{sync::Arc}; +use std::{ + io::{self, Write}, + sync::Arc, +}; -use api::{OllamaAPI, OllamaChat}; +use api::OllamaAPI; use iced::{ alignment::{Horizontal, Vertical}, futures::StreamExt, @@ -10,7 +13,9 @@ use iced::{ window::{self}, Application, Command, Length, Settings, Theme, }; -use tokio::sync::{RwLock}; +use tokio::sync::RwLock; + +use crate::api::OllamaChat; enum UI { Loading, @@ -18,10 +23,6 @@ enum UI { Chats(UIState), } -struct Chat { - ollama_chat: OllamaChat, -} - struct UIState { ollama_api: Arc>, active_chat: Option, @@ -116,7 +117,9 @@ impl Application for UI { .ollama_api .blocking_write() .create_chat("dolphin-mixtral"); - chat.blocking_write().send_system("Hello World!"); + if let Ok(chat) = chat { + chat.blocking_write().send_system("Hello World!"); + } Command::none() } UIMessage::OpenChat(index) => { @@ -140,16 +143,17 @@ impl Application for UI { return Command::perform( async move { let stream; + let chat; { - let api = ollama_api.read().await; - let chat = &api.chats[index]; { + let api = ollama_api.read().await; + chat = api.chats[index].clone(); let mut chat = chat.write().await; chat.send_user(&content); } - let chat = chat.read().await; - stream = api.complete(&chat).await; + //let mut chat = chat.write().await; + stream = OllamaChat::complete(chat.clone()).await; } if let Ok(mut stream) = stream { @@ -158,7 +162,20 @@ impl Application for UI { break; } - print!("{}", content.unwrap()); + let content = content.unwrap(); + + let chat = chat.clone(); + let mut chat = chat.write().await; + let msg = chat.current_message.clone(); + if let Some(msg) = msg { + msg.write().await.push_str(&content); + } else { + chat.current_message = + Some(Arc::new(RwLock::new(String::from(&content)))); + } + + print!("{}", content); + let _ = io::stdout().flush(); } } //state.ollama_api.complete(&x); @@ -222,16 +239,21 @@ impl Application for UI { .height(Length::Fill) .width(Length::Fixed(300.0)); - let messages; + let mut messages: Vec< + iced::Element<'_, Self::Message, iced::Renderer>, + >; if let Some(chat_index) = state.active_chat { - let chat = &api.chats[chat_index]; + let chat = &api.chats[chat_index].blocking_read(); messages = chat - .blocking_read() .messages .iter() .map(|x| x.content.clone()) .map(|x| text(x).into()) .collect(); + + if let Some(msg) = &chat.current_message { + messages.push(text(msg.blocking_read()).into()); + } } else { messages = Vec::new(); }