001    /*****************************************************************************
002     * Copyright (C) PicoContainer Organization. All rights reserved.            *
003     * ------------------------------------------------------------------------- *
004     * The software in this package is published under the terms of the BSD      *
005     * style license a copy of which has been included with this distribution in *
006     * the LICENSE.txt file.                                                     *
007     *                                                                           *
008     * Original code by                                                          *
009     *****************************************************************************/
010    package org.picocontainer.gems.behaviors;
011    
012    import java.lang.reflect.Constructor;
013    import java.lang.reflect.InvocationTargetException;
014    import java.lang.reflect.Method;
015    import java.util.HashSet;
016    import java.util.Set;
017    
018    import org.objectweb.asm.ClassWriter;
019    import org.objectweb.asm.FieldVisitor;
020    import org.objectweb.asm.MethodVisitor;
021    import org.objectweb.asm.Opcodes;
022    import org.picocontainer.ComponentAdapter;
023    import org.picocontainer.PicoContainer;
024    import org.picocontainer.behaviors.AbstractBehavior;
025    import org.picocontainer.behaviors.Cached;
026    
027    
028    /**
029     * This component adapter makes it possible to hide the implementation of a real subject (behind a proxy).
030     * The proxy will implement all the interfaces of the
031     * underlying subject. If you want caching,
032     * use a {@link Cached} around this one.
033     *
034     * @author Paul Hammant
035     */
036    @SuppressWarnings("serial")
037    public class AsmHiddenImplementation<T> extends AbstractBehavior<T> implements Opcodes {
038    
039    
040            public AsmHiddenImplementation(final ComponentAdapter<T> delegate) {
041            super(delegate);
042        }
043    
044        @Override
045            public T getComponentInstance(final PicoContainer container, final java.lang.reflect.Type into) {
046            T o = getDelegate().getComponentInstance(container, into);
047            Class[] interfaces = o.getClass().getInterfaces();
048            if (interfaces.length != 0) {
049                byte[] bytes = makeProxy("XX", interfaces, true);
050                AsmClassLoader cl = new AsmClassLoader(HotSwappable.Swappable.class.getClassLoader());
051                Class<?> pClazz = cl.defineClass("XX", bytes);
052                try {
053                    Constructor<T> ctor = (Constructor<T>) pClazz.getConstructor(HotSwappable.Swappable.class);
054                    final HotSwappable.Swappable swappable = getSwappable();
055                    swappable.swap(o);
056                    return ctor.newInstance(swappable);
057                } catch (NoSuchMethodException e) {
058                } catch (InstantiationException e) {
059                } catch (IllegalAccessException e) {
060                } catch (InvocationTargetException e) {
061                }
062            }
063            return o;
064        }
065    
066        public String getDescriptor() {
067            return "Hidden";
068        }
069    
070        protected HotSwappable.Swappable getSwappable() {
071            return new HotSwappable.Swappable();
072        }
073    
074        public byte[] makeProxy(final String proxyName, final Class[] interfaces, final boolean setter) {
075    
076            ClassWriter cw = new ClassWriter(0);
077            FieldVisitor fv;
078    
079            Class<Object> superclass = Object.class;
080    
081            cw.visit(V1_5, ACC_PUBLIC + ACC_SUPER, proxyName, null, dotsToSlashes(superclass), getNames(interfaces));
082    
083            {
084                fv = cw.visitField(ACC_PRIVATE + ACC_TRANSIENT, "swappable", encodedClassName(HotSwappable.Swappable.class), null, null);
085                fv.visitEnd();
086            }
087            doConstructor(proxyName, cw);
088            Set<String> methodsDone = new HashSet<String>();
089            for (Class<?> iface : interfaces) {
090                Method[] meths = iface.getMethods();
091                for (Method meth : meths) {
092                    if (!methodsDone.contains(meth.toString())) {
093                        doMethod(proxyName, cw, iface, meth);
094                        methodsDone.add(meth.toString());
095                    }
096                }
097            }
098    
099            cw.visitEnd();
100    
101            return cw.toByteArray();
102        }
103    
104        private String[] getNames(final Class[] interfaces) {
105            String[] retVal = new String[interfaces.length];
106            for (int i = 0; i < interfaces.length; i++) {
107                retVal[i] = dotsToSlashes(interfaces[i]);
108            }
109            return retVal;
110        }
111    
112        private void doConstructor(final String proxyName, final ClassWriter cw) {
113            MethodVisitor mv;
114            mv = cw.visitMethod(ACC_PUBLIC, "<init>", "(L"+ dotsToSlashes(HotSwappable.Swappable.class)+";)V", null, null);
115            mv.visitCode();
116            mv.visitVarInsn(ALOAD, 0);
117            mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V");
118            mv.visitVarInsn(ALOAD, 0);
119            mv.visitVarInsn(ALOAD, 1);
120            mv.visitFieldInsn(PUTFIELD, proxyName, "swappable", encodedClassName(HotSwappable.Swappable.class));
121            mv.visitInsn(RETURN);
122            mv.visitMaxs(2, 2);
123            mv.visitEnd();
124        }
125    
126        private void doMethod(final String proxyName, final ClassWriter cw, final Class<?> iface, final Method meth) {
127            String signature = "(" + encodedParameterNames(meth) + ")" + encodedClassName(meth.getReturnType());
128            String[] exceptions = encodedExceptionNames(meth.getExceptionTypes());
129            MethodVisitor mv;
130            mv = cw.visitMethod(ACC_PUBLIC, meth.getName(), signature, null, exceptions);
131            mv.visitCode();
132            mv.visitVarInsn(ALOAD, 0);
133            mv.visitFieldInsn(GETFIELD, proxyName, "swappable", encodedClassName(HotSwappable.Swappable.class));
134            mv.visitMethodInsn(INVOKEVIRTUAL, dotsToSlashes(HotSwappable.Swappable.class), "getInstance", "()Ljava/lang/Object;");
135            mv.visitTypeInsn(CHECKCAST, dotsToSlashes(iface));
136            Class[] types = meth.getParameterTypes();
137            int ix = 1;
138            for (Class type : types) {
139                int load = whichLoad(type);
140                mv.visitVarInsn(load, ix);
141                ix = indexOf(ix, load);
142            }
143            mv.visitMethodInsn(INVOKEINTERFACE, dotsToSlashes(iface), meth.getName(), signature);
144            mv.visitInsn(whichReturn(meth.getReturnType()));
145            mv.visitMaxs(ix, ix);
146            mv.visitEnd();
147        }
148    
149        private int indexOf(final int ix, final int loadType) {
150            if (loadType == LLOAD) {
151                return ix + 2;
152            } else if (loadType == DLOAD) {
153                return ix + 2;
154            } else if (loadType == ILOAD) {
155                return ix + 1;
156            } else if (loadType == ALOAD) {
157                return ix + 1;
158            } else if (loadType == FLOAD) {
159                return ix + 1;
160            }
161            return 0;
162        }
163    
164        private String[] encodedExceptionNames(final Class[] exceptionTypes) {
165            if (exceptionTypes.length == 0) {
166                return null;
167            }
168            String[] retVal = new String[exceptionTypes.length];
169            for (int i = 0; i < exceptionTypes.length; i++) {
170                Class clazz = exceptionTypes[i];
171                retVal[i] = dotsToSlashes(clazz);
172            }
173            return retVal;
174        }
175    
176        private int whichReturn(final Class<?> clazz) {
177            if (!clazz.isPrimitive()) {
178                return ARETURN;
179            } else if (clazz.isArray()) {
180                return ARETURN;
181            } else if (clazz == int.class) {
182                return IRETURN;
183            } else if (clazz == long.class) {
184                return LRETURN;
185            } else if (clazz == byte.class) {
186                return IRETURN;
187            } else if (clazz == float.class) {
188                return FRETURN;
189            } else if (clazz == double.class) {
190                return DRETURN;
191            } else if (clazz == char.class) {
192                return IRETURN;
193            } else if (clazz == short.class) {
194                return IRETURN;
195            } else if (clazz == boolean.class) {
196                return IRETURN;
197            } else if (clazz == void.class) {
198                return RETURN;
199            } else {
200                return 0;
201            }
202        }
203    
204        private int whichLoad(final Class<?> clazz) {
205            if (!clazz.isPrimitive()) {
206                return ALOAD;
207            } else if (clazz.isArray()) {
208                return ALOAD;
209            } else if (clazz == int.class) {
210                return ILOAD;
211            } else if (clazz == long.class) {
212                return LLOAD;
213            } else if (clazz == byte.class) {
214                return ILOAD;
215            } else if (clazz == float.class) {
216                return FLOAD;
217            } else if (clazz == double.class) {
218                return DLOAD;
219            } else if (clazz == char.class) {
220                return ILOAD;
221            } else if (clazz == short.class) {
222                return ILOAD;
223            } else if (clazz == boolean.class) {
224                return ILOAD;
225            } else {
226                return 0;
227            }
228        }
229    
230        private String encodedClassName(final Class<?> clazz) {
231            if (clazz.getName().startsWith("[")) {
232                return dotsToSlashes(clazz);
233            } else if (!clazz.isPrimitive()) {
234                return "L" + dotsToSlashes(clazz) + ";";
235            } else if (clazz == int.class) {
236                return "I";
237            } else if (clazz == long.class) {
238                return "J";
239            } else if (clazz == byte.class) {
240                return "B";
241            } else if (clazz == float.class) {
242                return "F";
243            } else if (clazz == double.class) {
244                return "D";
245            } else if (clazz == char.class) {
246                return "C";
247            } else if (clazz == short.class) {
248                return "S";
249            } else if (clazz == boolean.class) {
250                return "Z";
251            } else if (clazz == void.class) {
252                return "V";
253            } else {
254                return null;
255            }
256        }
257    
258        private String encodedParameterNames(final Method meth) {
259            String retVal = "";
260            for (Class<?> type : meth.getParameterTypes()) {
261                retVal += encodedClassName(type);
262            }
263            return retVal;
264        }
265    
266        private String dotsToSlashes(final Class<?> type) {
267            return type.getName().replace('.', '/');
268        }
269    
270        private static class AsmClassLoader extends ClassLoader {
271    
272            public AsmClassLoader(final ClassLoader parent) {
273                super(parent);
274            }
275    
276            public Class<?> defineClass(final String name, final byte[] b) {
277                return defineClass(name, b, 0, b.length);
278            }
279        }
280    
281    }