ref: c664c75f3317a331afa283e880b14935cf05d4a3
dir: /libstd/resolve.myr/
use "alloc.use"
use "die.use"
use "endian.use"
use "error.use"
use "fmt.use"
use "option.use"
use "slcp.use"
use "slurp.use"
use "strsplit.use"
use "strstrip.use"
use "sys.use"
use "types.use"
pkg std =
type resolveerr = union
`Badhost
`Badsrv
`Badquery
`Badresp
;;
type netaddr = union
`Ipv4 byte[4]
`Ipv6 byte[16]
;;
type hostinfo = struct
fam : sockfam
stype : socktype
ttl : uint32
addr : netaddr
/*
proto : uint32
flags : uint32
addr : sockaddr[:]
canon : byte[:]
next : hostinfo#
*/
;;
const resolve : (host : byte[:] -> error(hostinfo[:], resolveerr))
;;
const Hostfile = "/etc/hosts"
const resolve = {host : byte[:]
match hostfind(host)
| `Some h: -> h
| `None: -> dnsresolve(host)
;;
}
const hostfind = {host
-> `None
/*
var hdat
var lines
var ip
var hn
var str
var i
match slurp(Hostfile)
| `Success h: hdat = h
| `Failure m: -> `None
;;
lines = strsplit(hdat, "\n")
for i = 0; i < lines.len; i++
lines[i] = strstrip(lines[i])
(ip, str) = nextword(lines)
(hn, str) = nextword(str)
if streq(hn, host)
-> parseip(ip)
;;
;;
*/
}
const dnsresolve = {host : byte[:]
/*var hosts*/
var nsrv
if !valid(host)
-> `Failure (`Badhost)
;;
if (nsrv = dnsconnect()) < 0
-> `Failure (`Badsrv)
;;
-> dnsquery(nsrv, host)
}
const dnsconnect = {
var sa : sockaddr_in
var s
var status
s = socket(Afinet, Sockdgram, 0)
if s < 0
put("Warning: Failed to open socket: %l\n", s)
-> -1
;;
/* hardcode Google DNS for now.
FIXME: parse /etc/resolv.conf */
sa.fam = Afinet
sa.port = hosttonet(53) /* port 53 */
sa.addr = [8,8,8,8] /* 8.8.8.8 */
status = connect(s, (&sa) castto(sockaddr#), sizeof(sockaddr_in))
if status < 0
put("Warning: Failed to connect to server: %l\n", status)
-> -1
;;
-> s
}
const dnsquery = {srv, host
var id
var r
id = tquery(srv, host)
r = rquery(srv, id)
put("Got hosts. Returning\n")
-> r
}
const Qr : uint16 = 1 << 0
const Aa : uint16 = 1 << 5
const Tc : uint16 = 1 << 6
const Rd : uint16 = 1 << 7
const Ra : uint16 = 1 << 8
var nextid : uint16 = 42
const tquery = {srv, host
var pkt : byte[512] /* big enough */
var off : size
put("Sending request for %s\n", host)
/* header */
off = 0
off += pack16(pkt[:], off, nextid) /* id */
off += pack16(pkt[:], off, Ra) /* flags */
off += pack16(pkt[:], off, 1) /* qdcount */
off += pack16(pkt[:], off, 0) /* ancount */
off += pack16(pkt[:], off, 0) /* nscount */
off += pack16(pkt[:], off, 0) /* arcount */
/* query */
off += packname(pkt[:], off, host) /* host */
off += pack16(pkt[:], off, 0x1) /* qtype: a record */
off += pack16(pkt[:], off, 0x1) /* qclass: inet4 */
write(srv, pkt[:off])
-> nextid++
}
const rquery = {srv, id
var pktbuf : byte[1024]
var pkt
var n
put("Waiting for response...\n")
n = read(srv, pktbuf[:])
if n < 0
put("Warning: Failed to read from %z: %i\n", srv, n)
;;
pkt = pktbuf[:n]
put("Got response:\n");
dumpresponse(pkt)
-> hosts(pkt, id)
}
const hosts = {pkt, id : uint16
var off
var v, q, a
var i
var hinf : hostinfo[:]
off = 0
/* parse header */
(v, off) = unpack16(pkt, off) /* id */
if v != id
-> `Failure (`Badresp)
;;
put("Unpacking flags")
(v, off) = unpack16(pkt, off) /* flags */
(q, off) = unpack16(pkt, off) /* qdcount */
(a, off) = unpack16(pkt, off) /* ancount */
(v, off) = unpack16(pkt, off) /* nscount */
(v, off) = unpack16(pkt, off) /* arcount */
/* skip past query records */
for i = 0; i < q; i++
put("Skipping query record")
off = skipname(pkt, off) /* name */
(v, off) = unpack16(pkt, off) /* type */
(v, off) = unpack16(pkt, off) /* class */
;;
/* parse answer records */
hinf = slalloc(a castto(size))
for i = 0; i < a; i++
off = skipname(pkt, off) /* name */
(v, off) = unpack16(pkt, off) /* type */
(v, off) = unpack16(pkt, off) /* class */
(hinf[i].ttl, off) = unpack32(pkt, off) /* ttl */
(v, off) = unpack16(pkt, off) /* rdatalen */
/* the thing we're interested in: our IP address */
hinf[i].addr = `Ipv4 [pkt[off], pkt[off+1], pkt[off+2], pkt[off+3]]
off += 4;
;;
-> `Success hinf
}
const dumpresponse = {pkt
var nquery, nans
var off
var v
var i
(v, off) = unpack16(pkt, 0)
(v, off) = unpack16(pkt, off)
(nquery, off) = unpack16(pkt, off)
put("hdr.qdcount = %w\n", nquery)
(nans, off) = unpack16(pkt, off)
put("hdr.ancount = %w\n", nans)
(v, off) = unpack16(pkt, off)
put("hdr.nscount = %w\n", v)
(v, off) = unpack16(pkt, off)
put("hdr.arcount = %w\n", v)
put("Queries:\n")
for i = 0; i < nquery; i++
put("i: %w\n", i)
off = dumpquery(pkt, off)
;;
put("Answers:")
for i = 0; i < nans; i++
put("i: %w\n", i)
off = dumpans(pkt, off)
;;
}
const dumpquery = {pkt, off
var v
put("\tname = ");
off = printname(pkt, off)
(v, off) = unpack16(pkt, off)
put("\tbody.type = %w\n", v)
(v, off) = unpack16(pkt, off)
put("\tbody.class = %w\n", v)
-> off
}
const dumpans = {pkt, off
var v
put("\tname = ");
off = printname(pkt, off)
(v, off) = unpack16(pkt, off)
put("\tbody.type = %w\n", v)
(v, off) = unpack16(pkt, off)
put("\tbody.class = %w\n", v)
(v, off) = unpack16(pkt, off)
put("\tbody.ttl_lo = %w\n", v)
(v, off) = unpack16(pkt, off)
put("\tbody.ttl_hi = %w\n", v)
(v, off) = unpack16(pkt, off)
put("\tbody.rdlength = %w\n", v)
(v, off) = unpack16(pkt, off)
put("\tbody.rdata_lo = %w\n", v)
(v, off) = unpack16(pkt, off)
put("\tbody.rdata_hi = %w\n", v)
-> off
}
const skipname = {pkt, off
var sz
for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size)
/* ptr is 2 bytes */
if sz & 0xC0 == 0xC0
-> off + 2
else
off += sz + 1
;;
;;
-> off + 1
}
const printname = {pkt, off
var sz
for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size)
if sz & 0xC0 == 0xC0
put("PTR: ")
printname(pkt, ((sz & ~0xC0) << 8) | (pkt[off + 1] castto(size)))
-> off + 2
else
put("%s.", pkt[off+1:off+sz+1])
off += sz + 1
;;
;;
-> off + 1
}
const pack16 = {buf, off, v
buf[off] = (v & 0xff00) >> 8 castto(byte)
buf[off+1] = (v & 0x00ff) castto(byte)
-> sizeof(uint16) /* we always write one uint16 */
}
const unpack16 = {buf, off
var v
v = (buf[off] castto(uint16)) << 8
v |= (buf[off + 1] castto(uint16))
-> (v, off+sizeof(uint16))
}
const unpack32 = {buf, off
var v
v = (buf[off] castto(uint32)) << 24
v |= (buf[off+1] castto(uint32)) << 32
v |= (buf[off+2] castto(uint32)) << 8
v |= (buf[off+3] castto(uint32))
-> (v, off+sizeof(uint32))
}
const packname = {buf, off : size, host
var i
var start
var seglen, lastseg
start = off
seglen = 0
lastseg = 0
for i = 0; i < host.len; i++
seglen++
if host[i] == ('.' castto(byte))
off += addseg(buf, off, host[lastseg:lastseg+seglen-1])
lastseg = seglen
seglen = 0
;;
;;
if host[host.len - 1] != ('.' castto(byte))
off += addseg(buf, off, host[lastseg:lastseg + seglen])
;;
off += addseg(buf, off, "") /* null terminating segment */
-> off - start
}
const addseg = {buf, off, str
buf[off] = str.len castto(byte)
slcp(buf[off + 1 : off + str.len + 1], str)
-> str.len + 1
}
const valid = {host : byte[:]
var i
var seglen
/* maximum length: 255 chars */
if host.len > 255
-> false
;;
seglen = 0
for i = 0; i < host.len; i++
if host[i] == ('.' castto(byte))
seglen = 0
;;
if seglen > 63
-> false
;;
if host[i] & 0x80
-> false
;;
;;
-> true
}