WebSocketAiServer.java 3.4 KB
package com.aigeo.socket;

import com.alibaba.fastjson2.JSONException;
import com.alibaba.fastjson2.JSONObject;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.websocket.*;
import jakarta.websocket.server.ServerEndpoint;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;

@ServerEndpoint("/ws/chat/{userId}")
@Component
public class WebSocketAiServer {

//    public static AiSerivce aiSerivce;

//    TokenInfo tokenInfo = TokenInfo.getIntance();

    private static final ConcurrentHashMap<String, Session> sessionMap = new ConcurrentHashMap<>();


    @OnOpen
    public void onOpen(Session session) {
        sessionMap.put(session.getId(),session);
        System.out.println("New connection: " + session.getId());
    }

    @OnClose
    public void onClose(Session session) {
        sessionMap.remove(session.getId());
        System.out.println("Connection closed: " + session.getId());
    }

    @OnMessage
    public void onMessage(String message, Session session) {
//        SocketDTO socketDTO = JSONObject.parseObject(message, SocketDTO.class);
//        socketDTO.setId(session.getId());
//        // 解析消息
//        try {
//            // 调用 reply 方法处理消息
//            aiSerivce.chat(socketDTO);
//        } catch (JSONException e) {
//            e.printStackTrace();
//            // 发送错误信息给客户端
//            DifyResponse difyResponse = new DifyResponse();
//            difyResponse.setCode(500);
//            sendMessage(session.getId(), difyResponse);
//        } catch (Exception e) {
//            e.printStackTrace();
//            // 发送错误信息给客户端
//            DifyResponse difyResponse = new DifyResponse();
//            difyResponse.setCode(500);
//            sendMessage(session.getId(), difyResponse);
//        }
    }

    @OnError
    public void onError(Session session, Throwable throwable) {
        System.out.println("Error occurred: " + throwable.getMessage());
    }

    public static void sendMessage(String sessionId, Object message) {
        Session session = sessionMap.get(sessionId);
        if (session != null && session.isOpen()) {
            try {
                ObjectMapper objectMapper = new ObjectMapper();
                String jsonMessage;
                try {
                    jsonMessage = objectMapper.writeValueAsString(message);
                } catch (Exception e) {
                    e.printStackTrace();
                    jsonMessage = "Error: Failed to serialize message";
                }
                session.getBasicRemote().sendText(jsonMessage);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    public static void sendMessageToAll(Object message) {
        ObjectMapper objectMapper = new ObjectMapper();
        String jsonMessage;
        try {
            jsonMessage = objectMapper.writeValueAsString(message);
        } catch (Exception e) {
            e.printStackTrace();
            jsonMessage = "Error: Failed to serialize message";
        }

        // 遍历 sessionMap 的所有 Session
        for (Session session : sessionMap.values()) {
            if (session.isOpen()) {
                try {
                    session.getBasicRemote().sendText(jsonMessage);
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }
}