From 8fa23d9dff423911aec24557bc81fa3273abf0ed Mon Sep 17 00:00:00 2001 From: MrLetsplay Date: Mon, 4 Dec 2023 21:37:48 +0100 Subject: [PATCH] Update protocol --- .../mrletsplay/shareserver/SessionUser.java | 2 +- .../mrletsplay/shareserver/ShareWSServer.java | 57 ++++++++++++++++--- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/src/main/java/me/mrletsplay/shareserver/SessionUser.java b/src/main/java/me/mrletsplay/shareserver/SessionUser.java index 1151d7f..22865fb 100644 --- a/src/main/java/me/mrletsplay/shareserver/SessionUser.java +++ b/src/main/java/me/mrletsplay/shareserver/SessionUser.java @@ -1,3 +1,3 @@ package me.mrletsplay.shareserver; -public record SessionUser(Session session, int siteID) {} +public record SessionUser(Session session, String username, int siteID) {} diff --git a/src/main/java/me/mrletsplay/shareserver/ShareWSServer.java b/src/main/java/me/mrletsplay/shareserver/ShareWSServer.java index 35ef215..42d2977 100644 --- a/src/main/java/me/mrletsplay/shareserver/ShareWSServer.java +++ b/src/main/java/me/mrletsplay/shareserver/ShareWSServer.java @@ -5,8 +5,8 @@ import java.io.DataOutputStream; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; +import java.util.Comparator; import java.util.List; -import java.util.Objects; import org.java_websocket.WebSocket; import org.java_websocket.framing.CloseFrame; @@ -14,7 +14,7 @@ import org.java_websocket.handshake.ClientHandshake; import org.java_websocket.server.WebSocketServer; import me.mrletsplay.shareclientcore.connection.RemoteConnection; -import me.mrletsplay.shareclientcore.connection.message.ChangeMessage; +import me.mrletsplay.shareclientcore.connection.message.AddressableMessage; import me.mrletsplay.shareclientcore.connection.message.ClientHelloMessage; import me.mrletsplay.shareclientcore.connection.message.Message; import me.mrletsplay.shareclientcore.connection.message.PeerJoinMessage; @@ -66,18 +66,43 @@ public class ShareWSServer extends WebSocketServer { if(m instanceof ClientHelloMessage hello) { Session session = ShareServer.getOrCreateSession(hello.sessionID()); int siteID = session.getNewSiteID(); - conn.setAttachment(new SessionUser(session, siteID)); + conn.setAttachment(new SessionUser(session, hello.username(), siteID)); send(conn, new ServerHelloMessage(RemoteConnection.PROTOCOL_VERSION, siteID)); - getPeers(session).forEach(peer -> send(peer, new PeerJoinMessage(hello.username(), siteID))); + getPeers(session).forEach(peer -> { + if(peer != conn) { + SessionUser user = peer.getAttachment(); + send(conn, new PeerJoinMessage(user.username(), user.siteID())); + } + + send(peer, new PeerJoinMessage(hello.username(), siteID)); + }); }else { conn.close(CloseFrame.POLICY_VALIDATION, "First message must be CLIENT_HELLO"); } return; } - Session session = conn.getAttachment(); - if(m instanceof ChangeMessage change) { - getPeers(session).forEach(peer -> send(peer, m)); + SessionUser user = conn.getAttachment(); + Session session = user.session(); + switch(m.getType()) { + case CHANGE -> getPeers(session).forEach(peer -> send(peer, m)); + case REQUEST_FULL_SYNC, REQUEST_CHECKSUM -> { + AddressableMessage msg = (AddressableMessage) m; + if(msg.siteID() != user.siteID()) { + conn.close(CloseFrame.POLICY_VALIDATION, "Invalid site id"); + return; + } + + send(getHost(getPeers(session)), m); + } + case FULL_SYNC, CHECKSUM -> { + AddressableMessage msg = (AddressableMessage) m; + WebSocket peer = getPeer(getPeers(session), msg.siteID()); + if(peer != null) send(peer, m); + } + default -> { + conn.close(CloseFrame.POLICY_VALIDATION, "Invalid message received"); + } } // System.out.println("Got a message"); @@ -88,10 +113,26 @@ public class ShareWSServer extends WebSocketServer { private List getPeers(Session session) { return getConnections().stream() - .filter(c -> Objects.equals(c.getAttachment(), session)) + .filter(c -> { + SessionUser user = c.getAttachment(); + if(user == null) return false; + return user.session().equals(session); + }) .toList(); } + private WebSocket getHost(List peers) { + return peers.stream() + .min(Comparator.comparing(p -> p.getAttachment().siteID())) + .orElse(null); + } + + private WebSocket getPeer(List peers, int siteID) { + return peers.stream() + .filter(p -> p.getAttachment().siteID() == siteID) + .findFirst().orElse(null); + } + @Override public void onError(WebSocket conn, Exception ex) {