AuthServiceImpl.java 6.0 KB
package com.aigeo.auth.service.impl;

import com.aigeo.auth.dto.LoginRequest;
import com.aigeo.auth.dto.LoginResponse;
import com.aigeo.auth.dto.RegisterRequest;
import com.aigeo.auth.dto.RegisterResponse;
import com.aigeo.auth.service.AuthService;
import com.aigeo.company.entity.Company;
import com.aigeo.company.entity.User;
import com.aigeo.company.service.CompanyService;
import com.aigeo.company.service.UserService;
import com.aigeo.common.enums.CompanyStatus;
import com.aigeo.common.enums.UserRole;
import com.aigeo.common.exception.BusinessException;
import com.aigeo.util.JwtUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.stereotype.Service;

import java.time.LocalDateTime;
import java.util.Optional;

/**
 * 认证服务实现类
 */
@Slf4j
@Service
public class AuthServiceImpl implements AuthService {

    private final UserService userService;
    private final CompanyService companyService;
    private final PasswordEncoder passwordEncoder;
    private final JwtUtil jwtUtil;

    public AuthServiceImpl(UserService userService, CompanyService companyService, 
                          PasswordEncoder passwordEncoder, JwtUtil jwtUtil) {
        this.userService = userService;
        this.companyService = companyService;
        this.passwordEncoder = passwordEncoder;
        this.jwtUtil = jwtUtil;
    }

    @Override
    public LoginResponse login(LoginRequest loginRequest) {
        try {
            // 查找用户
            Optional<User> userOptional = userService.findByUsername(loginRequest.getUsername());
            if (!userOptional.isPresent()) {
                // 尝试通过邮箱查找
                userOptional = userService.findByEmail(loginRequest.getUsername());
            }
            
            if (!userOptional.isPresent()) {
                throw new BusinessException(400, "用户名或密码错误");
            }

            User user = userOptional.get();

            // 验证密码
            if (!passwordEncoder.matches(loginRequest.getPassword(), user.getPasswordHash())) {
                throw new BusinessException(400, "用户名或密码错误");
            }

            // 生成JWT token
            String token = jwtUtil.generateToken(user);

            // 构建响应
            LoginResponse response = new LoginResponse();
            response.setToken(token);
            response.setExpiresIn(jwtUtil.getExpirationTimeSeconds());
            response.setUser(user);

            return response;
        } catch (Exception e) {
            log.error("用户登录失败: {}", loginRequest.getUsername(), e);
            throw new BusinessException(500, "登录失败");
        }
    }

    @Override
    public RegisterResponse register(RegisterRequest registerRequest) {
        try {
            // 检查用户名是否已存在
            if (userService.findByUsername(registerRequest.getUsername()).isPresent()) {
                throw new BusinessException(400, "用户名已存在");
            }

            // 检查邮箱是否已存在
            if (userService.findByEmail(registerRequest.getEmail()).isPresent()) {
                throw new BusinessException(400, "邮箱已被注册");
            }

            // 验证密码和确认密码是否一致
            if (!registerRequest.getPassword().equals(registerRequest.getConfirmPassword())) {
                throw new BusinessException(400, "密码和确认密码不一致");
            }
            
            // 确定公司ID
            Integer companyId = registerRequest.getCompanyId();
            if (companyId == null) {
                // 如果没有提供公司ID,则需要提供公司名称来创建新公司
                if (registerRequest.getCompanyName() == null || registerRequest.getCompanyName().isEmpty()) {
                    throw new BusinessException(400, "必须提供公司ID或公司名称");
                }
                
                // 创建新公司
                Company company = new Company();
                company.setName(registerRequest.getCompanyName());
                company.setBillingEmail(registerRequest.getEmail());
                company.setStatus(CompanyStatus.TRIAL);
                company.setTrialExpiryDate(LocalDateTime.now().plusDays(30)); // 30天试用期
                
                Company savedCompany = companyService.save(company);
                companyId = savedCompany.getId();
                log.info("为新用户 {} 创建了新公司 {}, 公司ID: {}", 
                        registerRequest.getUsername(), registerRequest.getCompanyName(), companyId);
            }

            // 创建新用户
            User user = new User();
            user.setUsername(registerRequest.getUsername());
            user.setEmail(registerRequest.getEmail());
            user.setPasswordHash(passwordEncoder.encode(registerRequest.getPassword()));
            user.setFullName(registerRequest.getFullName());
            user.setPhone(registerRequest.getPhone());
            user.setCompanyId(companyId);
            user.setRole(UserRole.EDITOR); // 使用现有的枚举值
            user.setAvatarUrl(registerRequest.getAvatarUrl());
            user.setIsActive(true);

            // 保存用户
            User savedUser = userService.save(user);

            // 生成JWT token
            String token = jwtUtil.generateToken(savedUser);

            // 构建响应
            RegisterResponse response = new RegisterResponse();
            response.setToken(token);
            response.setExpiresIn(jwtUtil.getExpirationTimeSeconds());
            response.setUser(savedUser);

            return response;
        } catch (BusinessException e) {
            throw e;
        } catch (Exception e) {
            log.error("用户注册失败: {}", registerRequest.getUsername(), e);
            throw new BusinessException(500, "注册失败");
        }
    }
}