UI pain (mostly broken)

This commit is contained in:
MrLetsplay 2024-02-06 22:27:37 +01:00
parent 4b505c26e1
commit d90151ee5f
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
2 changed files with 83 additions and 47 deletions

View File

@ -1,26 +1,35 @@
use std::ops::{Deref, DerefMut};
use std::sync::Arc; use std::sync::Arc;
use iced::futures::StreamExt;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::io::AsyncBufReadExt; use tokio::io::AsyncBufReadExt;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tokio_stream::wrappers::LinesStream; use tokio_stream::wrappers::LinesStream;
use tokio_stream::{Stream, StreamExt}; use tokio_stream::Stream;
use tokio_util::io::StreamReader; use tokio_util::io::StreamReader;
pub struct OllamaAPI { pub struct OllamaAPI {
pub api_url: String, pub api_url: String,
client: Client,
pub chats: Vec<Arc<RwLock<OllamaChat>>>, pub chats: Vec<Arc<RwLock<OllamaChat>>>,
} }
#[derive(Serialize)] #[derive(Serialize)]
pub struct OllamaChat { pub struct OllamaChat {
#[serde(skip_serializing)]
client: Client,
#[serde(skip_serializing)]
api_url: String,
#[serde(skip_serializing)] #[serde(skip_serializing)]
pub history_size: u32, pub history_size: u32,
pub model: String, pub model: String,
pub messages: Vec<ChatMessage>, pub messages: Vec<ChatMessage>,
#[serde(skip_serializing)]
pub current_message: Option<Arc<RwLock<String>>>,
} }
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
@ -55,49 +64,26 @@ fn deserialize_chunk(chunk: Result<String, std::io::Error>) -> anyhow::Result<St
} }
impl OllamaAPI { impl OllamaAPI {
pub async fn complete(
&self,
chat: &OllamaChat,
) -> anyhow::Result<impl Stream<Item = anyhow::Result<String>>> {
let body = serde_json::to_string(&chat)?;
println!("Sending: {}", body);
let request = self
.client
.post(self.api_url.clone() + "/api/chat")
.body(serde_json::to_string(&chat)?)
.build()?;
let stream = self
.client
.execute(request)
.await?
.bytes_stream()
.map(|x| match x {
Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)),
Ok(bytes) => Ok(bytes),
});
let lines = StreamReader::new(stream).lines();
Ok(LinesStream::new(lines).map(deserialize_chunk))
}
pub fn create(api_url: &str) -> anyhow::Result<Self> { pub fn create(api_url: &str) -> anyhow::Result<Self> {
let client = Client::builder().build()?;
Ok(Self { Ok(Self {
api_url: String::from(api_url), api_url: String::from(api_url),
client,
chats: Vec::new(), chats: Vec::new(),
}) })
} }
pub fn create_chat(&mut self, model: &str) -> Arc<RwLock<OllamaChat>> { pub fn create_chat(&mut self, model: &str) -> anyhow::Result<Arc<RwLock<OllamaChat>>> {
let client = Client::builder().build()?;
let chat = Arc::new(RwLock::new(OllamaChat { let chat = Arc::new(RwLock::new(OllamaChat {
history_size: 0, history_size: 0,
client,
api_url: self.api_url.clone(),
model: String::from(model), model: String::from(model),
messages: Vec::new(), messages: Vec::new(),
current_message: None,
})); }));
self.chats.push(chat.clone()); self.chats.push(chat.clone());
chat Ok(chat)
} }
} }
@ -120,4 +106,32 @@ impl OllamaChat {
pub fn send_assistant(&mut self, content: &str) { pub fn send_assistant(&mut self, content: &str) {
self.send_message(Role::Assistant, content); self.send_message(Role::Assistant, content);
} }
pub async fn complete(
chat: Arc<RwLock<OllamaChat>>,
) -> anyhow::Result<impl Stream<Item = anyhow::Result<String>>> {
let mut chat = chat.write().await;
let body = serde_json::to_string(chat.deref())?;
println!("Sending: {}", body);
let msg = Arc::new(RwLock::new(String::from("")));
chat.current_message = Some(msg.clone());
let request = chat
.client
.post(chat.api_url.clone() + "/api/chat")
.body(serde_json::to_string(chat.deref())?)
.build()?;
let stream = chat
.client
.execute(request)
.await?
.bytes_stream()
.map(|x| match x {
Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)),
Ok(bytes) => Ok(bytes),
});
let lines = StreamReader::new(stream).lines();
Ok(LinesStream::new(lines).map(deserialize_chunk))
}
} }

View File

@ -1,8 +1,11 @@
mod api; mod api;
use std::{sync::Arc}; use std::{
io::{self, Write},
sync::Arc,
};
use api::{OllamaAPI, OllamaChat}; use api::OllamaAPI;
use iced::{ use iced::{
alignment::{Horizontal, Vertical}, alignment::{Horizontal, Vertical},
futures::StreamExt, futures::StreamExt,
@ -10,7 +13,9 @@ use iced::{
window::{self}, window::{self},
Application, Command, Length, Settings, Theme, Application, Command, Length, Settings, Theme,
}; };
use tokio::sync::{RwLock}; use tokio::sync::RwLock;
use crate::api::OllamaChat;
enum UI { enum UI {
Loading, Loading,
@ -18,10 +23,6 @@ enum UI {
Chats(UIState), Chats(UIState),
} }
struct Chat {
ollama_chat: OllamaChat,
}
struct UIState { struct UIState {
ollama_api: Arc<RwLock<OllamaAPI>>, ollama_api: Arc<RwLock<OllamaAPI>>,
active_chat: Option<usize>, active_chat: Option<usize>,
@ -116,7 +117,9 @@ impl Application for UI {
.ollama_api .ollama_api
.blocking_write() .blocking_write()
.create_chat("dolphin-mixtral"); .create_chat("dolphin-mixtral");
if let Ok(chat) = chat {
chat.blocking_write().send_system("Hello World!"); chat.blocking_write().send_system("Hello World!");
}
Command::none() Command::none()
} }
UIMessage::OpenChat(index) => { UIMessage::OpenChat(index) => {
@ -140,16 +143,17 @@ impl Application for UI {
return Command::perform( return Command::perform(
async move { async move {
let stream; let stream;
let chat;
{
{ {
let api = ollama_api.read().await; let api = ollama_api.read().await;
let chat = &api.chats[index]; chat = api.chats[index].clone();
{
let mut chat = chat.write().await; let mut chat = chat.write().await;
chat.send_user(&content); chat.send_user(&content);
} }
let chat = chat.read().await;
stream = api.complete(&chat).await; //let mut chat = chat.write().await;
stream = OllamaChat::complete(chat.clone()).await;
} }
if let Ok(mut stream) = stream { if let Ok(mut stream) = stream {
@ -158,7 +162,20 @@ impl Application for UI {
break; break;
} }
print!("{}", content.unwrap()); let content = content.unwrap();
let chat = chat.clone();
let mut chat = chat.write().await;
let msg = chat.current_message.clone();
if let Some(msg) = msg {
msg.write().await.push_str(&content);
} else {
chat.current_message =
Some(Arc::new(RwLock::new(String::from(&content))));
}
print!("{}", content);
let _ = io::stdout().flush();
} }
} }
//state.ollama_api.complete(&x); //state.ollama_api.complete(&x);
@ -222,16 +239,21 @@ impl Application for UI {
.height(Length::Fill) .height(Length::Fill)
.width(Length::Fixed(300.0)); .width(Length::Fixed(300.0));
let messages; let mut messages: Vec<
iced::Element<'_, Self::Message, iced::Renderer<Self::Theme>>,
>;
if let Some(chat_index) = state.active_chat { if let Some(chat_index) = state.active_chat {
let chat = &api.chats[chat_index]; let chat = &api.chats[chat_index].blocking_read();
messages = chat messages = chat
.blocking_read()
.messages .messages
.iter() .iter()
.map(|x| x.content.clone()) .map(|x| x.content.clone())
.map(|x| text(x).into()) .map(|x| text(x).into())
.collect(); .collect();
if let Some(msg) = &chat.current_message {
messages.push(text(msg.blocking_read()).into());
}
} else { } else {
messages = Vec::new(); messages = Vec::new();
} }