From e1eb7ef0cf5ab2efbdd14a8d67bd4154dad19c8c Mon Sep 17 00:00:00 2001 From: MrLetsplay Date: Sat, 13 Jan 2024 19:45:41 +0100 Subject: [PATCH] initial commit --- .gitignore | 5 + .project | 23 ++++ pom.xml | 32 ++++++ src/Ollama.java | 28 +++++ src/me/mrletsplay/ollama/api/Chat.java | 83 ++++++++++++++ src/me/mrletsplay/ollama/api/ChatMessage.java | 18 ++++ src/me/mrletsplay/ollama/api/OllamaAPI.java | 101 ++++++++++++++++++ 7 files changed, 290 insertions(+) create mode 100644 .gitignore create mode 100644 .project create mode 100644 pom.xml create mode 100644 src/Ollama.java create mode 100644 src/me/mrletsplay/ollama/api/Chat.java create mode 100644 src/me/mrletsplay/ollama/api/ChatMessage.java create mode 100644 src/me/mrletsplay/ollama/api/OllamaAPI.java diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..746a89c --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/.settings +/bin +/.classpath +/TEST +/target/ diff --git a/.project b/.project new file mode 100644 index 0000000..87484d9 --- /dev/null +++ b/.project @@ -0,0 +1,23 @@ + + + OllamaAPI + + + + + + org.eclipse.jdt.core.javabuilder + + + + + org.eclipse.m2e.core.maven2Builder + + + + + + org.eclipse.m2e.core.maven2Nature + org.eclipse.jdt.core.javanature + + diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..7321a23 --- /dev/null +++ b/pom.xml @@ -0,0 +1,32 @@ + + 4.0.0 + Ollama + Ollama + 0.0.1-SNAPSHOT + + src + + + maven-compiler-plugin + 3.8.1 + + 17 + + + + + + + + Graphite + https://maven.graphite-official.com/releases + + + + + me.mrletsplay + MrCore + 4.5 + + + \ No newline at end of file diff --git a/src/Ollama.java b/src/Ollama.java new file mode 100644 index 0000000..d55b1f0 --- /dev/null +++ b/src/Ollama.java @@ -0,0 +1,28 @@ +import java.io.IOException; +import java.util.Scanner; + +import me.mrletsplay.ollama.api.Chat; +import me.mrletsplay.ollama.api.OllamaAPI; + +public class Ollama { + + public static void main(String[] args) throws IOException, InterruptedException { + OllamaAPI api = new OllamaAPI("http://localhost:11434"); + String systemPrompt = """ + You are AmogusBot, a bot based on the video game \"Among Us\". You don't talk about anything other than Among Us and constantly make references to the game Among Us. + """; + + Chat chat = api.startChat("dolphin-mixtral", systemPrompt); + + try(Scanner s = new Scanner(System.in)) { + while(true) { + String p = s.nextLine(); + chat.sendUser(p); + System.err.println("Bot is thinking..."); + chat.doCompletion().forEach(System.out::print); + System.out.println(); + } + } + } + +} diff --git a/src/me/mrletsplay/ollama/api/Chat.java b/src/me/mrletsplay/ollama/api/Chat.java new file mode 100644 index 0000000..a5ed2d7 --- /dev/null +++ b/src/me/mrletsplay/ollama/api/Chat.java @@ -0,0 +1,83 @@ +package me.mrletsplay.ollama.api; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import me.mrletsplay.mrcore.json.JSONArray; +import me.mrletsplay.mrcore.json.JSONObject; +import me.mrletsplay.ollama.api.ChatMessage.Role; + +public class Chat { + + public static final int DEFAULT_HISTORY_SIZE = -1; + + private OllamaAPI api; + private String model; + private int historySize; + + private List messages; + + public Chat(OllamaAPI api, String model) { + this.api = api; + this.model = model; + this.historySize = DEFAULT_HISTORY_SIZE; + this.messages = new ArrayList<>(); + } + + /** + * -1 == unlimited + * @param historySize + */ + public void setHistorySize(int historySize) { + this.historySize = historySize; + } + + public int getHistorySize() { + return historySize; + } + + public void sendSystem(String message) { + this.messages.add(new ChatMessage(Role.SYSTEM, message)); + } + + public void sendAssistant(String message) { + this.messages.add(new ChatMessage(Role.ASSISTANT, message)); + } + + public void sendUser(String message) { + this.messages.add(new ChatMessage(Role.USER, message)); + } + + public Stream doCompletion() throws IOException, InterruptedException { + StringBuffer message = new StringBuffer(); + return api.prompt(this) + .map(o -> { + String content = o.getJSONObject("message").getString("content"); + message.append(content); + + if(o.getBoolean("done")) { + sendAssistant(message.toString()); + } + + return content; + }); + } + + public JSONObject toJSON() { + JSONObject obj = new JSONObject(); + obj.put("model", model); + + JSONArray messages = new JSONArray(); + // TODO: keep system messages + this.messages.stream().skip(historySize == -1 ? 0 : Math.max(0, messages.size() - historySize)).forEach(m -> messages.add(m.toJSON())); + obj.put("messages", messages); + +// JSONObject options = new JSONObject(); +// options.put("num_thread", 16); +// obj.put("options", options); + return obj; + } + +} diff --git a/src/me/mrletsplay/ollama/api/ChatMessage.java b/src/me/mrletsplay/ollama/api/ChatMessage.java new file mode 100644 index 0000000..c02260d --- /dev/null +++ b/src/me/mrletsplay/ollama/api/ChatMessage.java @@ -0,0 +1,18 @@ +package me.mrletsplay.ollama.api; + +import me.mrletsplay.mrcore.json.JSONObject; + +public record ChatMessage(Role role, String content) { + + public static enum Role { + SYSTEM, ASSISTANT, USER; + } + + public JSONObject toJSON() { + JSONObject obj = new JSONObject(); + obj.put("role", role.name().toLowerCase()); + obj.put("content", content); + return obj; + } + +} diff --git a/src/me/mrletsplay/ollama/api/OllamaAPI.java b/src/me/mrletsplay/ollama/api/OllamaAPI.java new file mode 100644 index 0000000..2dc586d --- /dev/null +++ b/src/me/mrletsplay/ollama/api/OllamaAPI.java @@ -0,0 +1,101 @@ +package me.mrletsplay.ollama.api; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.BodyPublishers; +import java.net.http.HttpResponse.BodyHandlers; +import java.util.Iterator; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +import me.mrletsplay.mrcore.json.JSONException; +import me.mrletsplay.mrcore.json.JSONObject; + +public class OllamaAPI { + + private String baseURL; + + private HttpClient client; + + public OllamaAPI(String baseURL) { + this.baseURL = baseURL; + this.client = HttpClient.newHttpClient(); + } + + public Stream prompt(String model, String systemPrompt, String prompt) throws IOException, InterruptedException { + Chat c = startChat(model, systemPrompt); + c.sendUser(prompt); + return c.doCompletion(); + } + + public Stream prompt(String model, String prompt) throws IOException, InterruptedException { + Chat c = startChat(model); + c.sendUser(prompt); + return c.doCompletion(); + } + + public Stream prompt(Chat chat) throws IOException, InterruptedException { + JSONObject request = chat.toJSON(); + InputStream in = client.send(HttpRequest.newBuilder(URI.create(baseURL + "/api/chat")) + .POST(BodyPublishers.ofString(request.toFancyString())) + .build(), BodyHandlers.ofInputStream()).body(); + + return createJSONStream(in); + } + + public Chat startChat(String model) { + return new Chat(this, model); + } + + public Chat startChat(String model, String systemPrompt) { + Chat chat = startChat(model); + chat.sendSystem(systemPrompt); + return chat; + } + + private static Stream createJSONStream(InputStream in) { + BufferedReader r = new BufferedReader(new InputStreamReader(in)); + Iterator it = new Iterator<>() { + + private JSONObject nextObj; + + private JSONObject tryRead() { + if(nextObj != null) return null; + + try { + String line = r.readLine(); + if(line == null) { + r.close(); + return null; + } + return nextObj = new JSONObject(line); + } catch (IOException | JSONException e) { + throw new RuntimeException("Failed to read object", e); + } + } + + @Override + public JSONObject next() { + tryRead(); + JSONObject toReturn = nextObj; + nextObj = null; + return toReturn; + } + + @Override + public boolean hasNext() { + return tryRead() != null; + } + }; + return StreamSupport.stream(Spliterators.spliteratorUnknownSize(it, Spliterator.ORDERED), false); + } + + + +}