PreviousUpNext

15.4.1094  src/lib/std/src/matrix.pkg

## matrix.pkg
#
# Two-dimensional matrices.

# Compiled by:
#     src/lib/std/src/standard-core.sublib



###                  "Engineering is like acting,
###                   in that when it is well done,
###                   it goes unnoticed and unapplauded."



#DO set_control "compiler::trap_int_overflow" "TRUE";

stipulate
    package rwv =  rw_vector;                           # rw_vector             is from   src/lib/std/src/rw-vector.pkg
    package rws =  rw_vector_slice;                     # rw_vector_slice       is from   src/lib/std/src/rw-vector-slice.pkg
    package inl =  inline_t;                            # inline_t              is from   src/lib/core/init/built-in.pkg
herein

    package   matrix
    :         Matrix                                    # Matrix                is from   src/lib/std/src/matrix.api
    {                                                   # inline_t              is from   src/lib/core/init/built-in.pkg
        ltu = inl::default_int::ltu;

        unsafe_set = inl::poly_rw_vector::set;
        unsafe_get = inl::poly_rw_vector::get;


        Matrix(X)
            =
            { data:   rwv::Rw_Vector(X),
              nrows:  Int,
              ncols:  Int
            };

        Region(X)
            =
            { base:   Matrix(X),
              row:    Int,
              col:    Int,
              nrows:  Null_Or( Int ),
              ncols:  Null_Or( Int )
            };

        Traversal
            = ROW_MAJOR
            | COLUMN_MAJOR
            ;

        make_matrix'
            =
            inl::poly_rw_vector::array;

        # Compute the index of an matrix element 
        #
        fun unsafe_index ( { nrows, ncols, ... }: Matrix(X), i, j)
            =
            (i * ncols + j);

        fun index (arr, i, j)
            =
            if ((ltu (i, arr.nrows) and ltu (j, arr.ncols)))
                #
                unsafe_index (arr, i, j);
            else
                raise exception exceptions_guts::SUBSCRIPT;                                     # exceptions_guts       is from   src/lib/std/src/exceptions-guts.pkg
            fi;

        fun check_size (nrows, ncols)
            =
            if  (nrows < 0
            or   ncols < 0
            )
                raise exception exceptions_guts::SIZE;
            else
                n = nrows * ncols
                    except
                        OVERFLOW = raise exception exceptions_guts::SIZE;

                if (n > core::maximum_vector_length)    raise exception exceptions_guts::SIZE;  fi;

                n;
            fi;

        fun make_matrix (nrows, ncols, v)
            =
            case (check_size (nrows, ncols))
                #
                0 => { data => inl::poly_rw_vector::new_array0(), nrows => 0, ncols => 0 };
                n => { data => make_matrix' (n, v), nrows, ncols };
            esac;

        fun from_list rows
            =
            case (list::reverse rows)
                #         
                []  =>
                    { data  => inl::poly_rw_vector::new_array0(),
                      nrows => 0,
                      ncols => 0
                    };

                last_row ! rest
                    =>
                    {
                        columns = list::length last_row;


                        fun check ([], rows, l)
                                =>
                                (rows, l);

                            check (row ! rest, rows, l)
                                =>
                                check (rest, rows+1, check_row (row, 0))
                                where
                                    fun check_row ([], n)
                                            =>
                                            {   if   (n != columns   )   raise exception exceptions_guts::SIZE;   fi;
                                                l;
                                            };

                                        check_row (x ! r, n)
                                            =>
                                            x ! check_row (r, n+1);
                                    end;
                                end;
                        end;

                        (check (rest, 1, last_row))
                            ->
                            (rows, data);
                            

                        { data => rw_vector::from_list data, nrows => rows, ncols => columns };
                    };
            esac;


        stipulate
            fun from_fn_rm (nrows, ncols, f)                                            # "rm" == "row major"
                =
                case (check_size (nrows, ncols))
                    #
                    0 => { data => inl::poly_rw_vector::new_array0(), nrows, ncols };
                    #
                    n => {

                        arr = make_matrix' (n, f (0, 0));

                        fun lp1 (i, j, k)
                            =
                            if (i < nrows)
                                #
                                lp2 (i, 0, k);
                            fi

                        also
                        fun lp2 (i, j, k)
                            =
                            if (j < ncols)
                                #
                                unsafe_set (arr, k, f (i, j));
                                lp2 (i, j+1, k+1);
                            else
                                lp1 (i+1, 0, k);
                            fi;

                        lp2 (0, 1, 1);  #  we've already done (0, 0) 

                        { data => arr, nrows, ncols };
                    };
                esac;


            fun from_fn_cm (nrows, ncols, f)                                    # "cm" == "column major"
                =
                case (check_size (nrows, ncols))
                    #          
                    0 => { data => inl::poly_rw_vector::new_array0(), nrows, ncols };
                    #          
                    n => {
                        arr   = make_matrix' (n, f (0, 0));

                        delta = n - 1;

                        fun lp1 (i, j, k)
                            =
                            if (j < ncols)
                                #
                                lp2 (0, j, k);
                            fi

                        also
                        fun lp2 (i, j, k)
                            =
                            if (i < nrows)
                                #
                                unsafe_set (arr, k, f (i, j));

                                lp2 (i+1, j, k+ncols);
                            else
                                lp1 (0, j+1, k-delta);
                            fi;

                        lp2 (1, 0, ncols);              # We've already done (0, 0) 

                        { data => arr, nrows, ncols };
                    };
                esac;
        herein
            fun from_fn  ROW_MAJOR    =>  from_fn_rm;
                from_fn  COLUMN_MAJOR =>  from_fn_cm;
            end;
        end;

        fun get (a, i, j)    =  unsafe_get (a.data, index (a, i, j));
        fun set (a, i, j, v) =  unsafe_set (a.data, index (a, i, j), v);


        fun dimensions { data, nrows, ncols }
            =
            (nrows, ncols);


        fun columns (arr:  Matrix(X)) =  arr.ncols;
        fun rows    (arr:  Matrix(X)) =  arr.nrows;


        fun row ( { data, nrows, ncols }, i)
            =
            {   stop = i*ncols;

                fun make_vec (j, l)
                    =
                    if (j < stop)
                         vector::from_list l;
                    else
                         make_vec (j - 1, rwv::get (data, j) ! l);
                    fi;

                if (not (ltu (nrows, i)))
                    #
                    make_vec (stop+ncols - 1, []);
                else 
                    raise exception exceptions_guts::SUBSCRIPT;
                fi;
            };

        fun column ( { data, nrows, ncols }, j)
            =
            {   fun make_vec (i, l)
                    =
                    if (i < 0)
                        vector::from_list l;
                    else
                        make_vec (i-ncols, rwv::get (data, i) ! l);
                    fi;

                if (not (ltu (ncols, j)))
                    make_vec ((rwv::length data - ncols) + j, []);                 
                else
                    raise exception exceptions_guts::SUBSCRIPT;
                fi;
            };

        Index = DONE
              | INDEX  { i: Int, r: Int, c: Int }
              ;

        fun check_region { base=> { data, nrows, ncols }, row, col, nrows=>nr, ncols=>nc }
            =
            {   fun check (start, n, NULL)
                        =>
                        if  (start < 0
                        or   start > n
                        )
                             raise exception exceptions_guts::SUBSCRIPT;
                        else
                             n-start;
                        fi;

                    check (start, n, THE len)
                        =>
                        if ((start < 0) or (len < 0) or (n < start+len))
                            #
                            raise exception exceptions_guts::SUBSCRIPT;
                        else
                            len;
                        fi;
                end;

                nr = check (row, nrows, nr);
                nc = check (col, ncols, nc);

                { data, i => (row*ncols + col), r=>row, c=>col, nr, nc };
            };

        fun copy { src:  Region(X), dst: Matrix(X), dst_row, dst_col }
            =
            {   check_region src;

                src -> { base,
                        row    => srow,   col   => scol,
                         nrows => snrows, ncols => sncols
                       };

                base ->  { data => bdata, ncols => bncols, nrows => bnrows };
                dst  ->  { data => ddata, ncols => dncols, nrows => dnrows };

                src_nrows = the_else (snrows, bnrows - srow);
                src_ncols = the_else (sncols, bncols - scol);

                fun dn (i, d, s)
                    =
                    if (i > 0 )
                        #
                        # We might be better off doing this directly
                        # instead of calling the rw_vector_slice module:
                        #       
                        rws::copy { src => rws::make_slice (bdata, s, THE src_ncols),
                                   dst => ddata, di => d
                                 };

                        dn (i - 1, d + dncols, s + bncols);
                    fi;


                fun up (i, d, s)
                    =
                    if (i > 0)
                        #
                        rws::copy { src => rws::make_slice (bdata, s, THE src_ncols),
                                   dst => ddata, di => d
                                 };

                        up (i - 1, d - dncols, s - bncols);
                    fi;

                if  (src_nrows + dst_row > dnrows
                or   src_ncols + dst_col > dncols
                )
                    raise exception exceptions_guts::SUBSCRIPT;
                else
                    if (dst_row <= srow)
                        #
                        dn ( src_nrows,
                             dst_row * dncols + dst_col,
                             srow * bncols + scol
                           );
                    else
                        up ( src_nrows,
                             (dst_row + src_nrows - 1) * dncols + dst_col,
                             (srow + src_nrows - 1) * bncols + scol
                           );
                    fi;
                fi;
            };


        # This function generates a stream of indices
        # for the given region in row-major order.
        #
        fun iterate_rm arg
            =
            (data, iter)
            where  

                (check_region arg)
                    ->
                    { data, i, r, c=>c_start, nr, nc };

                ii = REF i;
                ri = REF r;
                ci = REF c_start;

                r_end = r+nr;
                c_end = c_start+nc;

                row_delta = arg.base.ncols - nc;

                fun make_index (r, c)
                    =
                    {   i = *ii;
                        #
                        ii := i+1;
                        INDEX { i, c, r };
                    };

                fun iter ()
                    =
                    {   r = *ri;
                        c = *ci;

                        if (c < c_end)
                            #
                            ci := c+1;
                            make_index (r, c);

                        elif (r+1 < r_end)

                            ii := *ii + row_delta;
                            ci := c_start;
                            ri := r+1;

                            iter ();

                        else

                            DONE;
                        fi;
                    };
                end;

        # This function generates a stream of indices
        # for the given region in col-major order:
        #
        fun iterate_cm (arg as { base=> { ncols, nrows, ... }, ... } )
            =
            {   (check_region  arg)
                    ->
                    { data, i, r=>r_start, c, nr, nc };

                ii = REF i;
                ri = REF r_start;
                ci = REF c;

                r_end = r_start+nr;
                c_end = c+nc;

                delta = (nr * ncols) - 1;

                fun make_index (r, c)
                    =
                    {   i = *ii;
                        #
                        ii := i+ncols;
                        INDEX { i, c, r };
                    };

                fun iter ()
                    =
                    {   r = *ri;
                        c = *ci;

                        if (r < r_end)
                            #
                            ri := r+1;
                            make_index (r, c);

                        elif (c+1 < c_end)

                            ii := *ii - delta;
                            ri := r_start;
                            ci := c+1;

                            iter ();

                        else

                            DONE;
                        fi;
                    };

                  (data, iter);
              };


        fun keyed_apply order f region
            =
            apply ()
            where
                my (data, iter)
                    =
                    case order
                        #
                        ROW_MAJOR    => iterate_rm region;
                        COLUMN_MAJOR => iterate_cm region;
                    esac;


                fun apply ()
                    =
                    case (iter ())

                        DONE => ();

                        INDEX { i, r, c }
                            =>
                            {   f (r, c, unsafe_get (data, i));

                                apply ();
                            };
                    esac;
            end;


        fun apply_rm f { data, ncols, nrows }
            =
            rwv::apply f data;


        fun apply_cm f { data, ncols, nrows }
            =
            appf (0, 0)
            where
                delta = rwv::length data - 1;

                fun appf (i, k)
                    =
                    if (i < nrows)
                        #
                        f (unsafe_get (data, k));
                        appf (i+1, k+ncols);
                    else
                        k = k-delta;

                        if (k < ncols)   appf (0, k);   fi;
                    fi;
            end;

        fun apply ROW_MAJOR    =>  apply_rm;
            apply COLUMN_MAJOR =>  apply_cm;
        end;

        fun keyed_map_in_place order f region
            =
            modify ()
            where
                my (data, iter)
                    =
                    case order
                        #
                        ROW_MAJOR    => iterate_rm  region;
                        COLUMN_MAJOR => iterate_cm  region;
                    esac;


                fun modify ()
                    =
                    case (iter ())
                        #
                        DONE => ();

                        INDEX { i, r, c }
                            =>
                            {   unsafe_set (data, i, f (r, c, unsafe_get (data, i)));
                                modify();
                            };
                    esac;
            end;


        stipulate
            fun modify_rm f { data, ncols, nrows }
                =
                rwv::map_in_place  f  data;


            fun modify_cm f { data, ncols, nrows }
                =
                modf (0, 0)
                where
                    delta = rwv::length data - 1;

                    fun modf (i, k)
                        =
                        if (i < nrows)
                            #
                            unsafe_set (data, k, f (unsafe_get (data, k))); modf (i+1, k+ncols);
                        else
                            k = k-delta;

                            if (k < ncols)    modf (0, k);   fi;
                        fi;
                end;
        herein

            fun map_in_place ROW_MAJOR    =>  modify_rm;
                map_in_place COLUMN_MAJOR =>  modify_cm;
            end;
        end;

        fun foldi order f init region
            =
            fold init
            where

                my (data, iter)
                    =
                    case order
                        ROW_MAJOR    => iterate_rm  region;
                        COLUMN_MAJOR => iterate_cm  region;
                    esac;


                fun fold accum
                    =
                    case (iter ())
                        #
                        DONE => accum;

                        INDEX { i, r, c }
                            =>
                            fold (f(r, c, unsafe_get (data, i), accum));
                    esac;
            end;


        fun fold_rm f init { data, ncols, nrows }
            =
            rwv::fold_forward f init data;


        fun fold_cm f init { data, ncols, nrows }
            =
            foldf (0, 0, init)
            where
                delta = rwv::length data - 1;

                fun foldf (i, k, accum)
                    =
                    if (i < nrows)
                        #
                        foldf (i+1, k+ncols, f (unsafe_get (data, k), accum));
                    else
                        k = k-delta;

                        if (k < ncols)   foldf (0, k, accum);
                        else             accum;
                        fi;
                    fi;
            end;


        fun fold ROW_MAJOR    =>  fold_rm;
            fold COLUMN_MAJOR =>  fold_cm;
        end;

    };
end;

package rw_matrix= matrix;      # matrix        is from   src/lib/std/src/matrix.pkg


Comments and suggestions to: bugs@mythryl.org

PreviousUpNext