1 package de.campussource.cse.common.test;
2
3 import java.lang.reflect.Field;
4 import java.util.ArrayList;
5 import java.util.Arrays;
6 import java.util.HashMap;
7 import java.util.List;
8 import java.util.Map;
9
10 import javax.ejb.EJB;
11 import javax.persistence.EntityManager;
12 import javax.persistence.PersistenceContext;
13
14
15 public class AnnotationInjector {
16
17 private Map<String, Object> context = new HashMap<String, Object>();
18
19 public AnnotationInjector autowire(Object object) {
20 intoContextByClassName(object);
21 autowireByField(object);
22 return this;
23 }
24
25 public AnnotationInjector defaultPersistentUnit(EntityManager entityManager) {
26 addPersistenceUnit(PersistenceContext.class.getName(), entityManager);
27 return this;
28 }
29
30 public AnnotationInjector addPersistenceUnit(String persistentUnit, EntityManager entityManager) {
31 context.put(persistentUnit, entityManager);
32 return this;
33 }
34
35 private void intoContextByClassName(Object target) {
36 context.put(target.getClass().getName(), target);
37 }
38
39 private void injectField(Object object, Field field, Object reference) {
40 boolean accessible = field.isAccessible();
41 field.setAccessible(true);
42 try {
43 field.set(object, reference);
44 } catch (IllegalArgumentException e) {
45 e.printStackTrace();
46 } catch (IllegalAccessException e) {
47 e.printStackTrace();
48 }
49 field.setAccessible(accessible);
50 }
51
52 protected void autowireByField(Object object) {
53 Field[] fields = retrieveAllDeclaredFieldsIncludedInheritance(object);
54 for (Field field : fields) {
55 if (field.isAnnotationPresent(PersistenceContext.class)) {
56 PersistenceContext annotation = field.getAnnotation(PersistenceContext.class);
57 Object target = null;
58 if (annotation.name().isEmpty()) {
59
60 target = context.get(PersistenceContext.class.getName());
61 } else {
62 target = context.get(annotation.name());
63 }
64 injectField(object, field, target);
65 }
66 if (field.isAnnotationPresent(EJB.class)) {
67 try {
68 Class fieldClass = field.getType();
69 Object target = context.get(fieldClass.getName());
70 if (target == null) {
71 target = fieldClass.newInstance();
72 autowire(target);
73 context.put(fieldClass.getName(), target);
74 }
75 injectField(object, field, target);
76 } catch (InstantiationException e) {
77 throw new AnnotationInjectionException(object, e);
78 } catch (IllegalAccessException e) {
79 throw new AnnotationInjectionException(object, e);
80 }
81 }
82 }
83 }
84
85 private Field[] retrieveAllDeclaredFieldsIncludedInheritance(Object object) {
86 List<Field> fields = new ArrayList<Field>();
87 Class clazz = object.getClass();
88 while (clazz != null) {
89 fields.addAll(Arrays.asList(clazz.getDeclaredFields()));
90 clazz = clazz.getSuperclass();
91 }
92 return (Field[]) fields.toArray(new Field[fields.size()]);
93 }
94 }