netty4用最简单的协议解决一个半包问题

有时候简化实现别人的代码,有助于你更好的理解代码,不要一味地读源代码。

问题来源

客户端往服务器发送小文件

解决思路

1、使用netty(废话)
2、只是用ByteBuf
3、自定义一种协议,用最小的网络代价完成数据传送

实现

其实netty有很多的定义好的协议来解决各种各样的问题,这篇文章来自《netty权威指南》作者李林峰,详细介绍了netty的编解码框架,以及一些常用的编解码协议。

在解决这个问题的时候,我遇到的一个主要问题就是我在客户端发送一个数据包,这个数据包的大小可以很大,但是如果只用简单的channelRead去读取数据的话得到的数据并不是完整的。具体原因参考netty用户指南中的tcp stream-based传输的问题。

我先做了一个简单的协议设计:
packet = |文件名长度|文件名|文件字节长度|文件字节流|

于是就有了客户端发送的简单代码

            String name = "diagram.png";
            FileInputStream fileInputStream = new FileInputStream(new File("src/main/resources/diagram.png"));
            byte[] bytes = new byte[fileInputStream.available()];
            fileInputStream.read(bytes);

            ByteBuf byteBuf = Unpooled.buffer();

            byteBuf.writeInt("diagram.png".getBytes().length);
            byteBuf.writeBytes("diagram.png".getBytes());

            byteBuf.writeInt(bytes.length);
            byteBuf.writeBytes(bytes);
            channelFuture.channel().writeAndFlush(byteBuf);

这样发送没有问题,因为byteBuf是动态扩展的。但是接受的时候就有问题了。如果我们接受比较小的,比如一个int,我们可以直接这样写

 @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        
        if(msg instanceof ByteBuf)
        {
            ByteBuf byteBuf = (ByteBuf)msg;
            if(byteBuf.readableBytes() > 4)
            {
                int result = byteBuf.readInt();
            }
        }
}

但是当长度很大的时候,我们就需要解决读半包的问题了。直到读到完整的数据才进行处理。但是每次收到的数据怎么去判断是不是和上一个数据是连续的,如何在没有收集到完整数据时不处理数据而继续接受呢?这是我一直困扰的问题。因为我把每一个byteBuf当成一个message来想了,其实不是的,ByteBuf中有两个指针readIndex和writeIndex,readIndex永远小于writeIndex。大概如下图所示


ByteBuf示例图

在netty的设计中ByteBuf是可以被重用的,所以可能针对这一个ChannelRead一直读取的是同一个ByteBuf。这其中readrIndex之前的是已经读取过的,就是已经被调用readXXX()之后的数据,可以重新去读取,readerIndex和writerIndex之前的是当前的readableBytes,writerIndex到capacity的是writeableBytes,当writerIndex超过capacity时就会扩展。同时为了重用这部分空间,当调用discardBytes时,会把readerIndex和writerIndex拷贝到开头,这样前面废弃的部分就被重用了,也一定程度场避免了扩容,节省了空间。

那如何针对上面的输入写ByteBuf的解码呢?
先看看netty自带的解码器怎么解决这个问题,其中LengthFieldBasedFrameDecoder就是用来解决这一类的问题的。在李林峰的文章中有详细介绍,这里就不赘述了。
我在之前代码的基础上添加了两行代码。

//在服务器的pipeline中添加的这个解码器,然后用4个字节表示整个包的长度,并且废弃掉这四个字节。
ch.pipeline().addLast(new LengthFieldBasedFrameDecoder(1024*1024, 0, 4, 0, 4));

//在发送的byteBuf头部添加真个包的长度
byteBuf.writeInt(4+ name.getBytes().length +4+ bytes.length);

然后我再在ChannelRead中处理剩下的数据
packet = |文件名长度|文件名|文件字节长度|文件字节流|

       if(msg instanceof ByteBuf)
        {
            ByteBuf byteBuf = (ByteBuf)msg;
            int nameSize = byteBuf.readInt();
            String name = new String(byteBuf.readBytes(nameSize).array(), "UTF-8");
            int fileSize = byteBuf.readInt();
            FileOutputStream fileOutputStream = new FileOutputStream(new File(name));
            fileOutputStream.write(byteBuf.readBytes(fileSize).array());
            System.out.println(name + " " + fileSize);
        }

问题解决,但是自己如何实现这个解码器呢?先看看netty怎么实现的。

protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
        if (discardingTooLongFrame) {
            long bytesToDiscard = this.bytesToDiscard;
            int localBytesToDiscard = (int) Math.min(bytesToDiscard, in.readableBytes());
            in.skipBytes(localBytesToDiscard);
            bytesToDiscard -= localBytesToDiscard;
            this.bytesToDiscard = bytesToDiscard;

            failIfNecessary(false);
        }

        if (in.readableBytes() < lengthFieldEndOffset) {
            return null;
        }

        int actualLengthFieldOffset = in.readerIndex() + lengthFieldOffset;
        long frameLength = getUnadjustedFrameLength(in, actualLengthFieldOffset, lengthFieldLength, byteOrder);

        if (frameLength < 0) {
            in.skipBytes(lengthFieldEndOffset);
            throw new CorruptedFrameException(
                    "negative pre-adjustment length field: " + frameLength);
        }

        frameLength += lengthAdjustment + lengthFieldEndOffset;

        if (frameLength < lengthFieldEndOffset) {
            in.skipBytes(lengthFieldEndOffset);
            throw new CorruptedFrameException(
                    "Adjusted frame length (" + frameLength + ") is less " +
                    "than lengthFieldEndOffset: " + lengthFieldEndOffset);
        }

        if (frameLength > maxFrameLength) {
            long discard = frameLength - in.readableBytes();
            tooLongFrameLength = frameLength;

            if (discard < 0) {
                // buffer contains more bytes then the frameLength so we can discard all now
                in.skipBytes((int) frameLength);
            } else {
                // Enter the discard mode and discard everything received so far.
                discardingTooLongFrame = true;
                bytesToDiscard = discard;
                in.skipBytes(in.readableBytes());
            }
            failIfNecessary(true);
            return null;
        }

        // never overflows because it's less than maxFrameLength
        int frameLengthInt = (int) frameLength;
        if (in.readableBytes() < frameLengthInt) {
            return null;
        }

        if (initialBytesToStrip > frameLengthInt) {
            in.skipBytes(frameLengthInt);
            throw new CorruptedFrameException(
                    "Adjusted frame length (" + frameLength + ") is less " +
                    "than initialBytesToStrip: " + initialBytesToStrip);
        }
        in.skipBytes(initialBytesToStrip);

        // extract frame
        int readerIndex = in.readerIndex();
        int actualFrameLength = frameLengthInt - initialBytesToStrip;
        ByteBuf frame = extractFrame(ctx, in, readerIndex, actualFrameLength);
        in.readerIndex(readerIndex + actualFrameLength);
        return frame;
    }

好长。。里面对于不合理的协议做了很多假设,并使不合理的输入快速失败。但是让我一个初学者写还是写不出来。所以我假设协议就是我设计的那样,简化这部分代码,便于理解。
变量给一个固定值

    private ByteOrder byteOrder = ByteOrder.BIG_ENDIAN;
    private int maxFrameLength = 1024*10;
    private int lengthFieldLength = 4;
    private int initialBytesToStrip = 0;
    private long tooLongFrameLength;
    private long bytesToDiscard;
    private boolean failFast = true;

然后写decode函数,就这么简单。。

 protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {

        int frameLength = (int) in.getUnsignedInt(0);//获取头部
        if(in.readableBytes() < frameLength)//当ByteBuf没有达到长度时,return null
        {
            return null;
        }
        in.skipBytes(4);//舍弃头部
        int index =  in.readerIndex();
        ByteBuf frame = in.slice(index, frameLength).retain();//取出自己定义的packet包返回给ChannelRead

        in.readerIndex(frameLength);//这一步一定要有,不然其实bytebuf的readerIndex没有变,netty会一直从这里开始读取,将readerIndex移动就相当于把前面的数据处理过了废弃掉了。
        return  frame;
    }

所以其实我们只要不处理bytebuf的数据知道可以读的数据达到我们需要的长度在处理就可以了。当然包的顺序不会出错是由底层tcp保证的,不用关心。

推荐阅读更多精彩内容