AIGUI/src/main.rs

343 lines
12 KiB
Rust

mod api;
use std::{
cell::RefCell,
io::{self, Write},
sync::{
mpsc::{Receiver, Sender},
Arc,
},
};
use api::{ChatMessage, OllamaAPI};
use iced::{
alignment::{Horizontal, Vertical},
futures::StreamExt,
subscription,
widget::{button, column, container, row, scrollable, text, text_input},
window, Application, Color, Command, Event, Length, Settings, Subscription, Theme,
};
use tokio::sync::RwLock;
use crate::api::{OllamaChat, Role};
enum UI {
Error(String),
Chats(UIState),
Closing,
}
struct UIState {
ollama_api: Arc<RwLock<OllamaAPI>>,
active_chat: Option<usize>,
chat_input: String,
send: Sender<UIMessage>,
recv: RefCell<Option<Receiver<UIMessage>>>,
}
impl UIState {
fn create(
api_url: &str,
send: Sender<UIMessage>,
recv: Receiver<UIMessage>,
) -> anyhow::Result<Self> {
Ok(Self {
ollama_api: Arc::new(RwLock::new(OllamaAPI::create(api_url)?)),
active_chat: None,
chat_input: String::new(),
send,
recv: RefCell::new(Some(recv)),
})
}
}
#[derive(Debug, Clone)]
enum UIMessage {
Nop,
EventOccurred(Event),
CreateChat,
OpenChat(usize),
ChatInput(String),
ChatSubmit,
DoneGenerating,
}
fn main() -> anyhow::Result<()> {
let (send, recv) = std::sync::mpsc::channel();
let mut settings = Settings::with_flags(UIFlags { send, recv });
settings.exit_on_close_request = false;
settings.window = window::Settings {
size: (720, 480),
..window::Settings::default()
};
UI::run(settings)?;
Ok(())
}
struct UIFlags {
send: Sender<UIMessage>,
recv: Receiver<UIMessage>,
}
impl Application for UI {
type Executor = iced::executor::Default;
type Message = UIMessage;
type Theme = Theme;
type Flags = UIFlags;
fn new(flags: Self::Flags) -> (Self, iced::Command<Self::Message>) {
let state: UI;
let state_result = UIState::create("http://localhost:11434", flags.send, flags.recv);
if let Ok(s) = state_result {
state = UI::Chats(s);
} else {
state = UI::Error(String::from("Failed to initialize Ollama API"));
}
(state, Command::none())
}
fn title(&self) -> String {
String::from("AIGUI")
}
fn update(&mut self, message: Self::Message) -> iced::Command<Self::Message> {
if let UIMessage::EventOccurred(event) = &message {
if let Event::Window(window::Event::CloseRequested) = event {
*self = UI::Closing;
return window::close();
} else {
return Command::none();
}
}
match self {
UI::Chats(state) => match message {
UIMessage::CreateChat => {
let chat = state
.ollama_api
.blocking_write()
.create_chat("dolphin-mixtral");
if let Ok(chat) = chat {
chat.blocking_write().send_system("You write stories about various topics.\n\nYou can include images into these stories using the following syntax:\n{IMAGE:description of image}\n\ne.g.:\n{IMAGE:an image of a large church, volumetric lighting, masterpiece, best quality}\n\nMake sure to use at least one image every five sentences.");
}
Command::none()
}
UIMessage::OpenChat(index) => {
println!("Open chat {}", index);
state.active_chat = Some(index);
Command::none()
}
UIMessage::ChatInput(text) => {
println!("Input: {}", text);
state.chat_input = text;
Command::none()
}
UIMessage::ChatSubmit => {
println!("Submit");
if let Some(index) = state.active_chat {
let content = state.chat_input.clone();
state.chat_input = String::new();
let send = state.send.clone();
let ollama_api = state.ollama_api.clone();
return Command::perform(
async move {
/* TODO: Stable diffusion
let sd =
StableDiffusionAPI::create("http://localhost:7860").unwrap();
let req = StableDiffusionRequest {
prompt: String::from("among us"),
steps: 10,
};
let img = sd.generate_image(&req).await.unwrap();
println!("{:?}", img);
std::fs::write("amogus.png", img);*/
let stream;
let chat;
{
{
let api = ollama_api.read().await;
chat = api.chats[index].clone();
let mut chat = chat.write().await;
chat.send_user(&content);
chat.generating_message = Some(String::new());
}
//let mut chat = chat.write().await;
stream = OllamaChat::complete(chat.clone()).await;
}
if let Ok(mut stream) = stream {
while let Some(content) = stream.next().await {
if content.is_err() {
break;
}
let content = content.unwrap();
let chat_arc = chat.clone();
let mut chat = chat_arc.write().await;
let msg = &mut chat.generating_message;
if let Some(msg) = msg {
msg.push_str(&content);
}
print!("{}", content);
let _ = io::stdout().flush();
let _ = send.send(UIMessage::Nop);
}
let mut chat = chat.write().await;
let msg = &chat.generating_message.clone();
if let Some(msg) = msg {
chat.send_assistant(&msg);
chat.generating_message = None;
}
}
},
|_| UIMessage::DoneGenerating,
);
}
Command::none()
}
UIMessage::DoneGenerating => Command::none(),
_ => Command::none(),
},
UI::Error(_) => Command::none(),
UI::Closing => Command::none(),
}
}
fn view(&self) -> iced::Element<'_, Self::Message, iced::Renderer<Self::Theme>> {
match self {
UI::Error(message) => text(String::from("Error: ") + message)
.width(Length::Fill)
.height(Length::Fill)
.vertical_alignment(Vertical::Center)
.horizontal_alignment(Horizontal::Center)
.into(),
UI::Chats(state) => {
let api = state.ollama_api.blocking_read();
let mut chats: Vec<iced::Element<'_, Self::Message, iced::Renderer<Self::Theme>>> =
api.chats
.iter()
.enumerate()
.map(|c| {
button("Chat")
.on_press(UIMessage::OpenChat(c.0))
.width(Length::Fill)
.into()
})
.collect();
chats.insert(
0,
button("Create Chat")
.on_press(UIMessage::CreateChat)
.width(Length::Fill)
.into(),
);
let chats_list = scrollable(container(column(chats).padding(10.0).spacing(10)))
.height(Length::Fill)
.width(Length::Fixed(300.0));
let mut messages: Vec<
iced::Element<'_, Self::Message, iced::Renderer<Self::Theme>>,
>;
let mut input_box = text_input("Your message", &state.chat_input);
if let Some(chat_index) = state.active_chat {
let chat = &api.chats[chat_index].blocking_read();
messages = chat
.messages
.iter()
.map(|x| self.render_message(x))
.collect();
if let Some(msg) = &chat.generating_message {
let mut msg = msg.clone();
msg.push_str("...");
let in_progress_message = ChatMessage {
role: Role::Assistant,
content: msg,
};
messages.push(self.render_message(&in_progress_message));
} else {
input_box = input_box
.on_input(|x| UIMessage::ChatInput(x))
.on_submit(UIMessage::ChatSubmit)
}
} else {
messages = Vec::new();
}
let active_chat = container(
column![
scrollable(container(column(messages).spacing(10)))
.width(Length::Fill)
.height(Length::Fill),
input_box
]
.padding(10)
.spacing(10),
)
.width(Length::Fill)
.height(Length::Fill);
container(row![chats_list, active_chat]).into()
}
UI::Closing => text("Closing").into(),
}
}
fn theme(&self) -> Self::Theme {
Theme::Dark
}
fn subscription(&self) -> iced::Subscription<Self::Message> {
let event_subscription = subscription::events().map(UIMessage::EventOccurred);
match self {
UI::Chats(state) => {
let recv_subscription = iced::subscription::unfold(
"external messages",
state.recv.take(),
move |recv| async move {
let msg = recv.as_ref().unwrap().recv();
if let Err(_) = msg {
return (UIMessage::Nop, recv);
}
let msg = msg.unwrap();
(msg, recv)
},
);
Subscription::batch(vec![event_subscription, recv_subscription])
}
_ => event_subscription,
}
}
}
impl UI {
fn render_message(
&self,
message: &ChatMessage,
) -> iced::Element<'_, UIMessage, iced::Renderer<Theme>> {
let (role_name, role_color) = match &message.role {
Role::User => ("User", Color::from_rgb8(128, 128, 128)),
Role::Assistant => ("Assistant", Color::from_rgb8(100, 100, 255)),
Role::System => ("System", Color::from_rgb8(255, 0, 0)),
};
column![text(role_name).style(role_color), text(&message.content)].into()
}
}