服务提供方
服务端启动服务及处理逻辑
1、启动加载配置
2、初始化服务端的提供服务的对象
3、启netty服务端(添加编解码及服务端处理handler)
4、将服务注册到zookeeper中,会创建服务节点,临时子节点,子节点放提供服务的地址
5、服务端收到请求后,根据请求找到提供服务的对象。
6、通过cglib动态代理调用对应的方法
7、给客户端响应,并在完成后关闭连接
1、spring启动
ClassPathXmlApplicationContext context = new ClassPathXmlApplicationContext(new String[] {"spring.xml"});
context.start();
2、spring加载rpc相应配置
<bean id="serviceRegistry" class="com.rpc.registry.zookeeper.ZooKeeperServiceRegistry">
<constructor-arg name="zkAddress" value="${rpc.registry_address}" />
</bean>
<bean id="rpcServer" class="com.rpc.server.RpcServer">
<constructor-arg name="serviceAddress" value="${rpc.service_address}" />
<constructor-arg name="serviceRegistry" ref="serviceRegistry" />
</bean>
ZooKeeperServiceRegistry类
package com.rpc.registry;
//服务注册接口类
public interface ServiceRegistry {
void register(String var1, String var2);
}
package com.rpc.registry.zookeeper;
import com.rpc.registry.ServiceRegistry;
import org.I0Itec.zkclient.ZkClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
//服务注册实现类
public class ZooKeeperServiceRegistry implements ServiceRegistry {
private static final Logger LOGGER = LoggerFactory.getLogger(ZooKeeperServiceRegistry.class);
private final ZkClient zkClient;
public ZooKeeperServiceRegistry(String zkAddress) {
this.zkClient = new ZkClient(zkAddress, 5000, 1000);
LOGGER.info("connect zookeeper");
}
public void register(String serviceName, String serviceAddress) {
String registryPath = "/registry";
if (!this.zkClient.exists(registryPath)) {
this.zkClient.createPersistent(registryPath);
LOGGER.info("create registry node: {}", registryPath);
}
//节点示例:/registry/com.db.api.service.cinf.CinfDepartmentsRpcService
String servicePath = registryPath + "/" + serviceName;
if (!this.zkClient.exists(servicePath)) {
this.zkClient.createPersistent(servicePath);
LOGGER.info("create service node: {}", servicePath);
}
//节点示例:/registry/com.db.api.service.cinf.CinfDepartmentsRpcService/address-0000000442 data
String addressPath = servicePath + "/address-";
String addressNode = this.zkClient.createEphemeralSequential(addressPath, serviceAddress);//创建临时序列节点,放入数据服务地址,服务停止时删除,
LOGGER.info("create address node: {}", addressNode);
}
}
RpcServer类
package com.rpc.server;
import com.rpc.common.bean.RpcRequest;
import com.rpc.common.bean.RpcResponse;
import com.rpc.common.codec.RpcDecoder;
import com.rpc.common.codec.RpcEncoder;
import com.rpc.common.util.Global;
import com.rpc.common.util.StringUtil;
import com.rpc.registry.ServiceRegistry;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.collections4.MapUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
//Rpc实际服务类
public class RpcServer implements ApplicationContextAware, InitializingBean {
private static final Logger LOGGER = LoggerFactory.getLogger(RpcServer.class);
private String serviceAddress;//服务地址
private ServiceRegistry serviceRegistry;//服务注册对象
//服务处理方法map ,key:服务接口+版本号 value:服务实现对象,调用方法时使用该对象
private ConcurrentHashMap<String, Object> handlerMap = new ConcurrentHashMap();
public RpcServer(String serviceAddress) {
this.serviceAddress = serviceAddress;
}
public RpcServer(String serviceAddress, ServiceRegistry serviceRegistry) {
this.serviceAddress = serviceAddress;
this.serviceRegistry = serviceRegistry;
}
//ApplicationContextAware接口重写方法,通过spring上下文获取自定义注解
public void setApplicationContext(ApplicationContext ctx) throws BeansException {
Global.ctx = ctx;
Map<String, Object> serviceBeanMap = ctx.getBeansWithAnnotation(RpcService.class);
Object serviceBean;
String serviceName;
if (MapUtils.isNotEmpty(serviceBeanMap)) {
for(Iterator i$ = serviceBeanMap.values().iterator(); i$.hasNext(); this.handlerMap.put(serviceName, serviceBean)) {
serviceBean = i$.next();
RpcService rpcService = (RpcService)serviceBean.getClass().getAnnotation(RpcService.class);
serviceName = rpcService.value().getName();
String serviceVersion = rpcService.version();
if (StringUtil.isNotEmpty(serviceVersion)) {
serviceName = serviceName + "-" + serviceVersion;
}
}
}
}
//InitializingBean接口方法,在初始化bean完后执行,启动netty服务端
public void afterPropertiesSet() throws Exception {
EventLoopGroup bossGroup = new NioEventLoopGroup();
NioEventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(bossGroup, workerGroup);
bootstrap.channel(NioServerSocketChannel.class);
bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
public void initChannel(SocketChannel channel) throws Exception {
ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast(new ChannelHandler[]{new RpcDecoder(RpcRequest.class)});
pipeline.addLast(new ChannelHandler[]{new RpcEncoder(RpcResponse.class)});
pipeline.addLast(new ChannelHandler[]{new RpcServerHandler(RpcServer.this.handlerMap)});
}
});
bootstrap.option(ChannelOption.SO_BACKLOG, 1024);
bootstrap.childOption(ChannelOption.SO_KEEPALIVE, true);
String[] addressArray = StringUtil.split(this.serviceAddress, ":");
String ip = addressArray[0];
int port = Integer.parseInt(addressArray[1]);
ChannelFuture future = bootstrap.bind(ip, port).sync();
//将服务注册到zookeeper中
if (this.serviceRegistry != null) {
Iterator i$ = this.handlerMap.keySet().iterator();
while(i$.hasNext()) {
String interfaceName = (String)i$.next();
this.serviceRegistry.register(interfaceName, this.serviceAddress);
LOGGER.info("register service: {} => {}", interfaceName, this.serviceAddress);
}
}
LOGGER.info("server started on port {}", port);
future.channel().closeFuture().sync();
} finally {
workerGroup.shutdownGracefully();
bossGroup.shutdownGracefully();
}
}
}
3、netty服务端启动部分类
编解码类
package com.rpc.common.codec;
import com.rpc.common.util.SerializationUtil;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import java.util.List;
//netty解码类,协议:数据长度(4字节)+数据
public class RpcDecoder extends ByteToMessageDecoder {
private Class<?> genericClass;//编码对象类型
public RpcDecoder(Class<?> genericClass) {
this.genericClass = genericClass;
}
public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (in.readableBytes() >= 4) {
in.markReaderIndex();
int dataLength = in.readInt();
if (in.readableBytes() < dataLength) {
in.resetReaderIndex();
} else {
byte[] data = new byte[dataLength];
in.readBytes(data);
out.add(SerializationUtil.deserialize(data, this.genericClass));
}
}
}
}
package com.rpc.common.codec;
import com.rpc.common.util.SerializationUtil;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;
//netty编码类,协议:数据长度(4字节)+数据
public class RpcEncoder extends MessageToByteEncoder {
private Class<?> genericClass;//解码对象类型
public RpcEncoder(Class<?> genericClass) {
this.genericClass = genericClass;
}
public void encode(ChannelHandlerContext ctx, Object in, ByteBuf out) throws Exception {
if (this.genericClass.isInstance(in)) {
byte[] data = SerializationUtil.serialize(in);
out.writeInt(data.length);
out.writeBytes(data);
}
}
}
序列化工具类
package com.rpc.common.util;
import com.dyuproject.protostuff.LinkedBuffer;
import com.dyuproject.protostuff.ProtostuffIOUtil;
import com.dyuproject.protostuff.Schema;
import com.dyuproject.protostuff.runtime.RuntimeSchema;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.objenesis.Objenesis;
import org.objenesis.ObjenesisStd;
//序列化工具类
public class SerializationUtil {
private static Map<Class<?>, Schema<?>> cachedSchema = new ConcurrentHashMap();
private static Objenesis objenesis = new ObjenesisStd(true);
private SerializationUtil() {
}
public static <T> byte[] serialize(T obj) {
Class<T> cls = obj.getClass();//序列化的对象必须有无参构造方法
LinkedBuffer buffer = LinkedBuffer.allocate(512);
byte[] var4;
try {
Schema<T> schema = getSchema(cls);
var4 = ProtostuffIOUtil.toByteArray(obj, schema, buffer);
} catch (Exception var8) {
throw new IllegalStateException(var8.getMessage(), var8);
} finally {
buffer.clear();
}
return var4;
}
public static <T> T deserialize(byte[] data, Class<T> cls) {
try {
T message = objenesis.newInstance(cls);
Schema<T> schema = getSchema(cls);
ProtostuffIOUtil.mergeFrom(data, message, schema);
return message;
} catch (Exception var4) {
throw new IllegalStateException(var4.getMessage(), var4);
}
}
//模式缓存,需要的时候用
private static <T> Schema<T> getSchema(Class<T> cls) {
Schema<T> schema = (Schema)cachedSchema.get(cls);
if (schema == null) {
schema = RuntimeSchema.createFrom(cls);
cachedSchema.put(cls, schema);
}
return (Schema)schema;
}
}
server端处理类
package com.rpc.server;
import com.rpc.common.bean.RpcRequest;
import com.rpc.common.bean.RpcResponse;
import com.rpc.common.util.StringUtil;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import java.util.concurrent.ConcurrentHashMap;
import net.sf.cglib.reflect.FastClass;
import net.sf.cglib.reflect.FastMethod;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
//服务端请求处理类
public class RpcServerHandler extends SimpleChannelInboundHandler<RpcRequest> {
private static final Logger LOGGER = LoggerFactory.getLogger(RpcServerHandler.class);
private final ConcurrentHashMap<String, Object> handlerMap;
public RpcServerHandler(ConcurrentHashMap<String, Object> handlerMap) {
this.handlerMap = handlerMap;
}
//读取数据
public void channelRead0(ChannelHandlerContext ctx, RpcRequest request) throws Exception {
RpcResponse response = new RpcResponse();
response.setRequestId(request.getRequestId());
try {
Object result = this.handle(request);
response.setResult(result);
} catch (Exception var5) {
LOGGER.error("handle result failure", var5);
var5.printStackTrace();
response.setException(var5);
}
//操作完成后,关闭连接
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
}
private Object handle(RpcRequest request) throws Exception {
String serviceName = request.getInterfaceName();
String serviceVersion = request.getServiceVersion();
if (StringUtil.isNotEmpty(serviceVersion)) {
serviceName = serviceName + "-" + serviceVersion;
}
Object serviceBean = this.handlerMap.get(serviceName);
if (serviceBean == null) {
throw new RuntimeException(String.format("can not find service bean by key: %s", serviceName));
} else {
Class<?> serviceClass = serviceBean.getClass();
String methodName = request.getMethodName();
Class<?>[] parameterTypes = request.getParameterTypes();
Object[] parameters = request.getParameters();
//cglib动态代理调用方法
FastClass serviceFastClass = FastClass.create(serviceClass);
FastMethod serviceFastMethod = serviceFastClass.getMethod(methodName, parameterTypes);
return serviceFastMethod.invoke(serviceBean, parameters);
}
}
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
LOGGER.error("server caught exception", cause);
ctx.close();
}
}
统一的请求和响应对象
package com.rpc.common.bean;
//远程调用的请求对象,会序列化,必须有无参构造方法
public class RpcRequest {
private String requestId;//请求id
private String interfaceName;//方法对应接口名字
private String serviceVersion;//接口版本号
private String methodName;//方法名
private Class<?>[] parameterTypes;//参数类型
private Object[] parameters;//参数
public RpcRequest() {
}
//省略set get
}
package com.rpc.common.bean;
//远程调用的响应对象,会序列化,必须有无参构造方法
public class RpcResponse {
private String requestId;
private Exception exception;
private Object result;//响应结果,实际类型是接口对应方法声明的返回值类型
public RpcResponse() {
}
//省略set get
}
具体服务类
package com.db.api.cinf;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.annotation.Autowired;
import com.db.api.entity.bankcomm.ApFdOrg;
import com.db.api.service.cinf.ApFdOrgRpcService;
import com.db.service.cinf.ApFdOrgService;
import com.rpc.server.RpcService;
/**
* 具体服务类
*/
@RpcService(ApFdOrgRpcService.class)
public class ApFdOrgRpcServiceImpl implements ApFdOrgRpcService {
private static Log log = LogFactory.getLog(ApFdOrgRpcServiceImpl.class);
@Autowired
private ApFdOrgService apFdOrgService;
//插入临时库
@Override
public int insertOrUpdateApFdOrg(ApFdOrg org) {
return apFdOrgService.insertOrUpdateApFdOrg(org);
}
}
服务消费方
客户端发送请求逻辑
0、启动加载服务获取地址
1、获取RpcProxy的对象
2、调用create方法,创建对应代理对象
3、通过代理对象调用方法(使用JDK的动态代理)
4、调用方法时先构造RpcRequest请求对象
5、再调用获取服务地址方法
6、根据服务地址构造RpcClient请求客户端对象
7、RpcClient发送请求,获取响应
8、RpcClient发送请求方法使用netty客户端(添加编解码及客户端处理handler)实现
9、同步获取响应,使用客户端处理handler的获取响应方法
4、spring加载rpc相应配置
<bean id="serviceDiscovery" class="com.rpc.registry.zookeeper.ZooKeeperServiceDiscovery">
<constructor-arg name="zkAddress" value="${rpc.registry_address}"/>
</bean>
<bean id="rpcProxy" class="com.rpc.client.RpcProxy">
<constructor-arg name="serviceDiscovery" ref="serviceDiscovery"/>
</bean>
服务发现类
package com.rpc.registry;
//服务发现接口
public interface ServiceDiscovery {
String discover(String var1);
}
package com.rpc.registry.zookeeper;
import com.rpc.common.util.CollectionUtil;
import com.rpc.registry.ServiceDiscovery;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import org.I0Itec.zkclient.ZkClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
//服务发现实现类
public class ZooKeeperServiceDiscovery implements ServiceDiscovery {
private static final Logger LOGGER = LoggerFactory.getLogger(ZooKeeperServiceDiscovery.class);
private String zkAddress;
public ZooKeeperServiceDiscovery(String zkAddress) {
this.zkAddress = zkAddress;
}
//先获取对应服务节点,然后获取子节点,获取子节点中的服务地址
public String discover(String name) {
ZkClient zkClient = new ZkClient(this.zkAddress, 5000, 1000);
LOGGER.info("connect zookeeper");
String var8;
try {
String servicePath = "/registry/" + name;
if (!zkClient.exists(servicePath)) {
throw new RuntimeException(String.format("can not find any service node on path: %s", servicePath));
}
List<String> addressList = zkClient.getChildren(servicePath);
if (CollectionUtil.isEmpty(addressList)) {
throw new RuntimeException(String.format("can not find any address node on path: %s", servicePath));
}
int size = addressList.size();
String address;
if (size == 1) {
address = (String)addressList.get(0);
LOGGER.info("get only address node: {}", address);
} else {
address = (String)addressList.get(ThreadLocalRandom.current().nextInt(size));
LOGGER.info("get random address node: {}", address);
}
String addressPath = servicePath + "/" + address;
var8 = (String)zkClient.readData(addressPath);
} finally {
zkClient.close();
}
return var8;
}
}
远程调用代理类
package com.rpc.client;
import com.rpc.common.bean.RpcRequest;
import com.rpc.common.bean.RpcResponse;
import com.rpc.common.util.StringUtil;
import com.rpc.registry.ServiceDiscovery;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
//远程调用代理类
public class RpcProxy {
private static final Logger LOGGER = LoggerFactory.getLogger(RpcProxy.class);
private String serviceAddress;
private ServiceDiscovery serviceDiscovery;
public RpcProxy(String serviceAddress) {
this.serviceAddress = serviceAddress;
}
public RpcProxy(ServiceDiscovery serviceDiscovery) {
this.serviceDiscovery = serviceDiscovery;
}
public <T> T create(Class<?> interfaceClass) {
return this.create(interfaceClass, "");
}
//代理对象调用方法时,会调用invoke方法
public <T> T create(final Class<?> interfaceClass, final String serviceVersion) {
return Proxy.newProxyInstance(interfaceClass.getClassLoader(), new Class[]{interfaceClass}, new InvocationHandler() {
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
RpcRequest request = new RpcRequest();
request.setRequestId(UUID.randomUUID().toString());
request.setInterfaceName(method.getDeclaringClass().getName());
request.setServiceVersion(serviceVersion);
request.setMethodName(method.getName());
request.setParameterTypes(method.getParameterTypes());
request.setParameters(args);
//发现服务
if (RpcProxy.this.serviceDiscovery != null) {
String serviceName = interfaceClass.getName();
if (StringUtil.isNotEmpty(serviceVersion)) {
serviceName = serviceName + "-" + serviceVersion;
}
RpcProxy.this.serviceAddress = RpcProxy.this.serviceDiscovery.discover(serviceName);
RpcProxy.LOGGER.debug("discover service: {} => {}", serviceName, RpcProxy.this.serviceAddress);
}
if (StringUtil.isEmpty(RpcProxy.this.serviceAddress)) {
throw new RuntimeException("server address is empty");
} else {
String[] array = StringUtil.split(RpcProxy.this.serviceAddress, ":");
String host = array[0];
int port = Integer.parseInt(array[1]);
//创建客户端进行远程调用
RpcClient client = new RpcClient(host, port);
long time = System.currentTimeMillis();
RpcResponse response = client.send(request);
RpcProxy.LOGGER.debug("time: {}ms", System.currentTimeMillis() - time);
if (response == null) {
throw new RuntimeException("response is null");
} else if (response.hasException()) {
throw response.getException();
} else {
return response.getResult();
}
}
}
});
}
}
5、客户端发送请求
远程调用客户端
package com.rpc.client;
import com.rpc.common.bean.RpcRequest;
import com.rpc.common.bean.RpcResponse;
import com.rpc.common.codec.RpcDecoder;
import com.rpc.common.codec.RpcEncoder;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
//远程调用客户端
public class RpcClient extends SimpleChannelInboundHandler<RpcResponse> {
private static final Logger LOGGER = LoggerFactory.getLogger(RpcClient.class);
private final String host;
private final int port;
private RpcResponse response;
public RpcClient(String host, int port) {
this.host = host;
this.port = port;
}
public void channelRead0(ChannelHandlerContext ctx, RpcResponse response) throws Exception {
this.response = response;
}
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
LOGGER.error("api caught exception", cause);
ctx.close();
}
//发送请求
public RpcResponse send(RpcRequest request) throws Exception {
NioEventLoopGroup group = new NioEventLoopGroup();
RpcResponse var7;
try {
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(group);
bootstrap.channel(NioSocketChannel.class);
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
public void initChannel(SocketChannel channel) throws Exception {
ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast(new ChannelHandler[]{new RpcEncoder(RpcRequest.class)});
pipeline.addLast(new ChannelHandler[]{new RpcDecoder(RpcResponse.class)});
pipeline.addLast(new ChannelHandler[]{RpcClient.this});
}
});
bootstrap.option(ChannelOption.TCP_NODELAY, true);
ChannelFuture future = bootstrap.connect(this.host, this.port).sync();
Channel channel = future.channel();
channel.writeAndFlush(request).sync();
channel.closeFuture().sync();
var7 = this.response;
} finally {
group.shutdownGracefully();
}
return var7;
}
}
客户端使用
RpcProxy rpcProxy = (RpcProxy) SpringHelper.getBean("rpcProxy");
CinfDepartmentsRpcService departmentsRpcService = (CinfDepartmentsRpcService) rpcProxy.create(CinfDepartmentsRpcService.class);
//根据eid和depttype查询部门信息
List<CinfDepartment> deptList = departmentsRpcService.getDeptsByEidAndType(Constant.BANKCOMM_EID, 1);