Weird async magic

This commit is contained in:
MrLetsplay 2024-02-02 23:43:04 +01:00
parent 6914f23313
commit 4b505c26e1
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
4 changed files with 950 additions and 757 deletions

1403
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -7,8 +7,7 @@ edition = "2021"
[dependencies] [dependencies]
anyhow = "1.0.79" anyhow = "1.0.79"
iced = "0.9.0" iced = { version = "0.10.0", features = ["tokio"] }
iced_web = "0.4.0"
reqwest = { version = "0.11.24", features = ["stream"] } reqwest = { version = "0.11.24", features = ["stream"] }
serde = { version = "1.0.196", features = ["derive"] } serde = { version = "1.0.196", features = ["derive"] }
serde_json = "1.0.113" serde_json = "1.0.113"

View File

@ -1,27 +1,29 @@
use std::sync::Arc;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::io::AsyncBufReadExt; use tokio::io::AsyncBufReadExt;
use tokio::sync::RwLock;
use tokio_stream::wrappers::LinesStream; use tokio_stream::wrappers::LinesStream;
use tokio_stream::{Stream, StreamExt}; use tokio_stream::{Stream, StreamExt};
use tokio_util::io::StreamReader; use tokio_util::io::StreamReader;
pub struct OllamaAPI { pub struct OllamaAPI {
api_url: &'static str, pub api_url: String,
client: Client, client: Client,
pub chats: Vec<Arc<RwLock<OllamaChat>>>,
} }
#[derive(Serialize)] #[derive(Serialize)]
pub struct OllamaChat<'a> { pub struct OllamaChat {
#[serde(skip_serializing)] #[serde(skip_serializing)]
api: &'a OllamaAPI, pub history_size: u32,
#[serde(skip_serializing)]
history_size: u32,
model: String, pub model: String,
messages: Vec<ChatMessage>, pub messages: Vec<ChatMessage>,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Clone)]
pub enum Role { pub enum Role {
#[serde(rename = "system")] #[serde(rename = "system")]
System, System,
@ -31,10 +33,10 @@ pub enum Role {
Assistant, Assistant,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Clone)]
pub struct ChatMessage { pub struct ChatMessage {
role: Role, pub role: Role,
content: String, pub content: String,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -53,16 +55,16 @@ fn deserialize_chunk(chunk: Result<String, std::io::Error>) -> anyhow::Result<St
} }
impl OllamaAPI { impl OllamaAPI {
pub async fn complete<'a>( pub async fn complete(
&self, &self,
chat: &OllamaChat<'a>, chat: &OllamaChat,
) -> anyhow::Result<impl Stream<Item = anyhow::Result<String>>> { ) -> anyhow::Result<impl Stream<Item = anyhow::Result<String>>> {
let body = serde_json::to_string(&chat)?; let body = serde_json::to_string(&chat)?;
println!("Sending: {}", body); println!("Sending: {}", body);
let request = self let request = self
.client .client
.post(String::from(self.api_url) + "/api/chat") .post(self.api_url.clone() + "/api/chat")
.body(serde_json::to_string(&chat)?) .body(serde_json::to_string(&chat)?)
.build()?; .build()?;
let stream = self let stream = self
@ -78,25 +80,28 @@ impl OllamaAPI {
Ok(LinesStream::new(lines).map(deserialize_chunk)) Ok(LinesStream::new(lines).map(deserialize_chunk))
} }
pub fn create(api_url: &'static str) -> anyhow::Result<Self> { pub fn create(api_url: &str) -> anyhow::Result<Self> {
let client = Client::builder().build()?; let client = Client::builder().build()?;
Ok(Self { Ok(Self {
api_url: api_url, api_url: String::from(api_url),
client, client,
chats: Vec::new(),
}) })
} }
pub fn create_chat(&self, model: &str) -> OllamaChat { pub fn create_chat(&mut self, model: &str) -> Arc<RwLock<OllamaChat>> {
OllamaChat { let chat = Arc::new(RwLock::new(OllamaChat {
api: &self,
history_size: 0, history_size: 0,
model: String::from(model), model: String::from(model),
messages: Vec::new(), messages: Vec::new(),
} }));
self.chats.push(chat.clone());
chat
} }
} }
impl<'a> OllamaChat<'a> { impl OllamaChat {
pub fn send_message(&mut self, role: Role, content: &str) { pub fn send_message(&mut self, role: Role, content: &str) {
self.messages.push(ChatMessage { self.messages.push(ChatMessage {
role, role,
@ -115,8 +120,4 @@ impl<'a> OllamaChat<'a> {
pub fn send_assistant(&mut self, content: &str) { pub fn send_assistant(&mut self, content: &str) {
self.send_message(Role::Assistant, content); self.send_message(Role::Assistant, content);
} }
pub async fn complete(&self) -> anyhow::Result<impl Stream<Item = anyhow::Result<String>>> {
self.api.complete(&self).await
}
} }

View File

@ -1,38 +1,59 @@
mod api; mod api;
use api::OllamaAPI; use std::{sync::Arc};
use api::{OllamaAPI, OllamaChat};
use iced::{ use iced::{
alignment::{Horizontal, Vertical},
futures::StreamExt, futures::StreamExt,
widget::{button, column, container, row, scrollable, text}, widget::{button, column, container, row, scrollable, text, text_input},
window::{self}, window::{self},
Application, Command, Settings, Theme, Application, Command, Length, Settings, Theme,
}; };
use tokio::sync::{RwLock};
enum UI { enum UI {
Loading, Loading,
Loaded, Error(String),
Chats(UIState),
}
struct Chat {
ollama_chat: OllamaChat,
}
struct UIState {
ollama_api: Arc<RwLock<OllamaAPI>>,
active_chat: Option<usize>,
chat_input: String,
busy: bool,
}
impl UIState {
fn create(api_url: &str) -> anyhow::Result<Self> {
Ok(Self {
ollama_api: Arc::new(RwLock::new(OllamaAPI::create(api_url)?)),
active_chat: None,
chat_input: String::from(""),
busy: false,
})
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
enum UIMessage { enum UIMessage {
Nop,
LoadDone, LoadDone,
CreateChat,
OpenChat(usize),
ChatInput(String),
ChatSubmit,
DoneGenerating,
} }
#[tokio::main] fn main() -> anyhow::Result<()> {
async fn main() -> anyhow::Result<()> {
println!("Making request"); println!("Making request");
let api = OllamaAPI::create("http://localhost:11434")?;
let mut chat = api.create_chat("dolphin-mixtral");
chat.send_system("You are a bot that only replies with \"Hello\" and nothing else. You must never reply with anything other than \"Hello\"");
chat.send_user("Who are you?");
// println!("{}", amogus.get().await?);
let mut a = chat.complete().await?;
while let Some(v) = a.next().await {
println!("GOT = {:?}", v);
}
UI::run(Settings { UI::run(Settings {
window: window::Settings { window: window::Settings {
size: (720, 480), size: (720, 480),
@ -53,7 +74,21 @@ impl Application for UI {
fn new(_flags: Self::Flags) -> (Self, iced::Command<Self::Message>) { fn new(_flags: Self::Flags) -> (Self, iced::Command<Self::Message>) {
( (
UI::Loading, UI::Loading,
Command::perform(async {}, |_| UIMessage::LoadDone), Command::perform(
async {
/*let api = OllamaAPI::create("http://localhost:11434").unwrap();
let mut chat = api.create_chat("dolphin-mixtral");
chat.send_system("You are a bot that only replies with \"Hello\" and nothing else. You must never reply with anything other than \"Hello\"");
chat.send_user("Who are you?");
// println!("{}", amogus.get().await?);
let mut a = chat.complete().await.unwrap();
while let Some(v) = a.next().await {
println!("GOT = {:?}", v);
}*/
},
|_| UIMessage::LoadDone,
),
) )
} }
@ -62,28 +97,175 @@ impl Application for UI {
} }
fn update(&mut self, message: Self::Message) -> iced::Command<Self::Message> { fn update(&mut self, message: Self::Message) -> iced::Command<Self::Message> {
match message { match self {
UI::Loading => match message {
UIMessage::LoadDone => { UIMessage::LoadDone => {
*self = UI::Loaded; let state = UIState::create("http://localhost:11434");
if let Ok(state) = state {
*self = UI::Chats(state);
} else {
*self = UI::Error(String::from("Failed to initialize Ollama API"));
}
Command::none() Command::none()
} }
_ => Command::none(),
},
UI::Chats(state) => match message {
UIMessage::CreateChat => {
let chat = state
.ollama_api
.blocking_write()
.create_chat("dolphin-mixtral");
chat.blocking_write().send_system("Hello World!");
Command::none()
}
UIMessage::OpenChat(index) => {
println!("Open chat {}", index);
state.active_chat = Some(index);
Command::none()
}
UIMessage::ChatInput(text) => {
println!("Input: {}", text);
state.chat_input = text;
Command::none()
}
UIMessage::ChatSubmit => {
println!("Submit");
if let Some(index) = state.active_chat {
let content = state.chat_input.clone();
state.chat_input = String::from("");
state.busy = true;
let ollama_api = state.ollama_api.clone();
return Command::perform(
async move {
let stream;
{
let api = ollama_api.read().await;
let chat = &api.chats[index];
{
let mut chat = chat.write().await;
chat.send_user(&content);
}
let chat = chat.read().await;
stream = api.complete(&chat).await;
}
if let Ok(mut stream) = stream {
while let Some(content) = stream.next().await {
if content.is_err() {
break;
}
print!("{}", content.unwrap());
}
}
//state.ollama_api.complete(&x);
},
|_| UIMessage::DoneGenerating,
);
}
Command::none()
}
UIMessage::DoneGenerating => {
state.busy = false;
Command::none()
}
_ => Command::none(),
},
UI::Error(_) => Command::none(),
} }
} }
fn view(&self) -> iced::Element<'_, Self::Message, iced::Renderer<Self::Theme>> { fn view(&self) -> iced::Element<'_, Self::Message, iced::Renderer<Self::Theme>> {
let content = match self { match self {
UI::Loading => row![text("This is amogus sussy baka!")], UI::Loading => text("Loading")
UI::Loaded => row![button(text("Hello Wordl!")).on_press(UIMessage::LoadDone)], .width(Length::Fill)
}; .height(Length::Fill)
.vertical_alignment(Vertical::Center)
.horizontal_alignment(Horizontal::Center)
.into(),
UI::Error(message) => text(String::from("Error: ") + message)
.width(Length::Fill)
.height(Length::Fill)
.vertical_alignment(Vertical::Center)
.horizontal_alignment(Horizontal::Center)
.into(),
UI::Chats(state) => {
println!("Blocking");
let api = state.ollama_api.blocking_read();
println!("Blocing done!");
let amogus = text("Hello!"); let mut chats: Vec<iced::Element<'_, Self::Message, iced::Renderer<Self::Theme>>> =
api.chats
.iter()
.enumerate()
.map(|c| {
button("amogus")
.on_press(UIMessage::OpenChat(c.0))
.width(Length::Fill)
.into()
})
.collect();
scrollable( chats.insert(
0,
button("Create Chat")
.on_press(UIMessage::CreateChat)
.width(Length::Fill)
.into(),
);
let chats_list = container(column(chats).padding(10.0).spacing(10))
.height(Length::Fill)
.width(Length::Fixed(300.0));
let messages;
if let Some(chat_index) = state.active_chat {
let chat = &api.chats[chat_index];
messages = chat
.blocking_read()
.messages
.iter()
.map(|x| x.content.clone())
.map(|x| text(x).into())
.collect();
} else {
messages = Vec::new();
}
let mut input_box = text_input("Your message", &state.chat_input);
if !state.busy {
input_box = input_box
.on_input(|x| UIMessage::ChatInput(x))
.on_submit(UIMessage::ChatSubmit)
}
let active_chat = container(
column![
scrollable(container(column(messages).spacing(10)))
.width(Length::Fill)
.height(Length::Fill),
input_box
]
.padding(10)
.spacing(10),
)
.width(Length::Fill)
.height(Length::Fill);
container(row![chats_list, active_chat]).into()
/*scrollable(
container(column![content, amogus].spacing(20).max_width(480)) container(column![content, amogus].spacing(20).max_width(480))
.padding(40) .padding(40)
.center_x(), .center_x(),
) )
.into() .into()*/
}
}
} }
fn theme(&self) -> Self::Theme { fn theme(&self) -> Self::Theme {