001    /**
002     * Copyright (C) 2012 FuseSource, Inc.
003     * http://fusesource.com
004     *
005     * Licensed under the Apache License, Version 2.0 (the "License");
006     * you may not use this file except in compliance with the License.
007     * You may obtain a copy of the License at
008     *
009     *    http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.fusesource.hawtdispatch.transport;
019    
020    import org.fusesource.hawtbuf.Buffer;
021    import org.fusesource.hawtbuf.DataByteArrayOutputStream;
022    import org.fusesource.hawtdispatch.util.BufferPool;
023    import org.fusesource.hawtdispatch.util.BufferPools;
024    
025    import java.io.EOFException;
026    import java.io.IOException;
027    import java.net.ProtocolException;
028    import java.net.SocketException;
029    import java.nio.ByteBuffer;
030    import java.nio.channels.GatheringByteChannel;
031    import java.nio.channels.ReadableByteChannel;
032    import java.nio.channels.SocketChannel;
033    import java.util.Arrays;
034    import java.util.LinkedList;
035    
036    /**
037     * Provides an abstract base class to make implementing the ProtocolCodec interface
038     * easier.
039     *
040     * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
041     */
042    public abstract class AbstractProtocolCodec implements ProtocolCodec {
043    
044        protected BufferPools bufferPools;
045        protected BufferPool writeBufferPool;
046        protected BufferPool readBufferPool;
047    
048        protected int writeBufferSize = 1024 * 64;
049        protected long writeCounter = 0L;
050        protected GatheringByteChannel writeChannel = null;
051        protected DataByteArrayOutputStream nextWriteBuffer;
052        protected long lastWriteIoSize = 0;
053    
054        protected LinkedList<ByteBuffer> writeBuffer = new LinkedList<ByteBuffer>();
055        private long writeBufferRemaining = 0;
056    
057    
058        public static interface Action {
059            Object apply() throws IOException;
060        }
061    
062        protected long readCounter = 0L;
063        protected int readBufferSize = 1024 * 64;
064        protected ReadableByteChannel readChannel = null;
065        protected ByteBuffer readBuffer;
066        protected ByteBuffer directReadBuffer = null;
067    
068        protected int readEnd;
069        protected int readStart;
070        protected int lastReadIoSize;
071        protected Action nextDecodeAction;
072    
073        public void setTransport(Transport transport) {
074            this.writeChannel = (GatheringByteChannel) transport.getWriteChannel();
075            this.readChannel = transport.getReadChannel();
076            if( nextDecodeAction==null ) {
077                nextDecodeAction = initialDecodeAction();
078            }
079            if( transport instanceof TcpTransport) {
080                TcpTransport tcp = (TcpTransport) transport;
081                writeBufferSize = tcp.getSendBufferSize();
082                readBufferSize = tcp.getReceiveBufferSize();
083            } else if( transport instanceof UdpTransport) {
084                UdpTransport tcp = (UdpTransport) transport;
085                writeBufferSize = tcp.getSendBufferSize();
086                readBufferSize = tcp.getReceiveBufferSize();
087            } else {
088                try {
089                    if (this.writeChannel instanceof SocketChannel) {
090                        writeBufferSize = ((SocketChannel) this.writeChannel).socket().getSendBufferSize();
091                        readBufferSize = ((SocketChannel) this.readChannel).socket().getReceiveBufferSize();
092                    } else if (this.writeChannel instanceof SslTransport.SSLChannel) {
093                        writeBufferSize = ((SslTransport.SSLChannel) this.readChannel).socket().getSendBufferSize();
094                        readBufferSize = ((SslTransport.SSLChannel) this.writeChannel).socket().getReceiveBufferSize();
095                    }
096                } catch (SocketException ignore) {
097                }
098            }
099            if( bufferPools!=null ) {
100                readBufferPool = bufferPools.getBufferPool(readBufferSize);
101                writeBufferPool = bufferPools.getBufferPool(writeBufferSize);
102            }
103        }
104    
105        public int getReadBufferSize() {
106            return readBufferSize;
107        }
108    
109        public int getWriteBufferSize() {
110            return writeBufferSize;
111        }
112    
113        public boolean full() {
114            return writeBufferRemaining >= writeBufferSize;
115        }
116    
117        public boolean isEmpty() {
118            return writeBufferRemaining == 0 && (nextWriteBuffer==null || nextWriteBuffer.size() == 0);
119        }
120    
121        public long getWriteCounter() {
122            return writeCounter;
123        }
124    
125        public long getLastWriteSize() {
126            return lastWriteIoSize;
127        }
128    
129        abstract protected void encode(Object value) throws IOException;
130    
131        public ProtocolCodec.BufferState write(Object value) throws IOException {
132            if (full()) {
133                return ProtocolCodec.BufferState.FULL;
134            } else {
135                boolean wasEmpty = isEmpty();
136                if( nextWriteBuffer == null ) {
137                    nextWriteBuffer = allocateNextWriteBuffer();
138                }
139                encode(value);
140                if (nextWriteBuffer.size() >= (writeBufferSize* 0.75)) {
141                    flushNextWriteBuffer();
142                }
143                if (wasEmpty) {
144                    return ProtocolCodec.BufferState.WAS_EMPTY;
145                } else {
146                    return ProtocolCodec.BufferState.NOT_EMPTY;
147                }
148            }
149        }
150    
151        private DataByteArrayOutputStream allocateNextWriteBuffer() {
152            if( writeBufferPool !=null ) {
153                return new DataByteArrayOutputStream(writeBufferPool.checkout()) {
154                    @Override
155                    protected void resize(int newcount) {
156                        byte[] oldbuf = buf;
157                        super.resize(newcount);
158                        if( oldbuf.length == writeBufferPool.getBufferSize() ) {
159                            writeBufferPool.checkin(oldbuf);
160                        }
161                    }
162                };
163            } else {
164                return new DataByteArrayOutputStream(writeBufferSize);
165            }
166        }
167    
168        protected void writeDirect(ByteBuffer value) throws IOException {
169            // is the direct buffer small enough to just fit into the nextWriteBuffer?
170            int nextnextPospos = nextWriteBuffer.position();
171            int valuevalueLengthlength = value.remaining();
172            int available = nextWriteBuffer.getData().length - nextnextPospos;
173            if (available > valuevalueLengthlength) {
174                value.get(nextWriteBuffer.getData(), nextnextPospos, valuevalueLengthlength);
175                nextWriteBuffer.position(nextnextPospos + valuevalueLengthlength);
176            } else {
177                if (nextWriteBuffer!=null && nextWriteBuffer.size() != 0) {
178                    flushNextWriteBuffer();
179                }
180                writeBuffer.add(value);
181                writeBufferRemaining += value.remaining();
182            }
183        }
184    
185        protected void flushNextWriteBuffer() {
186            DataByteArrayOutputStream next = allocateNextWriteBuffer();
187            ByteBuffer bb = nextWriteBuffer.toBuffer().toByteBuffer();
188            writeBuffer.add(bb);
189            writeBufferRemaining += bb.remaining();
190            nextWriteBuffer = next;
191        }
192    
193        public ProtocolCodec.BufferState flush() throws IOException {
194            while (true) {
195                if (writeBufferRemaining != 0) {
196                    if( writeBuffer.size() == 1) {
197                        ByteBuffer b = writeBuffer.getFirst();
198                        lastWriteIoSize = writeChannel.write(b);
199                        if (lastWriteIoSize == 0) {
200                            return ProtocolCodec.BufferState.NOT_EMPTY;
201                        } else {
202                            writeBufferRemaining -= lastWriteIoSize;
203                            writeCounter += lastWriteIoSize;
204                            if(!b.hasRemaining()) {
205                                onBufferFlushed(writeBuffer.removeFirst());
206                            }
207                        }
208                    } else {
209                        ByteBuffer[] buffers = writeBuffer.toArray(new ByteBuffer[writeBuffer.size()]);
210                        lastWriteIoSize = writeChannel.write(buffers, 0, buffers.length);
211                        if (lastWriteIoSize == 0) {
212                            return ProtocolCodec.BufferState.NOT_EMPTY;
213                        } else {
214                            writeBufferRemaining -= lastWriteIoSize;
215                            writeCounter += lastWriteIoSize;
216                            while (!writeBuffer.isEmpty() && !writeBuffer.getFirst().hasRemaining()) {
217                                onBufferFlushed(writeBuffer.removeFirst());
218                            }
219                        }
220                    }
221                } else {
222                    if (nextWriteBuffer==null || nextWriteBuffer.size() == 0) {
223                        if( writeBufferPool!=null &&  nextWriteBuffer!=null ) {
224                            writeBufferPool.checkin(nextWriteBuffer.getData());
225                            nextWriteBuffer = null;
226                        }
227                        return ProtocolCodec.BufferState.EMPTY;
228                    } else {
229                        flushNextWriteBuffer();
230                    }
231                }
232            }
233        }
234    
235        /**
236         * Called when a buffer is flushed out.  Subclasses can implement
237         * in case they want to recycle the buffer.
238         *
239         * @param byteBuffer
240         */
241        protected void onBufferFlushed(ByteBuffer byteBuffer) {
242        }
243    
244        /////////////////////////////////////////////////////////////////////
245        //
246        // Non blocking read impl
247        //
248        /////////////////////////////////////////////////////////////////////
249    
250        abstract protected Action initialDecodeAction();
251    
252    
253        public void unread(byte[] buffer) {
254            assert ((readCounter == 0));
255            readBuffer = ByteBuffer.allocate(buffer.length);
256            readBuffer.put(buffer);
257            readCounter += buffer.length;
258        }
259    
260        public long getReadCounter() {
261            return readCounter;
262        }
263    
264        public long getLastReadSize() {
265            return lastReadIoSize;
266        }
267    
268        public Object read() throws IOException {
269            Object command = null;
270            while (command == null) {
271                if (directReadBuffer != null) {
272                    while (directReadBuffer.hasRemaining()) {
273                        lastReadIoSize = readChannel.read(directReadBuffer);
274                        readCounter += lastReadIoSize;
275                        if (lastReadIoSize == -1) {
276                            throw new EOFException("Peer disconnected");
277                        } else if (lastReadIoSize == 0) {
278                            return null;
279                        }
280                    }
281                    command = nextDecodeAction.apply();
282                } else {
283                    if (readBuffer==null || readEnd >= readBuffer.position()) {
284    
285                        int readPos = 0;
286                        boolean candidateForCheckin = false;
287                        if( readBuffer!=null ) {
288                            readPos = readBuffer.position();
289                            candidateForCheckin = readBufferPool!=null && readStart == 0 && readBuffer.capacity() == readBufferPool.getBufferSize();
290                        }
291    
292                        if (readBuffer==null || readBuffer.remaining() == 0) {
293    
294    
295                            int loadedSize = readPos - readStart;
296                            int neededSize = readEnd - readStart;
297    
298                            int newSize = 0;
299                            if( neededSize > loadedSize ) {
300                                newSize =  Math.max(readBufferSize, neededSize);
301                            } else {
302                                newSize = loadedSize+readBufferSize;
303                            }
304    
305                            byte[] newBuffer;
306                            if (loadedSize > 0) {
307                                newBuffer = Arrays.copyOfRange(readBuffer.array(), readStart, readStart + newSize);
308                            } else {
309                                if( readBufferPool!=null && newSize == readBufferPool.getBufferSize()) {
310                                    newBuffer = readBufferPool.checkout();
311                                } else {
312                                    newBuffer =  new byte[newSize];
313                                }
314                            }
315    
316                            if( candidateForCheckin ) {
317                                readBufferPool.checkin(readBuffer.array());
318                            }
319    
320                            readBuffer = ByteBuffer.wrap(newBuffer);
321                            readBuffer.position(loadedSize);
322                            readStart = 0;
323                            readEnd = neededSize;
324                        }
325    
326                        lastReadIoSize = readChannel.read(readBuffer);
327    
328                        readCounter += lastReadIoSize;
329                        if (lastReadIoSize == -1) {
330                            readCounter += 1; // to compensate for that -1
331                            throw new EOFException("Peer disconnected");
332                        } else if (lastReadIoSize == 0) {
333                            if ( readStart == readBuffer.position() ) {
334                                if (candidateForCheckin) {
335                                    readBufferPool.checkin(readBuffer.array());
336                                }
337                                readStart = 0;
338                                readEnd = 0;
339                                readBuffer = null;
340                            }
341                            return null;
342                        }
343    
344                        // if we did not read a full buffer.. then resize the buffer
345                        if( readBuffer.hasRemaining() && readEnd <= readBuffer.position() ) {
346                            ByteBuffer perfectSized = ByteBuffer.wrap(Arrays.copyOfRange(readBuffer.array(), 0, readBuffer.position()));
347                            perfectSized.position(readBuffer.position());
348    
349                            if( candidateForCheckin ) {
350                                readBufferPool.checkin(readBuffer.array());
351                            }
352                            readBuffer = perfectSized;
353                        }
354                    }
355                    command = nextDecodeAction.apply();
356                    assert ((readStart <= readEnd));
357                }
358            }
359            return command;
360        }
361    
362        protected Buffer readUntil(Byte octet) throws ProtocolException {
363            return readUntil(octet, -1);
364        }
365    
366        protected Buffer readUntil(Byte octet, int max) throws ProtocolException {
367            return readUntil(octet, max, "Maximum protocol buffer length exeeded");
368        }
369    
370        protected Buffer readUntil(Byte octet, int max, String msg) throws ProtocolException {
371            byte[] array = readBuffer.array();
372            Buffer buf = new Buffer(array, readEnd, readBuffer.position() - readEnd);
373            int pos = buf.indexOf(octet);
374            if (pos >= 0) {
375                int offset = readStart;
376                readEnd += pos + 1;
377                readStart = readEnd;
378                int length = readEnd - offset;
379                if (max >= 0 && length > max) {
380                    throw new ProtocolException(msg);
381                }
382                return new Buffer(array, offset, length);
383            } else {
384                readEnd += buf.length;
385                if (max >= 0 && (readEnd - readStart) > max) {
386                    throw new ProtocolException(msg);
387                }
388                return null;
389            }
390        }
391    
392        protected Buffer readBytes(int length) {
393            readEnd = readStart + length;
394            if (readBuffer.position() < readEnd) {
395                return null;
396            } else {
397                int offset = readStart;
398                readStart = readEnd;
399                return new Buffer(readBuffer.array(), offset, length);
400            }
401        }
402    
403        protected Buffer peekBytes(int length) {
404            readEnd = readStart + length;
405            if (readBuffer.position() < readEnd) {
406                return null;
407            } else {
408                // rewind..
409                readEnd = readStart;
410                return new Buffer(readBuffer.array(), readStart, length);
411            }
412        }
413    
414        protected Boolean readDirect(ByteBuffer buffer) {
415            assert (directReadBuffer == null || (directReadBuffer == buffer));
416    
417            if (buffer.hasRemaining()) {
418                // First we need to transfer the read bytes from the non-direct
419                // byte buffer into the direct one..
420                int limit = readBuffer.position();
421                int transferSize = Math.min((limit - readStart), buffer.remaining());
422                byte[] readBufferArray = readBuffer.array();
423                buffer.put(readBufferArray, readStart, transferSize);
424    
425                // The direct byte buffer might have been smaller than our readBuffer one..
426                // compact the readBuffer to avoid doing additional mem allocations.
427                int trailingSize = limit - (readStart + transferSize);
428                if (trailingSize > 0) {
429                    System.arraycopy(readBufferArray, readStart + transferSize, readBufferArray, readStart, trailingSize);
430                }
431                readBuffer.position(readStart + trailingSize);
432            }
433    
434            // For big direct byte buffers, it will still not have been filled,
435            // so install it so that we directly read into it until it is filled.
436            if (buffer.hasRemaining()) {
437                directReadBuffer = buffer;
438                return false;
439            } else {
440                directReadBuffer = null;
441                buffer.flip();
442                return true;
443            }
444        }
445    
446        public BufferPools getBufferPools() {
447            return bufferPools;
448        }
449    
450        public void setBufferPools(BufferPools bufferPools) {
451            this.bufferPools = bufferPools;
452            if( bufferPools!=null ) {
453                readBufferPool = bufferPools.getBufferPool(readBufferSize);
454                writeBufferPool = bufferPools.getBufferPool(writeBufferSize);
455            } else {
456                readBufferPool = null;
457                writeBufferPool = null;
458            }
459        }
460    }