当前位置: 首页>后端>正文

基于Spring的JPA单数据源的SaaS隔离方案

说明:SaaS服务有多种数据隔离方式,本方案只介绍基于JPA的Entity的数据源隔离方案。
实现原理:基于hibernate的@FilterDef,@Filter定义隔离字段,再所有基于entity的操作都注入该字段值。
具体实现方式:
1、定义BaseEntity,定义隔离字段。所有需要隔离的Entity都继承该BaseEntity。
2、定义切面,自动注入FilterDef中需要的字段值。
3、定义Entity的@PrePersist,在新建数据时候注入隔离字段,这里以parkCode为例

定义BaseEntity


import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.hibernate.annotations.Filter;
import org.hibernate.annotations.FilterDef;
import org.hibernate.annotations.ParamDef;
import javax.persistence.*;

@Setter
@Getter
@EqualsAndHashCode
@MappedSuperclass
@EntityListeners(EntityInterceptor.class)
@AttributeOverride(name = "createdBy", column = @Column(name = "created_by"))
@AttributeOverride(name = "createdAt", column = @Column(name = "created_at"))
@AttributeOverride(name = "updatedBy", column = @Column(name = "updated_by"))
@AttributeOverride(name = "updatedAt", column = @Column(name = "updated_at"))
@AttributeOverride(name = "delFlg", column = @Column(name = "del_flg"))
@FilterDef(name = "parkCodeFilter", parameters = @ParamDef(name = "parkCode", type = "string"))
@Filter(name = "parkCodeFilter", condition = "park_code = (:parkCode)")
public abstract class BaseEntity {

    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    private Long id;

    @Column(name = "created_at", nullable = false)
    private Long createdAt;

    @Column(name = "updated_at")
    private Long updatedAt;

    @Column(name = "created_by")
    private Long createdBy;

    @Column(name = "updated_by")
    private Long updatedBy;

    @Column(name = "del_flg", nullable = false)
    private Integer delFlg;

    @Column(name = "park_code", nullable = false)
    private String parkCode;

}

定义切面
注意:entityManager.unwrap(Session.class)的执行依赖于事务,所以默认都是自动添加事务控制。


import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.hibernate.Session;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.support.DefaultTransactionDefinition;

import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import javax.persistence.PrePersist;
import javax.persistence.PreUpdate;
import java.lang.reflect.Method;

@Slf4j
@Aspect
@Component
public class EntityInterceptor {
    @Autowired
    @PersistenceContext()
    private EntityManager entityManager;
    @Autowired
    private PlatformTransactionManager txManager;

    @PrePersist
    public void prePersist(Object entity) {
        if (entity instanceof BaseEntity) {
            checkParkCode();
            BaseEntity baseEntity = (BaseEntity) entity;
            baseEntity.setDelFlg(0);
            baseEntity.setParkCode("1");
            baseEntity.setCreatedBy(UserThreadLocal.getCurrentUserId());
            baseEntity.setCreatedAt(System.currentTimeMillis());
            baseEntity.setUpdatedAt(baseEntity.getCreatedAt());
            baseEntity.setUpdatedBy(baseEntity.getCreatedBy());
        }
    }

    @PreUpdate
    public void preUpdate(Object entity) {
        if (entity instanceof BaseEntity) {
            checkParkCode();
            BaseEntity baseEntity = (BaseEntity) entity;
            baseEntity.setUpdatedBy(UserThreadLocal.getCurrentUserId());
            baseEntity.setUpdatedAt(System.currentTimeMillis());
        }
    }

    //TODO
    private void checkParkCode() {
//        String parkCode = UserThreadLocal.get().getParkCode();
      
    }

    @Pointcut("execution(* com.xxx.xxx.service.park..*.*(..))")
    public void initCut() {
    }

    @Around("initCut()")
    public Object run(ProceedingJoinPoint joinPoint) throws Throwable {
        Object object = null;
        Method targetMethod = ((MethodSignature) (joinPoint.getSignature())).getMethod();
        Transactional annotation = targetMethod.getAnnotation(Transactional.class);
        if (annotation == null) {
            DefaultTransactionDefinition def = new DefaultTransactionDefinition();
            def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRED);
            TransactionStatus status = txManager.getTransaction(def);
            try {
                object = commonProcess(joinPoint);
            } catch (Exception e) {
                log.error("方法执行异常" + joinPoint.getSignature(), e);
                txManager.rollback(status);
                throw e;
            }
            txManager.commit(status);//事务提交
            return object;
        }
        return commonProcess(joinPoint);
    }

    private Object commonProcess(ProceedingJoinPoint joinPoint) throws Throwable {
        Session session = entityManager.unwrap(Session.class);
        //parkCode需要自己设置 TODO
        session.enableFilter("parkCodeFilter").setParameter("parkCode", "1");
        return joinPoint.proceed(joinPoint.getArgs());
    }

}


https://www.xamrdz.com/backend/39p1920699.html

相关文章: