initial commit

This commit is contained in:
MrLetsplay 2024-01-13 19:45:41 +01:00
commit e1eb7ef0cf
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
7 changed files with 290 additions and 0 deletions

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
/.settings
/bin
/.classpath
/TEST
/target/

23
.project Normal file
View File

@ -0,0 +1,23 @@
<?xml version="1.0" encoding="UTF-8"?>
<projectDescription>
<name>OllamaAPI</name>
<comment></comment>
<projects>
</projects>
<buildSpec>
<buildCommand>
<name>org.eclipse.jdt.core.javabuilder</name>
<arguments>
</arguments>
</buildCommand>
<buildCommand>
<name>org.eclipse.m2e.core.maven2Builder</name>
<arguments>
</arguments>
</buildCommand>
</buildSpec>
<natures>
<nature>org.eclipse.m2e.core.maven2Nature</nature>
<nature>org.eclipse.jdt.core.javanature</nature>
</natures>
</projectDescription>

32
pom.xml Normal file
View File

@ -0,0 +1,32 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>Ollama</groupId>
<artifactId>Ollama</artifactId>
<version>0.0.1-SNAPSHOT</version>
<build>
<sourceDirectory>src</sourceDirectory>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<release>17</release>
</configuration>
</plugin>
</plugins>
</build>
<repositories>
<repository>
<id>Graphite</id>
<url>https://maven.graphite-official.com/releases</url>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>me.mrletsplay</groupId>
<artifactId>MrCore</artifactId>
<version>4.5</version>
</dependency>
</dependencies>
</project>

28
src/Ollama.java Normal file
View File

@ -0,0 +1,28 @@
import java.io.IOException;
import java.util.Scanner;
import me.mrletsplay.ollama.api.Chat;
import me.mrletsplay.ollama.api.OllamaAPI;
public class Ollama {
public static void main(String[] args) throws IOException, InterruptedException {
OllamaAPI api = new OllamaAPI("http://localhost:11434");
String systemPrompt = """
You are AmogusBot, a bot based on the video game \"Among Us\". You don't talk about anything other than Among Us and constantly make references to the game Among Us.
""";
Chat chat = api.startChat("dolphin-mixtral", systemPrompt);
try(Scanner s = new Scanner(System.in)) {
while(true) {
String p = s.nextLine();
chat.sendUser(p);
System.err.println("Bot is thinking...");
chat.doCompletion().forEach(System.out::print);
System.out.println();
}
}
}
}

View File

@ -0,0 +1,83 @@
package me.mrletsplay.ollama.api;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
import me.mrletsplay.mrcore.json.JSONArray;
import me.mrletsplay.mrcore.json.JSONObject;
import me.mrletsplay.ollama.api.ChatMessage.Role;
public class Chat {
public static final int DEFAULT_HISTORY_SIZE = -1;
private OllamaAPI api;
private String model;
private int historySize;
private List<ChatMessage> messages;
public Chat(OllamaAPI api, String model) {
this.api = api;
this.model = model;
this.historySize = DEFAULT_HISTORY_SIZE;
this.messages = new ArrayList<>();
}
/**
* -1 == unlimited
* @param historySize
*/
public void setHistorySize(int historySize) {
this.historySize = historySize;
}
public int getHistorySize() {
return historySize;
}
public void sendSystem(String message) {
this.messages.add(new ChatMessage(Role.SYSTEM, message));
}
public void sendAssistant(String message) {
this.messages.add(new ChatMessage(Role.ASSISTANT, message));
}
public void sendUser(String message) {
this.messages.add(new ChatMessage(Role.USER, message));
}
public Stream<String> doCompletion() throws IOException, InterruptedException {
StringBuffer message = new StringBuffer();
return api.prompt(this)
.map(o -> {
String content = o.getJSONObject("message").getString("content");
message.append(content);
if(o.getBoolean("done")) {
sendAssistant(message.toString());
}
return content;
});
}
public JSONObject toJSON() {
JSONObject obj = new JSONObject();
obj.put("model", model);
JSONArray messages = new JSONArray();
// TODO: keep system messages
this.messages.stream().skip(historySize == -1 ? 0 : Math.max(0, messages.size() - historySize)).forEach(m -> messages.add(m.toJSON()));
obj.put("messages", messages);
// JSONObject options = new JSONObject();
// options.put("num_thread", 16);
// obj.put("options", options);
return obj;
}
}

View File

@ -0,0 +1,18 @@
package me.mrletsplay.ollama.api;
import me.mrletsplay.mrcore.json.JSONObject;
public record ChatMessage(Role role, String content) {
public static enum Role {
SYSTEM, ASSISTANT, USER;
}
public JSONObject toJSON() {
JSONObject obj = new JSONObject();
obj.put("role", role.name().toLowerCase());
obj.put("content", content);
return obj;
}
}

View File

@ -0,0 +1,101 @@
package me.mrletsplay.ollama.api;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpRequest.BodyPublishers;
import java.net.http.HttpResponse.BodyHandlers;
import java.util.Iterator;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import me.mrletsplay.mrcore.json.JSONException;
import me.mrletsplay.mrcore.json.JSONObject;
public class OllamaAPI {
private String baseURL;
private HttpClient client;
public OllamaAPI(String baseURL) {
this.baseURL = baseURL;
this.client = HttpClient.newHttpClient();
}
public Stream<String> prompt(String model, String systemPrompt, String prompt) throws IOException, InterruptedException {
Chat c = startChat(model, systemPrompt);
c.sendUser(prompt);
return c.doCompletion();
}
public Stream<String> prompt(String model, String prompt) throws IOException, InterruptedException {
Chat c = startChat(model);
c.sendUser(prompt);
return c.doCompletion();
}
public Stream<JSONObject> prompt(Chat chat) throws IOException, InterruptedException {
JSONObject request = chat.toJSON();
InputStream in = client.send(HttpRequest.newBuilder(URI.create(baseURL + "/api/chat"))
.POST(BodyPublishers.ofString(request.toFancyString()))
.build(), BodyHandlers.ofInputStream()).body();
return createJSONStream(in);
}
public Chat startChat(String model) {
return new Chat(this, model);
}
public Chat startChat(String model, String systemPrompt) {
Chat chat = startChat(model);
chat.sendSystem(systemPrompt);
return chat;
}
private static Stream<JSONObject> createJSONStream(InputStream in) {
BufferedReader r = new BufferedReader(new InputStreamReader(in));
Iterator<JSONObject> it = new Iterator<>() {
private JSONObject nextObj;
private JSONObject tryRead() {
if(nextObj != null) return null;
try {
String line = r.readLine();
if(line == null) {
r.close();
return null;
}
return nextObj = new JSONObject(line);
} catch (IOException | JSONException e) {
throw new RuntimeException("Failed to read object", e);
}
}
@Override
public JSONObject next() {
tryRead();
JSONObject toReturn = nextObj;
nextObj = null;
return toReturn;
}
@Override
public boolean hasNext() {
return tryRead() != null;
}
};
return StreamSupport.stream(Spliterators.spliteratorUnknownSize(it, Spliterator.ORDERED), false);
}
}