Mini-Spring之BeanPostProcessor后置处理器

引言

我们上一篇文章中,已经完成了初始化的过程,十分简单。这次我们将学习BeanPostProcessor,来使我们初始化的过程更加的精细以及灵活。

本文所有的代码都在这个项目工程里,大家需要的时候可以随时取用。传送门

BeanPostProcessor这个接口作用是在初始化前后,对bean进行一些操作,我们首先需要定义接口,代码如下:

package com.zhu.spring;


public interface BeanPostProcessor {
    /**
     * do something before init
     * 初始化之前做一些动作
     */
    void postProcessBeforeInitialization(Object bean, String beanName);

    /**
     * do something after init
     * 初始化结束之后做一些动作
     */
    void postProcessAfterInitialization(Object bean, String beanName);
}

然后,需要在我们的核心容器类MiniSpringApplicationContext添加一些实现,具体如下:

package com.zhu.spring;

import java.beans.Introspector;
import java.io.File;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;


public class MiniSpringApplicationContext {

    private Class configClass;

    private Map<String, BeanDefinition> beanDefinitionMap = new ConcurrentHashMap<>();
    private Map<String, Object> singletonObjects = new ConcurrentHashMap<>();
    private List<BeanPostProcessor> beanPostProcessorList = new ArrayList<>();

    public MiniSpringApplicationContext(Class configClass) {
        this.configClass = configClass;

        //scan the class decorate by @ComponentScan
        if(configClass.isAnnotationPresent(ComponentScan.class)){

            ComponentScan componentScanAnnotation = (ComponentScan) configClass.getAnnotation(ComponentScan.class);
            //scan path, eg: com.zhu.service
            String path = componentScanAnnotation.value();
            //  com.zhu.service  ----> com/zhu/service
            path = path.replace(".", "/");

            //find absolute path from MiniSpringApplicationContext context
            ClassLoader classLoader = MiniSpringApplicationContext.class.getClassLoader();
            // get url , /Users/knight/IdeaProjects/mini-spring/out/production/mini-spring/com/zhu/service
            URL resource = classLoader.getResource(path);

            File file = new File(resource.getFile());
            if(file.isDirectory()){
                File[] files = file.listFiles();
                for (File f : files) {
                    String absolutePath = f.getAbsolutePath();
                    if(absolutePath.endsWith(".class")){
                        //real load class


                            // /Users/knight/IdeaProjects/mini-spring/out/production/mini-spring/com/zhu/service/UserService ---> com.zhu.service.UserService

                            //com/zhu/service/UserService
                            String className = absolutePath.substring(absolutePath.indexOf("com"), absolutePath.indexOf(".class"));

                            //com.zhu.service.UserService
                            className = className.replace("/", ".");

                        try {
                            Class<?> clazz = classLoader.loadClass(className);

                            if(clazz.isAnnotationPresent(Component.class)){
                                                                //新增代码,判断clazz上是否实现或者继承了BeanPostProcessor接口,然后将其加入定义的List中
                                if(BeanPostProcessor.class.isAssignableFrom(clazz)){
                                    BeanPostProcessor instance = (BeanPostProcessor)clazz.newInstance();
                                    beanPostProcessorList.add(instance);
                                }

                                Component componentAnnotation = clazz.getAnnotation(Component.class);
                                String beanName = componentAnnotation.value();
                                if("".equals(beanName)){
                                    //transfer name ,eg: Service->service SService->SService, SerR->serR
                                    beanName = Introspector.decapitalize(clazz.getSimpleName());
                                }
                                //generate BeanDefinition
                                BeanDefinition beanDefinition = new BeanDefinition();
                                beanDefinition.setType(clazz);
                                if (clazz.isAnnotationPresent(Scope.class)) {
                                    Scope scopeAnnotation = clazz.getAnnotation(Scope.class);
                                    beanDefinition.setScope(scopeAnnotation.value());
                                }else{
                                    beanDefinition.setScope("singleton");
                                }
                                beanDefinitionMap.put(beanName, beanDefinition);
                            }

                        } catch (ClassNotFoundException e) {
                            e.printStackTrace();
                        } catch (InstantiationException e) {
                            e.printStackTrace();
                        } catch (IllegalAccessException e) {
                            e.printStackTrace();
                        }
                    }
                }
            }

        }

        //create bean
        for (String beanName : beanDefinitionMap.keySet()) {
            BeanDefinition beanDefinition = beanDefinitionMap.get(beanName);
            if(beanDefinition.getScope().equals("singleton")){
                Object bean = createBean(beanName, beanDefinition);
                singletonObjects.put(beanName, bean);
            }
        }

    }

    private Object createBean(String beanName, BeanDefinition beanDefinition){

        Class clazz = beanDefinition.getType();

        try {
            Object instance = clazz.getConstructor().newInstance();

            // simple dependency injection
            for (Field field : clazz.getDeclaredFields()) {

                if (field.isAnnotationPresent(Autowired.class)) {
                    //change true ,can assign private
                    field.setAccessible(true);
                    String fieldName = field.getName();
                    Object bean = getBean(fieldName);
                    field.set(instance, bean);
                }
            }

            //check bean name aware
            if (instance instanceof BeanNameAware) {
                //force cast to BeanNameAare and call its method
                ((BeanNameAware)instance).setBeanName(beanName);
            }
                        //新增代码,遍历list,执行before init
            //before init
            for (BeanPostProcessor beanPostProcessor : beanPostProcessorList) {
                beanPostProcessor.postProcessBeforeInitialization(instance, beanName);
            }

            //check initalizing bean
            if (instance instanceof InitializingBean) {
                //force cast to BeanNameAare and call its method
                ((InitializingBean)instance).afterPropertiesSet();
            }

                        //新增代码,遍历list,执行after init
            //after init
            for (BeanPostProcessor beanPostProcessor : beanPostProcessorList) {
                beanPostProcessor.postProcessAfterInitialization(instance, beanName);
            }


            return instance;
        } catch (InstantiationException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        } catch (NoSuchMethodException e) {
            e.printStackTrace();
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

    public Object getBean(String beanName){

        BeanDefinition beanDefinition = beanDefinitionMap.get(beanName);

        if(beanDefinition == null){
            throw new RuntimeException("class not found with bean name:"+beanName);
        }

        String scope = beanDefinition.getScope();
        if("singleton".equals(scope)){
            Object bean = singletonObjects.get(beanName);
            if(bean == null){
                Object createdBean = createBean(beanName, beanDefinition);
                singletonObjects.put(beanName, createdBean);
                return createdBean;
            }
            return bean;
        }else{
            return createBean(beanName, beanDefinition);
        }
    }
}

最后,需要在我们的测试类中新增测试代码:

package com.zhu.service;

import com.zhu.spring.BeanPostProcessor;
import com.zhu.spring.Component;

@Component
public class SelfTestBeanPostProcessor implements BeanPostProcessor {
    @Override
    public void postProcessBeforeInitialization(Object bean, String beanName) {
        if(beanName.equals("userService")){
            System.out.println("before init bean processor");
        }

    }

    @Override
    public void postProcessAfterInitialization(Object bean, String beanName) {
        if(beanName.equals("userService")){
            System.out.println("after init bean processor");
        }
    }
}

执行测试类,我们发现一次执行BeanPostProcessor的beforeinit方法,InitializingBean的afterInit方法,BeanPostProcessor的afterinit方法。

结果如下:

before init bean processor
initializing bean —-afterPropertiesSet
after init bean processor
com.zhu.service.OrderService@60e53b93|userService