shithub: sl

Download patch

ref: c7c8af3159b3798bbb5682af8278dba388f5adfb
parent: f407276b16426d1c27eacef1b168c45a973cdafb
author: spew <spew@cbza.org>
date: Thu Apr 17 12:07:20 EDT 2025

implement generic map

Fixes: https://todo.sr.ht/~ft/sl/18

--- a/src/sl.c
+++ b/src/sl.c
@@ -14,8 +14,8 @@
 sl_v sl_builtinssym, sl_quote, sl_lambda, sl_comma, sl_commaat;
 sl_v sl_commadot, sl_trycatch, sl_backquote;
 sl_v sl_conssym, sl_symsym, sl_fixnumsym, sl_vecsym, sl_builtinsym, sl_vu8sym;
-sl_v sl_defsym, sl_defmacrosym, sl_forsym, sl_setqsym;
-sl_v sl_booleansym, sl_nullsym, sl_evalsym, sl_fnsym, sl_trimsym;
+sl_v sl_defsym, sl_defmacrosym, sl_forsym, sl_setqsym, sl_listsym;
+sl_v sl_booleansym, sl_nullsym, sl_evalsym, sl_fnsym, sl_trimsym, sl_strsym;
 sl_v sl_nulsym, sl_alarmsym, sl_backspacesym, sl_tabsym, sl_linefeedsym, sl_newlinesym;
 sl_v sl_vtabsym, sl_pagesym, sl_returnsym, sl_escsym, sl_spacesym, sl_deletesym;
 sl_v sl_errio, sl_errparse, sl_errtype, sl_errarg, sl_errmem, sl_errconst;
@@ -1110,51 +1110,66 @@
 	return _stacktrace(sl.throwing_frame ? sl.throwing_frame : sl.curr_frame);
 }
 
-BUILTIN("map", map)
+static sl_v
+map_seq(sl_v rt, sl_v *args, int nargs)
 {
-	if(sl_unlikely(nargs < 2))
-		argcount(nargs, 2);
+	enum {
+		RT_AUTO,
+		RT_LIST,
+		RT_VEC,
+		RT_ARR,
+		RT_STR,
+		RT_TBL,
+		RT_NIL,
+	} rtype = RT_AUTO;
+	sl_v arrtype = sl_nil;
+	bool rtdecl = true;
+	if(rt == sl_nil)
+		rtype = RT_NIL;
+	else if(issym(rt)){
+		if(rt == sl_listsym)
+			rtype = RT_LIST;
+		else if(rt == sl_vecsym)
+			rtype = RT_VEC;
+		else if(rt == sl_strsym)
+			rtype = RT_STR;
+		else if(rt == sl_tablesym)
+			rtype = RT_TBL;
+		else
+			bthrow(type_error("sequence", rt));
+	}else if(iscons(rt)){
+		if(car_(rt) != sl_arrsym)
+			bthrow(type_error("sequence", rt));
+		rtype = RT_ARR;
+		arrtype = car_(cdr_(rt));
+		get_arr_type(arrtype);
+	}else
+		rtdecl = false;
 	sl_v *k = sl.sp;
 	PUSH(sl_nil);
 	PUSH(sl_nil);
-	for(bool first = true;;){
-		PUSH(args[0]);
-		for(int i = 1; i < nargs; i++){
-			if(!iscons(args[i])){
-				POPN(2+i);
-				return k[1];
-			}
-			PUSH(car(args[i]));
-			args[i] = cdr_(args[i]);
-		}
-		sl_v v = _applyn(nargs-1);
-		POPN(nargs);
-		PUSH(v);
-		sl_v c = alloc_cons();
-		car_(c) = POP(); cdr_(c) = sl_nil;
-		if(first)
-			k[1] = c;
-		else
-			cdr_(k[0]) = c;
-		k[0] = c;
-		first = false;
-	}
-}
-
-BUILTIN("for-each", for_each)
-{
-	if(sl_unlikely(nargs < 2))
-		argcount(nargs, 2);
 	for(usize n = 0;; n++){
 		PUSH(args[0]);
 		int pargs = 0;
-		for(int i = 1; i < nargs; i++, pargs++){
-			sl_v v = args[i];
+		for(sl_v *a = args+1; a < args+nargs; a++, pargs++){
+			sl_v v = *a;
 			if(iscons(v)){
+				if(!rtdecl){
+					if(rtype == RT_AUTO)
+						rtype = RT_LIST;
+					if(rtype != RT_LIST)
+						bthrow(lerrorf(sl_errarg, "sequence type mismatch"));
+				}
 				PUSH(car_(v));
-				args[i] = cdr_(v);
+				*a = cdr_(v);
 				continue;
 			}else if(isvec(v)){
+				if(!rtdecl){
+					if(rtype == RT_AUTO)
+						rtype = RT_VEC;
+					if(rtype != RT_VEC)
+						bthrow(lerrorf(sl_errarg, "sequence type mismatch"));
+				}
 				usize sz = vec_size(v);
 				if(n < sz){
 					PUSH(vec_elt(v, n));
@@ -1161,11 +1176,17 @@
 					continue;
 				}
 			}else if(sl_isstr(v)){
+				if(!rtdecl){
+					if(rtype == RT_AUTO)
+						rtype = RT_STR;
+					if(rtype != RT_STR)
+						bthrow(lerrorf(sl_errarg, "sequence type mismatch"));
+				}
 				char *s = tostr(v);
-				usize sz = cv_len(ptr(v)), b, k;
-				for(b = k = 0; k < n && b < sz; k++)
+				usize sz = cv_len(ptr(v)), b, l;
+				for(b = l = 0; l < n && b < sz; l++)
 					b += u8_seqlen(s+b);
-				if(k == n && b < sz){
+				if(l == n && b < sz){
 					Rune r;
 					chartorune(&r, s+b);
 					PUSH(mk_rune(r));
@@ -1172,15 +1193,29 @@
 					continue;
 				}
 			}else if(isarr(v)){
+				if(!rtdecl){
+					if(rtype == RT_AUTO){
+						rtype = RT_ARR;
+						arrtype = cv_class(ptr(v))->eltype->type;
+					}
+					if(rtype != RT_ARR || arrtype != cv_class(ptr(v))->eltype->type)
+						bthrow(lerrorf(sl_errarg, "sequence type mismatch"));
+				}
 				usize sz = cvalue_arrlen(v);
 				if(n < sz){
-					sl_v a[2];
-					a[0] = v;
-					a[1] = fixnum(n);
-					PUSH(cvalue_arr_aref(a));
+					sl_v aref[2];
+					aref[0] = v;
+					aref[1] = fixnum(n);
+					PUSH(cvalue_arr_aref(aref));
 					continue;
 				}
 			}else if(ishashtable(v)){
+				if(!rtdecl){
+					if(rtype == RT_AUTO)
+						rtype = RT_TBL;
+					if(rtype != RT_TBL)
+						bthrow(lerrorf(sl_errarg, "sequence type mismatch"));
+				}
 				sl_htable *h = totable(v);
 				assert(n != 0 || h->i == 0);
 				void **table = h->table;
@@ -1198,13 +1233,82 @@
 				h->i = 0;
 			}
 			POPN(pargs+1);
-			return sl_void;
+			switch(rtype){
+			default:
+				break;
+			case RT_VEC:
+				k[1] = alloc_vec(n, 0);
+				memcpy(&vec_elt(k[1], 0), sl.sp-n, n*sizeof(sl_v));
+				POPN(n);
+				break;
+			case RT_ARR:
+				k[0] = sym_value(sl_arrsym);
+				k[1] = arrtype;
+				k[1] = _applyn(n+1);
+				POPN(n);
+				break;
+			case RT_STR:
+				k[1] = sym_value(sl_strsym);
+				k[1] = _applyn(n);
+				POPN(n);
+				break;
+			case RT_TBL:
+				k[1] = sym_value(sl_tablesym);
+				k[1] = _applyn(2*n);
+				POPN(2*n);
+				break;
+			case RT_NIL:
+				k[1] = sl_nil;
+				break;
+			}
+			POPN(2);
+			return k[1];
 		}
-		_applyn(pargs);
+		sl_v v = _applyn(pargs);
 		POPN(pargs+1);
+		switch(rtype){
+		sl_v c;
+		default:
+			PUSH(v);
+		case RT_NIL:
+			break;
+		case RT_TBL:
+			PUSH(car(v));
+			PUSH(cdr(v));
+			break;
+		case RT_LIST:
+			c = alloc_cons();
+			car_(c) = v;
+			cdr_(c) = sl_nil;
+			if(n == 0)
+				k[1] = c;
+			else
+				cdr_(k[0]) = c;
+			k[0] = c;
+			break;
+		}
 	}
 }
 
+BUILTIN("map", map)
+{
+	if(sl_unlikely(nargs < 2))
+		argcount(nargs, 2);
+	sl_v v = args[0];
+	if(v == sl_nil || issym(v) || iscons(v))
+		return map_seq(v, args+1, nargs-1);
+	else
+		return map_seq(UNBOUND, args, nargs);
+}
+
+BUILTIN("for-each", for_each)
+{
+	if(sl_unlikely(nargs < 2))
+		argcount(nargs, 2);
+	map_seq(sl_nil, args, nargs);
+	return sl_void;
+}
+
 BUILTIN("sleep", sl_sleep)
 {
 	if(nargs > 1)
@@ -1337,6 +1441,9 @@
 	sl_vecstructsym = mk_csym("%struct%");
 	sl_structsym = mk_csym("struct");
 	sl_builtinssym = mk_csym("*builtins*");
+	sl_listsym = mk_csym("list");
+	sl_tablesym = mk_csym("table");
+	sl_strsym = mk_csym("str");
 
 	set(sl_printprettysym = mk_csym("*print-pretty*"), sl_t);
 	set(sl_printreadablysym = mk_csym("*print-readably*"), sl_t);
--- a/src/sl.h
+++ b/src/sl.h
@@ -439,8 +439,8 @@
 extern sl_v sl_builtinssym, sl_quote, sl_lambda, sl_comma, sl_commaat;
 extern sl_v sl_commadot, sl_trycatch, sl_backquote;
 extern sl_v sl_conssym, sl_symsym, sl_fixnumsym, sl_vecsym, sl_builtinsym, sl_vu8sym;
-extern sl_v sl_defsym, sl_defmacrosym, sl_forsym, sl_setqsym;
-extern sl_v sl_booleansym, sl_nullsym, sl_evalsym, sl_fnsym, sl_trimsym;
+extern sl_v sl_defsym, sl_defmacrosym, sl_forsym, sl_setqsym, sl_listsym;
+extern sl_v sl_booleansym, sl_nullsym, sl_evalsym, sl_fnsym, sl_trimsym, sl_strsym;
 extern sl_v sl_nulsym, sl_alarmsym, sl_backspacesym, sl_tabsym, sl_linefeedsym, sl_newlinesym;
 extern sl_v sl_vtabsym, sl_pagesym, sl_returnsym, sl_escsym, sl_spacesym, sl_deletesym;
 extern sl_v sl_errio, sl_errparse, sl_errtype, sl_errarg, sl_errmem, sl_errconst;
@@ -449,7 +449,7 @@
 
 extern sl_v sl_printwidthsym, sl_printreadablysym, sl_printprettysym, sl_printlengthsym;
 extern sl_v sl_printlevelsym;
-extern sl_v sl_arrsym;
+extern sl_v sl_tablesym, sl_arrsym;
 extern sl_v sl_iosym, sl_rdsym, sl_wrsym, sl_apsym, sl_crsym, sl_truncsym;
 extern sl_v sl_s8sym, sl_u8sym, sl_s16sym, sl_u16sym, sl_s32sym, sl_u32sym;
 extern sl_v sl_s64sym, sl_u64sym, sl_p32sym, sl_p64sym, sl_ptrsym, sl_bignumsym;
--- a/test/unittest.sl
+++ b/test/unittest.sl
@@ -427,6 +427,19 @@
 (assert (equal? (map (λ (x y) (+ x y)) '(1 2) '(3)) '(4)))
 (assert (equal? (map (λ (x y z) (+ x y z)) '(1 2) '(3) '(4 5)) '(8)))
 
+;; map with different return types
+(assert (equal? (map 'vec + '(1 2 3) '(4 5 6) '(7 8 9)) (vec 12 15 18)))
+(assert (equal? (map '(arr s32) + '(1 2 3) '(4 5 6) '(7 8 9)) (arr 's32 12 15 18)))
+(def tbl (table "hi" 32 109234 "blah"))
+(assert (equal? (get (map cons tbl) "hi") 32))
+(assert (equal? (get (map cons tbl) 109234) "blah"))
+(assert (equal? (map + (vec 1 2 3) (vec 4 5 6)) (vec 5 7 9)))
+(assert-fail (map + (vec 1 2 3) '(4 5 6)))
+(assert (equal? (map 'vec + (vec 1 2 3) '(4 5 6)) (vec 5 7 9)))
+(assert (equal? (map 'list + (vec 1 2 3) '(4 5 6)) '(5 7 9)))
+(assert (equal? (map 'str + (vec 1 2 3) '(4 5 6)) "579"))
+(assert (equal? (map '(arr s16) + (vec 1 2 3) '(4 5 6)) (arr 's16 5 7 9)))
+
 ;; aref with multiple indices
 (def a #(#(0 1 2) #(3 (4 5 6) 7)))
 (assert (equal? 0 (aref a 0 0)))