手写tomcat

引言

作web开发的朋友对于tomcat应该不陌生,大家一直在和他打交道。

这边文章主要是模拟tomcat的主要功能,来自己实现一个简单的tomcat,用于更加清晰的理解它。

我们先回忆一下,tomcat做了什么?

在前后端统一的时代,我们会将写好的代码和jsp(主要是servlet实现类,jsp也是编译成servlet)打包,然后放到tomcat容器webapps目录下,然后启动tomcat容器,访问localhost:80(默认端口),就能在本地浏览器中动态访问网页了。

经过上面的回忆,我们发现tomcat主要是做了两部分工作:

  1. 处理浏览器发送给tomcat的请求,将其转发到响应的servlet中去
  2. 加载webapps下项目中的servlet实现类

这个项目主要是为了学习,所以我们也只需要模拟这个主流程,项目地址

处理Http请求

Http协议是应用层的协议,是基于tcp/ip的协议,而tcp/ip协议开发过于繁琐,于是就有了一个socket的抽象层,在tomcat中我们的Http协议就是基于Socket来实现的。

  1. 首先,我们创建一个Tomcat类,并启动一个SocketServer,具体如下:

    package com.zhu;
    
    import java.io.IOException;
    import java.net.ServerSocket;
    
    public class Tomcat {
        public void start(){
            try {
                ServerSocket serverSocket = new ServerSocket(8080);
                serverSocket.accept();
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        public static void main(String[] args) {
            Tomcat tomcat = new Tomcat();
            tomcat.start();
        }
    }
    
  2. 这个时候我们已经有了一个serversocket用于监听socket客户端发送的请求,这个时候我们其实可以在浏览器中输入localhost:8080,打上断点发现其实也能访问。但是,有个问题,就是访问一次之后程序就停止了,这明显不合理,怎么可能一个tomcat只处理一次请求能。因此,我们需要一个死循环来处理源源不断的请求。将上面代码进行一个修改。

    package com.zhu;
    
    import java.io.IOException;
    import java.net.ServerSocket;
    
    public class Tomcat {
        public void start(){
            try {
                while (true){
                    ServerSocket serverSocket = new ServerSocket(8080);
                    serverSocket.accept();
                }
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        public static void main(String[] args) {
            Tomcat tomcat = new Tomcat();
            tomcat.start();
        }
    }
    
  3. 接下来的问题就简单了,我们已经拿到了前端发过来的请求,所以下一步就只要解析http请求即可。解析之前,我们先看下http请求报文的格式:

    HTTP请求报文有四个部分组成:请求行(request line)、请求头(request header)、空行和请求数据(request data)

    http请求报文格式

  4. 我们新建一个类SocketProcesser专门用于处理socket连接,解析http报文,这里为了简单处理我们只解析请求行的信息,并将解析出来的信息放到自定义的Request中:

    package com.zhu;
    
    import com.zz.SelfServlet;
    
    import javax.servlet.ServletException;
    import java.io.IOException;
    import java.io.InputStream;
    import java.net.Socket;
    
    
    public class SocketProcessor implements Runnable{
    
        private Socket socket;
    
        public SocketProcessor(Socket socket) {
            this.socket = socket;
        }
    
        private void processSocket(Socket socket) {
            //处理socket连接,解析读取的数据,写返回的数据
            try {
                InputStream inputStream = socket.getInputStream();
                //读取1kb的数据到bytes数组,暂时简单测试,正式中应该用循环读取
                byte[] bytes = new byte[2048];
                inputStream.read(bytes);
    
    
                //解析http请求,这里简单化处理只请求第一行内容,即请求方法,请求路径,协议版本
                //GET /hello HTTP/1.1
                Request request = extracted(bytes, new Request());
                
            } catch (IOException e) {
                throw new RuntimeException(e);
            } catch (ServletException e) {
                throw new RuntimeException(e);
            }
    
    
        }
    
        private static Request extracted(byte[] bytes, Request request) {
            //解析GET,原理是遇到第一个空格返回,记录下标和内容
            int endPosition  = 0;
            StringBuilder stringBuilder = new StringBuilder();
            for(; endPosition< bytes.length; endPosition++){
                char currentChar = (char) bytes[endPosition];
                if( currentChar == ' '){
                    break;
                }
                stringBuilder.append(currentChar);
            }
    
    
            String method = stringBuilder.toString();
            request.setMethod(method);
    
            //清空暂存
            stringBuilder.delete(0, stringBuilder.length());
            //解析URL
            endPosition++;
            for(; endPosition< bytes.length; endPosition++){
                char currentChar = (char) bytes[endPosition];
                if( currentChar == ' '){
                    break;
                }
                stringBuilder.append(currentChar);
            }
    
            String url = stringBuilder.toString();
            request.setUrl(url);
    
    
            //清空暂存,在endPosition之后遇到第一个\r符号就停止
            stringBuilder.delete(0, stringBuilder.length());
            //解析URL
            endPosition++;
            for(; endPosition< bytes.length; endPosition++){
                char currentChar = (char) bytes[endPosition];
                if( currentChar == '\r'){
                    break;
                }
                stringBuilder.append(currentChar);
            }
            String protocal = stringBuilder.toString();
            request.setProtocol(protocal);
            return request;
        }
    
        @Override
        public void run() {
            processSocket(socket);
        }
    }
    
    package com.zhu;
    
    import java.net.Socket;
    
    //AbstractRequest就是一个抽象类实现了HttpServletRequest的接口,方便后续实现指定的几个接口
    //不然的话,就需要实现所有HttpServletRequest接口的所有方法了,这里略去了get/set
    public class Request extends AbstractRequest {
    
        private String method;
        private String url;
        private String protocol;
    
        private Socket socket;
    
        public Request() {
        }
    
        public Request(String method, String url, String protocol, Socket socket) {
            this.method = method;
            this.url = url;
            this.protocol = protocol;
            this.socket = socket;
        }
    }
    
  5. 理论上来讲,我们已经将解析好http请求,并将其封到Request类中,这个时候只需要将其传递到我们自定义的Servlet中做一些我们自定义的业务处理。但是,我们目前还没写加载webapps目录下的servlet,所以我们先简单new一个跑通流程再说:

    package com.zz;//com.zz包下,代表用户自定义的servlet com.zhu代表我们mini-tomcat框架的
    
    import javax.servlet.ServletException;
    import javax.servlet.http.HttpServlet;
    import javax.servlet.http.HttpServletRequest;
    import javax.servlet.http.HttpServletResponse;
    import java.io.IOException;
    
    public class SelfServlet extends HttpServlet {
    
        @Override
        protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
            System.out.println("user servlet:"+req.getMethod());
          //这两个请求头必须要加,不然浏览器无法正常显示数据
            resp.addHeader("Content-Length", "12");
            resp.addHeader("Content-Type", "text/plain;charset=utf-8");
          //我们这里简单打印一个字符串
            resp.getOutputStream().write("hello , mini-tomcat".getBytes());
        }
    
    }
    

    于此同时,SocketProcessor也要进行修改:

    package com.zhu;
    
    import com.zz.SelfServlet;
    
    import javax.servlet.ServletException;
    import java.io.IOException;
    import java.io.InputStream;
    import java.net.Socket;
    
    
    public class SocketProcessor implements Runnable{
    
        private Socket socket;
    
        public SocketProcessor(Socket socket) {
            this.socket = socket;
        }
    
        private void processSocket(Socket socket) {
            //处理socket连接,解析读取的数据,写返回的数据
            try {
                InputStream inputStream = socket.getInputStream();
                //读取1kb的数据到bytes数组,暂时简单测试,正式中应该用循环读取
                byte[] bytes = new byte[2048];
                inputStream.read(bytes);
    
    
                //解析http请求,这里简单化处理只请求第一行内容,即请求方法,请求路径,协议版本
                //GET /hello HTTP/1.1
                Request request = extracted(bytes, new Request());
    
                //这里自定义Servlet模拟用户写的servlet,暂时没写到tomcat加载webapps的servlet出此下策
                SelfServlet selfServlet = new SelfServlet();
                  //这里service会自动根据Http的请求方式(get post等)来找对应的方法(doget dopost)
                selfServlet.service(request, response);
                
            } catch (IOException e) {
                throw new RuntimeException(e);
            } catch (ServletException e) {
                throw new RuntimeException(e);
            }
    
    
        }
    
        private static Request extracted(byte[] bytes, Request request) {
            //解析GET,原理是遇到第一个空格返回,记录下标和内容
            int endPosition  = 0;
            StringBuilder stringBuilder = new StringBuilder();
            for(; endPosition< bytes.length; endPosition++){
                char currentChar = (char) bytes[endPosition];
                if( currentChar == ' '){
                    break;
                }
                stringBuilder.append(currentChar);
            }
    
    
            String method = stringBuilder.toString();
            request.setMethod(method);
    
            //清空暂存
            stringBuilder.delete(0, stringBuilder.length());
            //解析URL
            endPosition++;
            for(; endPosition< bytes.length; endPosition++){
                char currentChar = (char) bytes[endPosition];
                if( currentChar == ' '){
                    break;
                }
                stringBuilder.append(currentChar);
            }
    
            String url = stringBuilder.toString();
            request.setUrl(url);
    
    
            //清空暂存,在endPosition之后遇到第一个\r符号就停止
            stringBuilder.delete(0, stringBuilder.length());
            //解析URL
            endPosition++;
            for(; endPosition< bytes.length; endPosition++){
                char currentChar = (char) bytes[endPosition];
                if( currentChar == '\r'){
                    break;
                }
                stringBuilder.append(currentChar);
            }
            String protocal = stringBuilder.toString();
            request.setProtocol(protocal);
            return request;
        }
    
        @Override
        public void run() {
            processSocket(socket);
        }
    }
    
  6. 最后,就是将信息返回给浏览器端了,所以我们定义了一个Response类,用于生成http返回报文:

    HTTP响应报文也有四个部分组成,和请求报文格式是差不多的,包含状态行(status line)、响应头部(headers)、空行(blank line)和响应数据(也叫响应体或响应正文),图中请求数据换成响应的正文即可。

    http请求报文格式

    代码如下:

    package com.zhu;
    
    import java.io.IOException;
    import java.io.OutputStream;
    import java.util.HashMap;
    import java.util.Map;
    
    
    public class Response extends AbstractResponse{
    
        private int status = 200;
        private String message = "OK";
    
        private byte SP = ' ';
        private byte CR = '\r';
        private byte LF = '\n';
        private Map<String, String> headers = new HashMap<>();
    
        private Request request;
        private OutputStream socketOutputSteam;
        ResponseServletOutputStream responseServletOutputStream = new ResponseServletOutputStream();
    
        public Response(Request request) {
            this.request = request;
            try {
                this.socketOutputSteam = request.getSocket().getOutputStream();
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
    
        @Override
        public int getStatus() {
            return status;
        }
    
        @Override
        public void setStatus(int status, String message) {
            this.status = status;
            this.message = message;
        }
    
        @Override
        public void addHeader(String s, String s1) {
            headers.put(s, s1);
        }
    
        @Override
        public ResponseServletOutputStream getOutputStream() throws IOException {
            return responseServletOutputStream;
        }
    
        /**
         * 实际发送响应的地方
         * 也是实现http response协议格式的地方
         */
        public void complete(){
            try {
                //发送响应行
                sendResponseLine();
                //发送响应头
                sendResponseHeader();
                //发送响应体
                sendResponseBody();
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
    
    
        }
    
        private void sendResponseBody() throws IOException {
            socketOutputSteam.write(getOutputStream().getBytes());
        }
    
        private void sendResponseHeader() throws IOException{
            for (Map.Entry<String, String> entry : headers.entrySet()) {
                String key = entry.getKey();
                String value = entry.getValue();
                socketOutputSteam.write(key.getBytes());
                socketOutputSteam.write(":".getBytes());
                socketOutputSteam.write(value.getBytes());
                socketOutputSteam.write(CR);
                socketOutputSteam.write(LF);
            }
            socketOutputSteam.write(CR);
            socketOutputSteam.write(LF);
        }
    
        private void sendResponseLine() throws IOException {
            socketOutputSteam.write(request.getProtocol().getBytes());
            socketOutputSteam.write(SP);
            socketOutputSteam.write(status);
            socketOutputSteam.write(SP);
            socketOutputSteam.write(message.getBytes());
            socketOutputSteam.write(CR);
            socketOutputSteam.write(LF);
        }
    }
    

    同时修改SocketProcessor:

    package com.zhu;
    
    import com.zz.SelfServlet;
    
    import javax.servlet.ServletException;
    import java.io.IOException;
    import java.io.InputStream;
    import java.net.Socket;
    
    
    public class SocketProcessor implements Runnable{
    
        private Socket socket;
    
        public SocketProcessor(Socket socket) {
            this.socket = socket;
        }
    
        private void processSocket(Socket socket) {
            //处理socket连接,解析读取的数据,写返回的数据
            try {
                InputStream inputStream = socket.getInputStream();
                //读取1kb的数据到bytes数组,暂时简单测试,正式中应该用循环读取
                byte[] bytes = new byte[2048];
                inputStream.read(bytes);
    
    
                //解析http请求,这里简单化处理只请求第一行内容,即请求方法,请求路径,协议版本
                //GET /hello HTTP/1.1
                Request request = extracted(bytes, new Request());
                request.setSocket(socket);
                Response response = new Response(request);
    
                //这里自定义Servlet模拟用户写的servlet,暂时没写到tomcat加载webapps的servlet出此下策
                SelfServlet selfServlet = new SelfServlet();
                  //这里service会自动根据Http的请求方式(get post等)来找对应的方法(doget dopost)
                selfServlet.service(request, response);
              
                  //发送响应
                response.complete();
                
            } catch (IOException e) {
                throw new RuntimeException(e);
            } catch (ServletException e) {
                throw new RuntimeException(e);
            }
    
    
        }
    
        private static Request extracted(byte[] bytes, Request request) {
            //解析GET,原理是遇到第一个空格返回,记录下标和内容
            int endPosition  = 0;
            StringBuilder stringBuilder = new StringBuilder();
            for(; endPosition< bytes.length; endPosition++){
                char currentChar = (char) bytes[endPosition];
                if( currentChar == ' '){
                    break;
                }
                stringBuilder.append(currentChar);
            }
    
    
            String method = stringBuilder.toString();
            request.setMethod(method);
    
            //清空暂存
            stringBuilder.delete(0, stringBuilder.length());
            //解析URL
            endPosition++;
            for(; endPosition< bytes.length; endPosition++){
                char currentChar = (char) bytes[endPosition];
                if( currentChar == ' '){
                    break;
                }
                stringBuilder.append(currentChar);
            }
    
            String url = stringBuilder.toString();
            request.setUrl(url);
    
    
            //清空暂存,在endPosition之后遇到第一个\r符号就停止
            stringBuilder.delete(0, stringBuilder.length());
            //解析URL
            endPosition++;
            for(; endPosition< bytes.length; endPosition++){
                char currentChar = (char) bytes[endPosition];
                if( currentChar == '\r'){
                    break;
                }
                stringBuilder.append(currentChar);
            }
            String protocal = stringBuilder.toString();
            request.setProtocol(protocal);
            return request;
        }
    
        @Override
        public void run() {
            processSocket(socket);
        }
    }
    

    可以看见,一次浏览器请求都是一个request对应一个response。

  7. 最后,还有一个问题没有解决。目前为止,我们所有的请求都是单线程的。只有等第一个http请求连接完成,执行完解析,处理自定义业务(自定义的servlet中处理的),返回报文,第二个http请求才能进入。我们将其简单又滑下,将连接和解析处理任务分开,具体代码如下:

    package com.zhu;
    
    import java.io.IOException;
    import java.io.InputStream;
    import java.io.OutputStream;
    import java.net.ServerSocket;
    import java.net.Socket;
    import java.util.concurrent.ExecutorService;
    import java.util.concurrent.Executors;
    
    public class Tomcat {
    
    
        public void start(){
            //Socket 链接
            try {
                //访问http://localhost:8080
                ServerSocket serverSocket = new ServerSocket(8080);
    
                ExecutorService executorService = Executors.newFixedThreadPool(10);
    
                //支持多个socket连接请求,如果不加循环的话,第一次执行完,就会结束vm
                while(true){
                    Socket socket = serverSocket.accept();
                    //用其他线程去做一些解析或者其他工作,主线程只用于接收请求
                    executorService.execute(new SocketProcessor(socket));
                }
    
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
    
    
    
        public static void main(String[] args) {
            Tomcat tomcat = new Tomcat();
    
            tomcat.start();
        }
    }
    
  8. 然后,启动项目访问http://localhost:8080/hello会出现我们设置的结果。

目前,我们完成了tomcat解析http,并将其封装到Request类中,然后传递到我们自定义的servlet业务类中,执行业务之后能够正常返回http报文。

下篇文章将补全tomcat加载webapps下面的业务servlet,不用像现在一样硬编码,敬请期待!😄😄😄