说明:SaaS服务有多种数据隔离方式,本方案只介绍基于JPA的Entity的数据源隔离方案。
实现原理:基于hibernate的@FilterDef,@Filter定义隔离字段,再所有基于entity的操作都注入该字段值。
具体实现方式:
1、定义BaseEntity,定义隔离字段。所有需要隔离的Entity都继承该BaseEntity。
2、定义切面,自动注入FilterDef中需要的字段值。
3、定义Entity的@PrePersist,在新建数据时候注入隔离字段,这里以parkCode为例
定义BaseEntity
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.hibernate.annotations.Filter;
import org.hibernate.annotations.FilterDef;
import org.hibernate.annotations.ParamDef;
import javax.persistence.*;
@Setter
@Getter
@EqualsAndHashCode
@MappedSuperclass
@EntityListeners(EntityInterceptor.class)
@AttributeOverride(name = "createdBy", column = @Column(name = "created_by"))
@AttributeOverride(name = "createdAt", column = @Column(name = "created_at"))
@AttributeOverride(name = "updatedBy", column = @Column(name = "updated_by"))
@AttributeOverride(name = "updatedAt", column = @Column(name = "updated_at"))
@AttributeOverride(name = "delFlg", column = @Column(name = "del_flg"))
@FilterDef(name = "parkCodeFilter", parameters = @ParamDef(name = "parkCode", type = "string"))
@Filter(name = "parkCodeFilter", condition = "park_code = (:parkCode)")
public abstract class BaseEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;
@Column(name = "created_at", nullable = false)
private Long createdAt;
@Column(name = "updated_at")
private Long updatedAt;
@Column(name = "created_by")
private Long createdBy;
@Column(name = "updated_by")
private Long updatedBy;
@Column(name = "del_flg", nullable = false)
private Integer delFlg;
@Column(name = "park_code", nullable = false)
private String parkCode;
}
定义切面
注意:entityManager.unwrap(Session.class)的执行依赖于事务,所以默认都是自动添加事务控制。
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.hibernate.Session;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.support.DefaultTransactionDefinition;
import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import javax.persistence.PrePersist;
import javax.persistence.PreUpdate;
import java.lang.reflect.Method;
@Slf4j
@Aspect
@Component
public class EntityInterceptor {
@Autowired
@PersistenceContext()
private EntityManager entityManager;
@Autowired
private PlatformTransactionManager txManager;
@PrePersist
public void prePersist(Object entity) {
if (entity instanceof BaseEntity) {
checkParkCode();
BaseEntity baseEntity = (BaseEntity) entity;
baseEntity.setDelFlg(0);
baseEntity.setParkCode("1");
baseEntity.setCreatedBy(UserThreadLocal.getCurrentUserId());
baseEntity.setCreatedAt(System.currentTimeMillis());
baseEntity.setUpdatedAt(baseEntity.getCreatedAt());
baseEntity.setUpdatedBy(baseEntity.getCreatedBy());
}
}
@PreUpdate
public void preUpdate(Object entity) {
if (entity instanceof BaseEntity) {
checkParkCode();
BaseEntity baseEntity = (BaseEntity) entity;
baseEntity.setUpdatedBy(UserThreadLocal.getCurrentUserId());
baseEntity.setUpdatedAt(System.currentTimeMillis());
}
}
//TODO
private void checkParkCode() {
// String parkCode = UserThreadLocal.get().getParkCode();
}
@Pointcut("execution(* com.xxx.xxx.service.park..*.*(..))")
public void initCut() {
}
@Around("initCut()")
public Object run(ProceedingJoinPoint joinPoint) throws Throwable {
Object object = null;
Method targetMethod = ((MethodSignature) (joinPoint.getSignature())).getMethod();
Transactional annotation = targetMethod.getAnnotation(Transactional.class);
if (annotation == null) {
DefaultTransactionDefinition def = new DefaultTransactionDefinition();
def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRED);
TransactionStatus status = txManager.getTransaction(def);
try {
object = commonProcess(joinPoint);
} catch (Exception e) {
log.error("方法执行异常" + joinPoint.getSignature(), e);
txManager.rollback(status);
throw e;
}
txManager.commit(status);//事务提交
return object;
}
return commonProcess(joinPoint);
}
private Object commonProcess(ProceedingJoinPoint joinPoint) throws Throwable {
Session session = entityManager.unwrap(Session.class);
//parkCode需要自己设置 TODO
session.enableFilter("parkCodeFilter").setParameter("parkCode", "1");
return joinPoint.proceed(joinPoint.getArgs());
}
}