WebSocketAiServer.java
3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package com.aigeo.socket;
import com.alibaba.fastjson2.JSONException;
import com.alibaba.fastjson2.JSONObject;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.websocket.*;
import jakarta.websocket.server.ServerEndpoint;
import org.springframework.stereotype.Component;
import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
@ServerEndpoint("/ws/chat/{userId}")
@Component
public class WebSocketAiServer {
// public static AiSerivce aiSerivce;
// TokenInfo tokenInfo = TokenInfo.getIntance();
private static final ConcurrentHashMap<String, Session> sessionMap = new ConcurrentHashMap<>();
@OnOpen
public void onOpen(Session session) {
sessionMap.put(session.getId(),session);
System.out.println("New connection: " + session.getId());
}
@OnClose
public void onClose(Session session) {
sessionMap.remove(session.getId());
System.out.println("Connection closed: " + session.getId());
}
@OnMessage
public void onMessage(String message, Session session) {
// SocketDTO socketDTO = JSONObject.parseObject(message, SocketDTO.class);
// socketDTO.setId(session.getId());
// // 解析消息
// try {
// // 调用 reply 方法处理消息
// aiSerivce.chat(socketDTO);
// } catch (JSONException e) {
// e.printStackTrace();
// // 发送错误信息给客户端
// DifyResponse difyResponse = new DifyResponse();
// difyResponse.setCode(500);
// sendMessage(session.getId(), difyResponse);
// } catch (Exception e) {
// e.printStackTrace();
// // 发送错误信息给客户端
// DifyResponse difyResponse = new DifyResponse();
// difyResponse.setCode(500);
// sendMessage(session.getId(), difyResponse);
// }
}
@OnError
public void onError(Session session, Throwable throwable) {
System.out.println("Error occurred: " + throwable.getMessage());
}
public static void sendMessage(String sessionId, Object message) {
Session session = sessionMap.get(sessionId);
if (session != null && session.isOpen()) {
try {
ObjectMapper objectMapper = new ObjectMapper();
String jsonMessage;
try {
jsonMessage = objectMapper.writeValueAsString(message);
} catch (Exception e) {
e.printStackTrace();
jsonMessage = "Error: Failed to serialize message";
}
session.getBasicRemote().sendText(jsonMessage);
} catch (IOException e) {
e.printStackTrace();
}
}
}
public static void sendMessageToAll(Object message) {
ObjectMapper objectMapper = new ObjectMapper();
String jsonMessage;
try {
jsonMessage = objectMapper.writeValueAsString(message);
} catch (Exception e) {
e.printStackTrace();
jsonMessage = "Error: Failed to serialize message";
}
// 遍历 sessionMap 的所有 Session
for (Session session : sessionMap.values()) {
if (session.isOpen()) {
try {
session.getBasicRemote().sendText(jsonMessage);
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}