Add stable diffusion API (WIP)

This commit is contained in:
MrLetsplay 2024-02-11 17:17:08 +01:00
parent 4ae2dc55d8
commit c3fda2ad6e
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
4 changed files with 286 additions and 4 deletions

229
Cargo.lock generated
View File

@ -168,6 +168,12 @@ version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]]
name = "bit_field"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc827186963e592360843fb5ba4b973e145841266c1357f7180c43526f2e5b61"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
@ -231,6 +237,12 @@ dependencies = [
"syn 2.0.48", "syn 2.0.48",
] ]
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "bytes" name = "bytes"
version = "1.5.0" version = "1.5.0"
@ -380,6 +392,12 @@ dependencies = [
"unicode-width", "unicode-width",
] ]
[[package]]
name = "color_quant"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
[[package]] [[package]]
name = "com-rs" name = "com-rs"
version = "0.2.1" version = "0.2.1"
@ -464,6 +482,25 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "crossbeam-deque"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-utils" name = "crossbeam-utils"
version = "0.8.19" version = "0.8.19"
@ -514,6 +551,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650" checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650"
[[package]]
name = "either"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
[[package]] [[package]]
name = "encoding_rs" name = "encoding_rs"
version = "0.8.33" version = "0.8.33"
@ -568,6 +611,22 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "exr"
version = "1.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "279d3efcc55e19917fff7ab3ddd6c14afb6a90881a0078465196fe2f99d08c56"
dependencies = [
"bit_field",
"flume",
"half",
"lebe",
"miniz_oxide",
"rayon-core",
"smallvec",
"zune-inflate",
]
[[package]] [[package]]
name = "fast-srgb8" name = "fast-srgb8"
version = "1.0.0" version = "1.0.0"
@ -608,6 +667,19 @@ dependencies = [
"miniz_oxide", "miniz_oxide",
] ]
[[package]]
name = "flume"
version = "0.10.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577"
dependencies = [
"futures-core",
"futures-sink",
"nanorand",
"pin-project",
"spin",
]
[[package]] [[package]]
name = "fnv" name = "fnv"
version = "1.0.7" version = "1.0.7"
@ -774,8 +846,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"js-sys",
"libc", "libc",
"wasi", "wasi",
"wasm-bindgen",
]
[[package]]
name = "gif"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80792593675e051cf94a4b111980da2ba60d4a83e43e0048c5693baab3977045"
dependencies = [
"color_quant",
"weezl",
] ]
[[package]] [[package]]
@ -1030,6 +1114,7 @@ dependencies = [
"iced_renderer", "iced_renderer",
"iced_widget", "iced_widget",
"iced_winit", "iced_winit",
"image",
"thiserror", "thiserror",
] ]
@ -1072,6 +1157,8 @@ dependencies = [
"glam", "glam",
"half", "half",
"iced_core", "iced_core",
"image",
"kamadak-exif",
"log", "log",
"raw-window-handle", "raw-window-handle",
"thiserror", "thiserror",
@ -1189,6 +1276,7 @@ name = "icedtest"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"base64",
"iced", "iced",
"reqwest", "reqwest",
"serde", "serde",
@ -1208,6 +1296,24 @@ dependencies = [
"unicode-normalization", "unicode-normalization",
] ]
[[package]]
name = "image"
version = "0.24.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "034bbe799d1909622a74d1193aa50147769440040ff36cb2baa947609b0a4e23"
dependencies = [
"bytemuck",
"byteorder",
"color_quant",
"exr",
"gif",
"jpeg-decoder",
"num-traits",
"png",
"qoi",
"tiff",
]
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "1.9.3" version = "1.9.3"
@ -1278,6 +1384,15 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "jpeg-decoder"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0"
dependencies = [
"rayon",
]
[[package]] [[package]]
name = "js-sys" name = "js-sys"
version = "0.3.68" version = "0.3.68"
@ -1287,6 +1402,15 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "kamadak-exif"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef4fc70d0ab7e5b6bafa30216a6b48705ea964cdfc29c050f2412295eba58077"
dependencies = [
"mutate_once",
]
[[package]] [[package]]
name = "khronos-egl" name = "khronos-egl"
version = "4.1.0" version = "4.1.0"
@ -1313,6 +1437,12 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "lebe"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.153" version = "0.2.153"
@ -1489,6 +1619,12 @@ dependencies = [
"windows-sys 0.48.0", "windows-sys 0.48.0",
] ]
[[package]]
name = "mutate_once"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16cf681a23b4d0a43fc35024c176437f9dcd818db34e0f42ab456a0ee5ad497b"
[[package]] [[package]]
name = "naga" name = "naga"
version = "0.12.3" version = "0.12.3"
@ -1509,6 +1645,15 @@ dependencies = [
"unicode-xid", "unicode-xid",
] ]
[[package]]
name = "nanorand"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3"
dependencies = [
"getrandom",
]
[[package]] [[package]]
name = "native-tls" name = "native-tls"
version = "0.2.11" version = "0.2.11"
@ -1916,6 +2061,26 @@ dependencies = [
"siphasher", "siphasher",
] ]
[[package]]
name = "pin-project"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
]
[[package]] [[package]]
name = "pin-project-lite" name = "pin-project-lite"
version = "0.2.13" version = "0.2.13"
@ -1992,6 +2157,15 @@ version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f0f7f43585c34e4fdd7497d746bc32e14458cf11c69341cc0587b1d825dde42" checksum = "0f0f7f43585c34e4fdd7497d746bc32e14458cf11c69341cc0587b1d825dde42"
[[package]]
name = "qoi"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f6d64c71eb498fe9eae14ce4ec935c555749aef511cca85b5568910d6e48001"
dependencies = [
"bytemuck",
]
[[package]] [[package]]
name = "quick-xml" name = "quick-xml"
version = "0.28.2" version = "0.28.2"
@ -2067,6 +2241,26 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9"
[[package]]
name = "rayon"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]] [[package]]
name = "read-fonts" name = "read-fonts"
version = "0.15.4" version = "0.15.4"
@ -2446,6 +2640,15 @@ dependencies = [
"x11rb 0.11.1", "x11rb 0.11.1",
] ]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
]
[[package]] [[package]]
name = "spirv" name = "spirv"
version = "0.2.0+1.5.4" version = "0.2.0+1.5.4"
@ -2590,6 +2793,17 @@ dependencies = [
"syn 2.0.48", "syn 2.0.48",
] ]
[[package]]
name = "tiff"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba1310fcea54c6a9a4fd1aad794ecc02c31682f6bfbecdf460bf19533eed1e3e"
dependencies = [
"flate2",
"jpeg-decoder",
"weezl",
]
[[package]] [[package]]
name = "tiny-skia" name = "tiny-skia"
version = "0.8.4" version = "0.8.4"
@ -3228,6 +3442,12 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "weezl"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb"
[[package]] [[package]]
name = "wgpu" name = "wgpu"
version = "0.16.3" version = "0.16.3"
@ -3751,3 +3971,12 @@ dependencies = [
"quote", "quote",
"syn 2.0.48", "syn 2.0.48",
] ]
[[package]]
name = "zune-inflate"
version = "0.2.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73ab332fe2f6680068f3582b16a24f90ad7096d5d39b974d1c0aff0125116f02"
dependencies = [
"simd-adler32",
]

View File

@ -7,7 +7,8 @@ edition = "2021"
[dependencies] [dependencies]
anyhow = "1.0.79" anyhow = "1.0.79"
iced = { version = "0.10.0", features = ["tokio"] } base64 = "0.21.7"
iced = { version = "0.10.0", features = ["tokio", "image"] }
reqwest = { version = "0.11.24", features = ["stream"] } reqwest = { version = "0.11.24", features = ["stream"] }
serde = { version = "1.0.196", features = ["derive"] } serde = { version = "1.0.196", features = ["derive"] }
serde_json = "1.0.113" serde_json = "1.0.113"

View File

@ -1,6 +1,8 @@
use std::ops::Deref; use std::ops::Deref;
use std::sync::Arc; use std::sync::Arc;
use base64::engine::{general_purpose};
use base64::Engine;
use iced::futures::StreamExt; use iced::futures::StreamExt;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -138,3 +140,43 @@ impl OllamaChat {
Ok(LinesStream::new(lines).map(deserialize_chunk)) Ok(LinesStream::new(lines).map(deserialize_chunk))
} }
} }
pub struct StableDiffusionAPI {
pub api_url: String,
client: Client,
}
#[derive(Serialize)]
pub struct StableDiffusionRequest {
pub prompt: String,
pub steps: u16,
}
#[derive(Deserialize)]
struct StableDiffusionResponse {
images: Vec<String>,
}
impl StableDiffusionAPI {
pub fn create(api_url: &str) -> anyhow::Result<Self> {
Ok(Self {
api_url: String::from(api_url),
client: Client::builder().build()?,
})
}
pub async fn generate_image(
&self,
request: &StableDiffusionRequest,
) -> anyhow::Result<Vec<u8>> {
let request = self
.client
.post(self.api_url.clone() + "/sdapi/v1/txt2img")
.body(serde_json::to_string(request)?)
.build()?;
let res = String::from_utf8(self.client.execute(request).await?.bytes().await?.to_vec())?;
let res: StableDiffusionResponse = serde_json::from_str(&res)?;
Ok(general_purpose::STANDARD.decode(&res.images[0])?)
}
}

View File

@ -15,8 +15,7 @@ use iced::{
futures::StreamExt, futures::StreamExt,
subscription, subscription,
widget::{button, column, container, row, scrollable, text, text_input}, widget::{button, column, container, row, scrollable, text, text_input},
window::{self}, window, Application, Color, Command, Event, Length, Settings, Subscription, Theme,
Application, Color, Command, Event, Length, Settings, Subscription, Theme,
}; };
use tokio::sync::RwLock; use tokio::sync::RwLock;
@ -122,7 +121,7 @@ impl Application for UI {
.blocking_write() .blocking_write()
.create_chat("dolphin-mixtral"); .create_chat("dolphin-mixtral");
if let Ok(chat) = chat { if let Ok(chat) = chat {
chat.blocking_write().send_system("Hello World!"); chat.blocking_write().send_system("You write stories about various topics.\n\nYou can include images into these stories using the following syntax:\n{IMAGE:description of image}\n\ne.g.:\n{IMAGE:an image of a large church, volumetric lighting, masterpiece, best quality}\n\nMake sure to use at least one image every five sentences.");
} }
Command::none() Command::none()
} }
@ -146,6 +145,17 @@ impl Application for UI {
let ollama_api = state.ollama_api.clone(); let ollama_api = state.ollama_api.clone();
return Command::perform( return Command::perform(
async move { async move {
/* TODO: Stable diffusion
let sd =
StableDiffusionAPI::create("http://localhost:7860").unwrap();
let req = StableDiffusionRequest {
prompt: String::from("among us"),
steps: 10,
};
let img = sd.generate_image(&req).await.unwrap();
println!("{:?}", img);
std::fs::write("amogus.png", img);*/
let stream; let stream;
let chat; let chat;
{ {