shithub: mc

ref: d073dcc756169e572f54e65f05ba6b4ffaddcde3
dir: /lib/flate/test/flate.myr/

View raw version
use bio
use flate
use std
use testr

impl std.equatable std.result(int, flate.err) =
	eq = {a, b
		match (a, b)
		| (`std.Ok na, `std.Ok nb): -> na == nb
		| (`std.Err (`flate.Io ea), `std.Err (`flate.Io eb)): -> ea == eb
		| (`std.Err (`flate.Format ea), `std.Err (`flate.Format eb)):
			-> std.sleq(ea, eb)
		| _: -> false
		;;
	}
;;

const init = {in -> flate.flatedec#
	-> std.mk([
		.rdf = bio.mkmem(in),
		.bpos = 8,
	])
}

const eof = `std.Err (`flate.Io `bio.Eof)

const bits = {ctx
	var st

	st = init("")
	testr.eq(ctx, flate.bits(st, 1), eof)

	st = init(" ")
	testr.eq(ctx, flate.bits(st, 9), eof)

	/* simple bit reading */
	st = init("\xaa")
	testr.eq(ctx, flate.bits(st, 0), `std.Ok 0)
	testr.eq(ctx, flate.bits(st, 1), `std.Ok 0)
	testr.eq(ctx, flate.bits(st, 1), `std.Ok 1)
	testr.eq(ctx, flate.bits(st, 2), `std.Ok 2)
	testr.eq(ctx, flate.bits(st, 3), `std.Ok 2)
	testr.eq(ctx, flate.bits(st, 2), eof)

	/* reading across byte boundaries */
	st = init("\xaa\xaa")
	testr.eq(ctx, flate.bits(st, 2), `std.Ok 2)
	testr.eq(ctx, flate.bits(st, 8+2), `std.Ok 0x2aa)
	testr.eq(ctx, flate.bits(st, 1), `std.Ok 0)
	testr.eq(ctx, flate.bits(st, 4), eof)

	/* reading whole bytes */
	st = init("\xab\xcd")
	testr.eq(ctx, flate.bits(st, 8), `std.Ok 0xab)
	testr.eq(ctx, flate.bits(st, 8), `std.Ok 0xcd)

	/* refilling */
	st = init("")
	var n = 0
	var pn = &n
	st.rdf = bio.mk(bio.Rd, [
		.read = {buf
			buf[0] = (++pn# : byte)
			-> `std.Ok 1
		}
	])
	testr.eq(ctx, flate.bits(st, 3), `std.Ok 1)
	testr.eq(ctx, flate.bits(st, 8), `std.Ok (2 << 5))
	testr.eq(ctx, flate.bits(st, 10), `std.Ok (3 << 5))
}

const mkhufc = {ctx
	var count
	var ht

	count = [
		0, 0, 0, 0, 0, 0, 0, 0,
		0, 0, 0, 0, 0, 0, 0, 0,
	][:]

	flate.mkhufc([][:], &ht)
	testr.check(ctx, std.eq(ht.count[:], count), "empty code")

	flate.mkhufc([1, 1][:], &ht)
	count[1] = 2
	testr.check(ctx, std.eq(ht.count[:], count), "1-bit code")
	testr.eq(ctx, ht.symbol[:2], [0, 1][:])

	flate.mkhufc([2, 1, 3, 3][:], &ht)
	std.slcp(count[1:4], [1, 1, 2][:])
	testr.check(ctx, std.eq(ht.count[:], count), "rfc code")
	testr.eq(ctx, ht.symbol[:4], [1, 0, 2, 3][:])
}

const readcode = {ctx
	var st, ht

	flate.mkhufc([2, 1, 3, 3][:], &ht)
	st = init("\xcd\x05")
	testr.eq(ctx, flate.readcode(st, &ht), `std.Ok 0)
	testr.eq(ctx, flate.readcode(st, &ht), `std.Ok 2)
	testr.eq(ctx, flate.readcode(st, &ht), `std.Ok 1)
	testr.eq(ctx, flate.readcode(st, &ht), `std.Ok 3)
	testr.eq(ctx, flate.readcode(st, &ht), `std.Ok 1)
	testr.eq(ctx, flate.readcode(st, &ht), `std.Ok 0)
	testr.eq(ctx, st.bpos, 4)

	/* same as above with some empty codes */
	flate.mkhufc([2, 1, 0, 0, 3, 0, 3][:], &ht)
	st = init("\xcd\x05")
	testr.eq(ctx, flate.readcode(st, &ht), `std.Ok 0)
	testr.eq(ctx, flate.readcode(st, &ht), `std.Ok 4)
	testr.eq(ctx, flate.readcode(st, &ht), `std.Ok 1)
	testr.eq(ctx, flate.readcode(st, &ht), `std.Ok 6)
	testr.eq(ctx, flate.readcode(st, &ht), `std.Ok 1)
	testr.eq(ctx, flate.readcode(st, &ht), `std.Ok 0)
	testr.eq(ctx, st.bpos, 4)

	/* empty codes at the beginning */
	flate.mkhufc([0, 0, 0, 4][:], &ht)
	st = init("\x00")
	testr.eq(ctx, flate.readcode(st, &ht), `std.Ok 3)
	testr.eq(ctx, flate.readcode(st, &ht), `std.Ok 3)
	testr.eq(ctx, flate.readcode(st, &ht), eof)
}

const bitsf = {str : byte[:]
	var sl, n, b

	sl = [][:]
	n = 0
	b = 0
	for c : str
		match (c : char)
		| '0': n++
		| '1': b += (1 << n); n++
		| _: /* do nothing */
		;;
		if n == 8
			std.slpush(&sl, (b : byte))
			n = 0
			b = 0
		;;
	;;
	std.slpush(&sl, (b : byte))
	-> bio.mkmem(sl)
}

const expect = {ctx, t, str
	var exp

	exp = std.mk(str)
	const write = {buf
		if buf.len > exp#.len
			testr.fail(ctx,
				"({}) more bytes written than expected:" \
				" '{}' written when expecting '{}'",
				t, buf, exp#)
		;;
		if !std.eq(buf, exp#[:buf.len])
			testr.fail(ctx,
				"({}) '{}' written when expecting '{}'",
				t, buf, exp#[:buf.len])
		;;
		exp# = exp#[buf.len:]
		-> `std.Ok buf.len
	}
	const close = {
		if exp#.len > 0
			testr.fail(ctx,
				"({}) output closed when still expecting '{}'",
				t, exp#)
		;;
		-> true
	}
	-> bio.mk(bio.Wr, [.write = write, .close = close])
}

type errortype = union
	`None
	`Eof
	`Io
	`Format
;;

const decode = {ctx
	const tests = [
		("empty",
		 "",
		 "", `Eof),
		("reserved block",
		 "1 11 00",
		 "", `Format),
		("partial block",
		 "1 1",
		 "", `Eof), /* could be `Format error... */
		("uncompressed empty block",
		 "1 00 00000 0000000000000000 1111111111111111",
		 "", `None),
		("uncompressed 3-bytes block",
		 "1 00 00000 1100000000000000 0011111111111111" \
		 "10000110 01000110 11000110",
		 "abc", `None),
		("uncompressed 1-byte then 2-bytes block",
		 "0 00 00000 1000000000000000 0111111111111111" \
		 "10000110" \
		 "1 00 00000 0100000000000000 1011111111111111" \
		 "01000110 11000110",
		 "abc", `None),
		("uncompressed mismatched len nlen",
		 "1 00 00000 1000000000000000 0111111111111011" \
		 "10000110",
		 "", `Format),
		("fixed code single block",
		 "1 10 10011000 10010101 10011100 10011100" \
		 "10011111 01010001 0000000",
		 "hello!", `None),
		("dynamic code block",
		 // 'r' markers are used when bits must come in "reverse"
		 "1 01" \
		 "r00000 r00000 r1111" \           // 257 literal, 1 distance, 19 code codes
		 "r101 101 101 101 101 101 101 101 101 101 101 101" \
		 " 101 101 101 101 101 101 101" \  // All 14 code codes of length 5
		 "10010 r0110101" \                // 97 zeroes
		 "00011 00011 00011" \             // then three times 3
		 "10010 r1111111 10010 r1110000" \ // then 156 (= 11+127 + 11+7) zeroes
		 "00011" \                         // then 3 once (end of block)
		 "00000" \                         // one unused length code
		 "000 001 010 011",                // finally, some data
		 "abc", `None),
		("dynamic code block with reference",
		 "1 01" \
		 "r01000 r10000 r1111" \           // 258 literal, 2 distance, 19 code codes
		 "r101 101 101 101 101 101 101 101 101 101 101 101" \
		 " 101 101 101 101 101 101 101" \  // All 14 code codes of length 5
		 "10010 r0110101" \                // 97 zeroes
		 "00011 00011 00011" \             // then three times 3
		 "10010 r1111111 10010 r1110000" \ // then 156 (= 11+127 + 11+7) zeroes
		 "00011 00000 00011" \             // then 3, 0, 3 (ref of length 4)
		 "00010 00010" \                   // two size two distance code
		 "000 001 010 100 01 011",         // use the reference at the end
		 "abcbcbc", `None),
		("dynamic code with invalid reference",
		 "1 01" \
		 "r01000 r10000 r1111" \
		 "r101 101 101 101 101 101 101 101 101 101 101 101" \
		 " 101 101 101 101 101 101 101" \
		 "10010 r0110101" \
		 "00011 00011 00011" \
		 "10010 r1111111 10010 r1110000" \
		 "00011 00000 00011" \
		 "00010 00010" \
		 "000 100 01",
		 "a", `Format),
		("dynamic code with cross-blocks reference",
		 "0 00 00000 1100000000000000 0011111111111111" \
		 "10000110 01000110 11000110" \    // "abc" uncompressed
		 "1 01" \
		 "r01000 r00000 r1000" \           // 258 literal, 1 distance, 5 code codes
		 "r110 110 110 000 010" \          // code lengths of 3, 3, 3, 0, 2
		 \                                 // for 16, 17, 18, 0, 8
		 "100 r1111111" \                  // 138 zeroes
		 "100 r1101011" \                  // +118 zeroes = 256 zeroes
		 "00" \                            // 256 (=EOB) has length 8
		 "010 r00" \                       // repeat this length 3 times
		 "00000001 00000000 00000000",     // one reference, then EOB
		 "abcccc", `None),
	]
	var outf, err

	for (t, in, out, experr) : tests
		outf = expect(ctx, t, out)
		err = flate.decode(outf, bitsf(in))
		match (err, experr)
		| (`std.Ok _, `None):
		| (`std.Err (`flate.Io `bio.Eof), `Eof):
		| (`std.Err (`flate.Io _), `Io):
		| (`std.Err (`flate.Format _), `Format):
		| _:
			testr.fail(ctx,
				"({}) {} error expected but got {} instead",
				t, experr, err)
		;;
		bio.close(outf)
	;;
}

const main = {
	testr.run([
		[.name = "bits", .fn = bits],
		[.name = "mkhufc", .fn = mkhufc],
		[.name = "readcode", .fn = readcode],
		[.name = "decode", .fn = decode],
	][:])
}