diff --git a/src/proto_alpha/lib_protocol/bitset.ml b/src/proto_alpha/lib_protocol/bitset.ml index 7f7792bf0803b3a400a1b6ea6c7242b42f944890..e9facb7e132aad4417f1786adad9348c8215ac34 100644 --- a/src/proto_alpha/lib_protocol/bitset.ml +++ b/src/proto_alpha/lib_protocol/bitset.ml @@ -43,6 +43,8 @@ let from_list positions = List.fold_left_e add empty positions let inter = Z.logand +let diff b1 b2 = Z.logand b1 (Z.lognot b2) + let () = let open Data_encoding in register_error_kind diff --git a/src/proto_alpha/lib_protocol/bitset.mli b/src/proto_alpha/lib_protocol/bitset.mli index 96b101a56cfcf1cdcf6b20b465d61d0c8fad9188..a8c23b1f42904d6a95c6379a3831d3e5ca4d0a04 100644 --- a/src/proto_alpha/lib_protocol/bitset.mli +++ b/src/proto_alpha/lib_protocol/bitset.mli @@ -44,13 +44,19 @@ val mem : t -> int -> bool tzresult This functions returns [Invalid_input i] if [i] is negative. *) val add : t -> int -> t tzresult -(** [from_list positions] folds [add] over the [positions] starting from [empty]. *) +(** [from_list positions] folds [add] over the [positions] starting from [empty]. + This function returns [Invalid_input i] if [i] is negative and appears in + [positions]. *) val from_list : int list -> t tzresult -(** [inter field_l field_r] returns [field] which is result of the - logical "and" bit-wise from [field_l] and [field_r]. *) +(** [inter set_l set_r] returns [set] which is result of the + intersection of [set_l] and [set_r]. *) val inter : t -> t -> t +(** [diff set_l set_r] returns a [set] containing fiels in [set_l] + that are not in [set_r]. *) +val diff : t -> t -> t + (** [occupied_size_in_bits bitset] returns the current number of bits occupied by the [bitset]. *) val occupied_size_in_bits : t -> int diff --git a/src/proto_alpha/lib_protocol/test/pbt/test_bitset.ml b/src/proto_alpha/lib_protocol/test/pbt/test_bitset.ml index e3f5b22770b91668d3ab8ad2cac71623eb2874c4..5099909ada027184951f091bb65ce02d19aaf515 100644 --- a/src/proto_alpha/lib_protocol/test/pbt/test_bitset.ml +++ b/src/proto_alpha/lib_protocol/test/pbt/test_bitset.ml @@ -37,13 +37,8 @@ let gen_ofs = QCheck2.Gen.int_bound (64 * 10) let gen_storage = let open QCheck2.Gen in - let* bool_vector = list bool in - match - List.fold_left_i_e - (fun i storage v -> if v then add storage i else Ok storage) - empty - bool_vector - with + let* int_vector = list @@ int_bound 64 in + match from_list int_vector with | Ok v -> return v | Error e -> Alcotest.failf @@ -70,6 +65,46 @@ let test_get_set (c, ofs) = | Ok res -> res) (0 -- 63) +let test_inter (c1, c2) = + let c3 = inter c1 c2 in + List.for_all + (fun ofs -> + let res = + let open Result_syntax in + let* v1 = mem c1 ofs in + let* v2 = mem c2 ofs in + let* v3 = mem c3 ofs in + return ((v1 && v2) = v3) + in + match res with + | Error e -> + Alcotest.failf + "Unexpected error: %a" + Environment.Error_monad.pp_trace + e + | Ok res -> res) + (0 -- 63) + +let test_diff (c1, c2) = + let c3 = diff c1 c2 in + List.for_all + (fun ofs -> + let res = + let open Result_syntax in + let* v1 = mem c1 ofs in + let* v2 = mem c2 ofs in + let* v3 = mem c3 ofs in + return ((v1 && not v2) = v3) + in + match res with + | Error e -> + Alcotest.failf + "Unexpected error: %a" + Environment.Error_monad.pp_trace + e + | Ok res -> res) + (0 -- 63) + let () = Alcotest.run "bits" @@ -82,5 +117,15 @@ let () = ~name:"get set" QCheck2.Gen.(pair gen_storage gen_ofs) test_get_set; + QCheck2.Test.make + ~count:10000 + ~name:"inter" + QCheck2.Gen.(pair gen_storage gen_storage) + test_inter; + QCheck2.Test.make + ~count:10000 + ~name:"diff" + QCheck2.Gen.(pair gen_storage gen_storage) + test_diff; ] ); ]