skynet源码分析(16)--skynet中http之httpc和httpd

作者:shihuaping0918@163.com,转载请注明作者

httpc.lua和httpd.lua提供的功能比较简陋,函数也比较少,代码量比较少,一百多行。在对http协议有一定认识的前提下,分析这两个文件的代码是比较简单的。

httpc.lua是http客户端代码,支持get/post请求,发送完请求以后等待回应。并解析回应包。

local skynet = require "skynet"
local socket = require "http.sockethelper"
local url = require "http.url"
local internal = require "http.internal"
local dns = require "skynet.dns"
local string = string
local table = table

local httpc = {}
--发送请求并等待回应
local function request(fd, method, host, url, recvheader, header, content)
    local read = socket.readfunc(fd)
    local write = socket.writefunc(fd)
    local header_content = ""
    if header then
        if not header.host then
            header.host = host
        end
        for k,v in pairs(header) do --http头组成字符串
            header_content = string.format("%s%s:%s\r\n", header_content, k, v)
        end
    else
        header_content = string.format("host:%s\r\n",host)
    end

    if content then --有消息体
        local data = string.format("%s %s HTTP/1.1\r\n%scontent-length:%d\r\n\r\n", method, url, header_content, #content) 
--content-length为消息体长度
        write(data)
        write(content)
    else --无消息体
        local request_header = string.format("%s %s HTTP/1.1\r\n%scontent-length:0\r\n\r\n", method, url, header_content)
--content-length为消息体长度,没有消息体就为0
        write(request_header)
    end
--等待回应
    local tmpline = {}
    local body = internal.recvheader(read, tmpline, "")
    if not body then
        error(socket.socket_error)
    end
--取出状态码,200是ok
    local statusline = tmpline[1]
    local code, info = statusline:match "HTTP/[%d%.]+%s+([%d]+)%s+(.*)$"
    code = assert(tonumber(code))
--取消息头
    local header = internal.parseheader(tmpline,2,recvheader or {})
    if not header then
        error("Invalid HTTP response header")
    end
--取content-length
    local length = header["content-length"]
    if length then
        length = tonumber(length)
    end
--取消息体编码方式
    local mode = header["transfer-encoding"]
    if mode then
        if mode ~= "identity" and mode ~= "chunked" then
            error ("Unsupport transfer-encoding")
        end
    end
--读取消息体
    if mode == "chunked" then
        body, header = internal.recvchunkedbody(read, nil, header, body)
        if not body then
            error("Invalid response body")
        end
    else
        -- identity mode
        if length then
            if #body >= length then
                body = body:sub(1,length)
            else
                local padding = read(length - #body)
                body = body .. padding
            end
        else
            -- no content-length, read all
            body = body .. socket.readall(fd)
        end
    end

    return code, body
end

local async_dns

function httpc.dns(server,port)
    async_dns = true
    dns.server(server,port)
end

function httpc.request(method, host, url, recvheader, header, content)
    local timeout = httpc.timeout   -- get httpc.timeout before any blocked api
    local hostname, port = host:match"([^:]+):?(%d*)$"
    if port == "" then --默认端口80
        port = 80
    else
        port = tonumber(port)
    end
--如果是域名,而不是ip
    if async_dns and not hostname:match(".*%d+$") then
        hostname = dns.resolve(hostname)
    end
--连接服务器,如果timeout>0,就是异步等待
    local fd = socket.connect(hostname, port, timeout)
    local finish
    if timeout then
        skynet.timeout(timeout, function()
            if not finish then
                socket.shutdown(fd) -- shutdown the socket fd, need close later.
            end
        end)
    end
--调用上面定义的request函数,在保护模式下进行
    local ok , statuscode, body = pcall(request, fd,method, host, url, recvheader, header, content)
    finish = true
    socket.close(fd)
    if ok then
        return statuscode, body
    else
        error(statuscode)
    end
end
--get方法
function httpc.get(...)
    return httpc.request("GET", ...)
end
--转换为百分号表示
local function escape(s)
    return (string.gsub(s, "([^A-Za-z0-9_])", function(c)
        return string.format("%%%02X", string.byte(c))
    end))
end
--post方法
function httpc.post(host, url, form, recvheader)
    local header = {
        ["content-type"] = "application/x-www-form-urlencoded"
    }
    local body = {}
    for k,v in pairs(form) do
        table.insert(body, string.format("%s=%s",escape(k),escape(v)))
    end

    return httpc.request("POST", host, url, recvheader, header, table.concat(body , "&"))
end

return httpc

httpd.lua,功能只有读取请求,解析请求。发送回应。

local internal = require "http.internal"

local table = table
local string = string
local type = type

local httpd = {}
--错误状态码定义
local http_status_msg = {
    [100] = "Continue",
    [101] = "Switching Protocols",
    [200] = "OK",
    [201] = "Created",
    [202] = "Accepted",
    [203] = "Non-Authoritative Information",
    [204] = "No Content",
    [205] = "Reset Content",
    [206] = "Partial Content",
    [300] = "Multiple Choices",
    [301] = "Moved Permanently",
    [302] = "Found",
    [303] = "See Other",
    [304] = "Not Modified",
    [305] = "Use Proxy",
    [307] = "Temporary Redirect",
    [400] = "Bad Request",
    [401] = "Unauthorized",
    [402] = "Payment Required",
    [403] = "Forbidden",
    [404] = "Not Found",
    [405] = "Method Not Allowed",
    [406] = "Not Acceptable",
    [407] = "Proxy Authentication Required",
    [408] = "Request Time-out",
    [409] = "Conflict",
    [410] = "Gone",
    [411] = "Length Required",
    [412] = "Precondition Failed",
    [413] = "Request Entity Too Large",
    [414] = "Request-URI Too Large",
    [415] = "Unsupported Media Type",
    [416] = "Requested range not satisfiable",
    [417] = "Expectation Failed",
    [500] = "Internal Server Error",
    [501] = "Not Implemented",
    [502] = "Bad Gateway",
    [503] = "Service Unavailable",
    [504] = "Gateway Time-out",
    [505] = "HTTP Version not supported",
}
--读请求
local function readall(readbytes, bodylimit)
    local tmpline = {}
    local body = internal.recvheader(readbytes, tmpline, "")
    if not body then
        return 413  -- Request Entity Too Large
    end
    local request = assert(tmpline[1])
    --请求url/method,http版本号,start line对应的内容
    local method, url, httpver = request:match "^(%a+)%s+(.-)%s+HTTP/([%d%.]+)$"
    assert(method and url and httpver)
    httpver = assert(tonumber(httpver))
    if httpver < 1.0 or httpver > 1.1 then --http版本错误
        return 505  -- HTTP Version not supported
    end
    local header = internal.parseheader(tmpline,2,{})
    if not header then
        return 400  -- Bad request
    end
    local length = header["content-length"] --消息体长度,所有的field name被转成小写了
    if length then
        length = tonumber(length)
    end
    local mode = header["transfer-encoding"] --消息体编码格式
    if mode then
        if mode ~= "identity" and mode ~= "chunked" then
            return 501  -- Not Implemented
        end
    end

    if mode == "chunked" then --chunked方式
        body, header = internal.recvchunkedbody(readbytes, bodylimit, header, body)
        if not body then
            return 413
        end
    else
        -- identity mode
        if length then
            if bodylimit and length > bodylimit then
                return 413
            end
            if #body >= length then
                body = body:sub(1,length)
            else
                local padding = readbytes(length - #body) --读指定的长度
                body = body .. padding 
            end
        end
    end

    return 200, url, method, header, body
end
--读取http请求
function httpd.read_request(...)
    local ok, code, url, method, header, body = pcall(readall, ...)
    if ok then
        return code, url, method, header, body
    else
        return nil, code
    end
end
--发应答包
local function writeall(writefunc, statuscode, bodyfunc, header)
     --http start line
    local statusline = string.format("HTTP/1.1 %03d %s\r\n", statuscode, http_status_msg[statuscode] or "")
    writefunc(statusline)
    if header then --发http头
        for k,v in pairs(header) do
            if type(v) == "table" then --表中表,或者表中key对应的是数组
                for _,v in ipairs(v) do
                    writefunc(string.format("%s: %s\r\n", k,v))
                end
            else
                writefunc(string.format("%s: %s\r\n", k,v))
            end
        end
    end
    local t = type(bodyfunc) --wtf,这个名字取得很误导
    if t == "string" then
        writefunc(string.format("content-length: %d\r\n\r\n", #bodyfunc)) --消息体长度
        writefunc(bodyfunc) --bodyfunc是字符串啊,所以说名字取得很误导人
    elseif t == "function" then
        writefunc("transfer-encoding: chunked\r\n")
        while true do
            local s = bodyfunc() --取消息体的一部分,应该是个generator才对
            if s then
                if s ~= "" then
                    writefunc(string.format("\r\n%x\r\n", #s)) --chunk size
                    writefunc(s) --chunk data
                end
            else
                writefunc("\r\n0\r\n\r\n") --last chunk
                break
            end
        end
    else
        assert(t == "nil")
        writefunc("\r\n")
    end
end

function httpd.write_response(...)
    return pcall(writeall, ...)
end

return httpd

推荐阅读更多精彩内容