36
   37:- module(thread,
   38          [ concurrent/3,                  39            concurrent_maplist/2,          40            concurrent_maplist/3,          41            concurrent_maplist/4,          42            concurrent_forall/2,           43            concurrent_forall/3,           44            concurrent_and/2,              45            concurrent_and/3,              46            first_solution/3,              47
   48            call_in_thread/2               49          ]).   50:- autoload(library(apply), [maplist/2, maplist/3, maplist/4, maplist/5]).   51:- autoload(library(error), [must_be/2, instantiation_error/1]).   52:- autoload(library(lists), [subtract/3, same_length/2, nth0/3]).   53:- autoload(library(option), [option/2, option/3]).   54:- autoload(library(ordsets), [ord_intersection/3, ord_union/3]).   55:- use_module(library(debug), [debug/3, assertion/1]).   56
   58
   59:- meta_predicate
   60    concurrent(+, :, +),
   61    concurrent_maplist(1, +),
   62    concurrent_maplist(2, ?, ?),
   63    concurrent_maplist(3, ?, ?, ?),
   64    concurrent_forall(0, 0),
   65    concurrent_forall(0, 0, +),
   66    concurrent_and(0, 0),
   67    concurrent_and(0, 0, +),
   68    first_solution(-, :, +),
   69    call_in_thread(+, 0).   70
   71
   72:- predicate_options(concurrent/3, 3,
   73                     [ pass_to(system:thread_create/3, 3)
   74                     ]).   75:- predicate_options(concurrent_forall/3, 3,
   76                     [ threads(nonneg)
   77                     ]).   78:- predicate_options(concurrent_and/3, 3,
   79                     [ threads(nonneg)
   80                     ]).   81:- predicate_options(first_solution/3, 3,
   82                     [ on_fail(oneof([stop,continue])),
   83                       on_error(oneof([stop,continue])),
   84                       pass_to(system:thread_create/3, 3)
   85                     ]).   86
  119
  163
  164concurrent(1, M:List, _) :-
  165    !,
  166    maplist(once_in_module(M), List).
  167concurrent(N, M:List, Options) :-
  168    must_be(positive_integer, N),
  169    must_be(list(callable), List),
  170    length(List, JobCount),
  171    message_queue_create(Done),
  172    message_queue_create(Queue),
  173    WorkerCount is min(N, JobCount),
  174    create_workers(WorkerCount, Queue, Done, Workers, Options),
  175    submit_goals(List, 1, M, Queue, VarList),
  176    forall(between(1, WorkerCount, _),
  177           thread_send_message(Queue, done)),
  178    VT =.. [vars|VarList],
  179    concur_wait(JobCount, Done, VT, cleanup(Workers, Queue),
  180                Result, [], Exitted),
  181    subtract(Workers, Exitted, RemainingWorkers),
  182    concur_cleanup(Result, RemainingWorkers, [Queue, Done]),
  183    (   Result == true
  184    ->  true
  185    ;   Result = false
  186    ->  fail
  187    ;   Result = exception(Error)
  188    ->  throw(Error)
  189    ).
  190
  191once_in_module(M, Goal) :-
  192    call(M:Goal), !.
  193
  199
  200submit_goals([], _, _, _, []).
  201submit_goals([H|T], I, M, Queue, [Vars|VT]) :-
  202    term_variables(H, Vars),
  203    thread_send_message(Queue, goal(I, M:H, Vars)),
  204    I2 is I + 1,
  205    submit_goals(T, I2, M, Queue, VT).
  206
  207
  215
  216concur_wait(0, _, _, _, true, Exited, Exited) :- !.
  217concur_wait(N, Done, VT, Cleanup, Status, Exitted0, Exitted) :-
  218    debug(concurrent, 'Concurrent: waiting for workers ...', []),
  219    catch(thread_get_message(Done, Exit), Error,
  220          concur_abort(Error, Cleanup, Done, Exitted0)),
  221    debug(concurrent, 'Waiting: received ~p', [Exit]),
  222    (   Exit = done(Id, Vars)
  223    ->  debug(concurrent, 'Concurrent: Job ~p completed with ~p', [Id, Vars]),
  224        arg(Id, VT, Vars),
  225        N2 is N - 1,
  226        concur_wait(N2, Done, VT, Cleanup, Status, Exitted0, Exitted)
  227    ;   Exit = finished(Thread)
  228    ->  thread_join(Thread, JoinStatus),
  229        debug(concurrent, 'Concurrent: waiter ~p joined: ~p',
  230              [Thread, JoinStatus]),
  231        (   JoinStatus == true
  232        ->  concur_wait(N, Done, VT, Cleanup, Status, [Thread|Exitted0], Exitted)
  233        ;   Status = JoinStatus,
  234            Exitted = [Thread|Exitted0]
  235        )
  236    ).
  237
  238concur_abort(Error, cleanup(Workers, Queue), Done, Exitted) :-
  239    debug(concurrent, 'Concurrent: got ~p', [Error]),
  240    subtract(Workers, Exitted, RemainingWorkers),
  241    concur_cleanup(Error, RemainingWorkers, [Queue, Done]),
  242    throw(Error).
  243
  244create_workers(N, Queue, Done, [Id|Ids], Options) :-
  245    N > 0,
  246    !,
  247    thread_create(worker(Queue, Done), Id,
  248                  [ at_exit(thread_send_message(Done, finished(Id)))
  249                  | Options
  250                  ]),
  251    N2 is N - 1,
  252    create_workers(N2, Queue, Done, Ids, Options).
  253create_workers(_, _, _, [], _).
  254
  255
  259
  260worker(Queue, Done) :-
  261    thread_get_message(Queue, Message),
  262    debug(concurrent, 'Worker: received ~p', [Message]),
  263    (   Message = goal(Id, Goal, Vars)
  264    ->  (   Goal
  265        ->  thread_send_message(Done, done(Id, Vars)),
  266            worker(Queue, Done)
  267        )
  268    ;   true
  269    ).
  270
  271
  278
  279concur_cleanup(Result, Workers, Queues) :-
  280    !,
  281    (   Result == true
  282    ->  true
  283    ;   kill_workers(Workers)
  284    ),
  285    join_all(Workers),
  286    maplist(message_queue_destroy, Queues).
  287
  288kill_workers([]).
  289kill_workers([Id|T]) :-
  290    debug(concurrent, 'Signalling ~w', [Id]),
  291    catch(thread_signal(Id, abort), _, true),
  292    kill_workers(T).
  293
  294join_all([]).
  295join_all([Id|T]) :-
  296    thread_join(Id, _),
  297    join_all(T).
  298
  299
  300		   303
  322
  323:- dynamic
  324    fa_aborted/1.  325
  326concurrent_forall(Generate, Test) :-
  327    concurrent_forall(Generate, Test, []).
  328
  329concurrent_forall(Generate, Test, Options) :-
  330    jobs(Jobs, Options),
  331    Jobs > 1,
  332    !,
  333    term_variables(Generate, GVars),
  334    term_variables(Test, TVars),
  335    sort(GVars, GVarsS),
  336    sort(TVars, TVarsS),
  337    ord_intersection(GVarsS, TVarsS, Shared),
  338    Templ =.. [v|Shared],
  339    MaxSize is Jobs*4,
  340    message_queue_create(Q, [max_size(MaxSize)]),
  341    length(Workers, Jobs),
  342    thread_self(Me),
  343    maplist(thread_create(fa_worker(Q, Me, Templ, Test)), Workers),
  344    catch(( forall(Generate,
  345                   thread_send_message(Q, job(Templ))),
  346            forall(between(1, Jobs, _),
  347                   thread_send_message(Q, done)),
  348            maplist(thread_join, Workers),
  349            message_queue_destroy(Q)
  350          ),
  351          Error,
  352          fa_cleanup(Error, Workers, Q)).
  353concurrent_forall(Generate, Test, _) :-
  354    forall(Generate, Test).
  355
  356fa_cleanup(Error, Workers, Q) :-
  357    maplist(safe_abort, Workers),
  358    debug(concurrent(fail), 'Joining workers', []),
  359    maplist(safe_join, Workers),
  360    debug(concurrent(fail), 'Destroying queue', []),
  361    retractall(fa_aborted(Q)),
  362    message_queue_destroy(Q),
  363    (   Error = fa_worker_failed(_0Test, Why)
  364    ->  debug(concurrent(fail), 'Test ~p failed: ~p', [_0Test, Why]),
  365        (   Why == false
  366        ->  fail
  367        ;   Why = error(E)
  368        ->  throw(E)
  369        ;   assertion(fail)
  370        )
  371    ;   throw(Error)
  372    ).
  373
  374fa_worker(Queue, Main, Templ, Test) :-
  375    repeat,
  376    thread_get_message(Queue, Msg),
  377    (   Msg == done
  378    ->  !
  379    ;   Msg = job(Templ),
  380        debug(concurrent, 'Running test ~p', [Test]),
  381        (   catch_with_backtrace(Test, E, true)
  382        ->  (   var(E)
  383            ->  fail
  384            ;   fa_stop(Queue, Main, fa_worker_failed(Test, error(E)))
  385            )
  386        ;   !,
  387            fa_stop(Queue, Main, fa_worker_failed(Test, false))
  388        )
  389    ).
  390
  391fa_stop(Queue, Main, Why) :-
  392    with_mutex('$concurrent_forall',
  393               fa_stop_sync(Queue, Main, Why)).
  394
  395fa_stop_sync(Queue, _Main, _Why) :-
  396    fa_aborted(Queue),
  397    !.
  398fa_stop_sync(Queue, Main, Why) :-
  399    asserta(fa_aborted(Queue)),
  400    debug(concurrent(fail), 'Stop due to ~p. Signalling ~q', [Why, Main]),
  401    thread_signal(Main, throw(Why)).
  402
  403jobs(Jobs, Options) :-
  404    (   option(threads(Jobs), Options)
  405    ->  true
  406    ;   current_prolog_flag(cpu_count, Jobs)
  407    ->  true
  408    ;   Jobs = 1
  409    ).
  410
  411safe_abort(Thread) :-
  412    catch(thread_signal(Thread, abort), error(_,_), true).
  413safe_join(Thread) :-
  414    E = error(_,_),
  415    catch(thread_join(Thread, _Status), E, true).
  416
  417
  418		   421
  448
  449concurrent_and(Gen, Test) :-
  450    concurrent_and(Gen, Test, []).
  451
  452concurrent_and(Gen, Test, Options) :-
  453    jobs(Jobs, Options),
  454    MaxSize is Jobs*4,
  455    message_queue_create(JobQueue, [max_size(MaxSize)]),
  456    message_queue_create(AnswerQueue, [max_size(MaxSize)]),
  457    ca_template(Gen, Test, Templ),
  458    term_variables(Gen+Test, AllVars),
  459    ReplyTempl =.. [v|AllVars],
  460    length(Workers, Jobs),
  461    Alive is 1<<Jobs-1,
  462    maplist(thread_create(ca_worker(JobQueue, AnswerQueue,
  463                                    Templ, Test, ReplyTempl)),
  464            Workers),
  465    thread_create(ca_generator(Gen, Templ, JobQueue, AnswerQueue),
  466                  GenThread),
  467    State = state(Alive),
  468    call_cleanup(
  469        ca_gather(State, AnswerQueue, ReplyTempl, Workers),
  470        ca_cleanup(GenThread, Workers, JobQueue, AnswerQueue)).
  471
  472ca_gather(State, AnswerQueue, ReplyTempl, Workers) :-
  473    repeat,
  474       thread_get_message(AnswerQueue, Msg),
  475       (   Msg = true(ReplyTempl)
  476       ->  true
  477       ;   Msg = done(Worker)
  478       ->  nth0(Done, Workers, Worker),
  479           arg(1, State, Alive0),
  480           Alive1 is Alive0 /\ \(1<<Done),
  481           debug(concurrent(and), 'Alive = ~2r', [Alive1]),
  482           (   Alive1 =:= 0
  483           ->  !,
  484               fail
  485           ;   nb_setarg(1, State, Alive1),
  486               fail
  487           )
  488       ;   Msg = error(E)
  489       ->  throw(E)
  490       ).
  491
  492ca_template(Gen, Test, Templ) :-
  493    term_variables(Gen,  GVars),
  494    term_variables(Test, TVars),
  495    sort(GVars, GVarsS),
  496    sort(TVars, TVarsS),
  497    ord_intersection(GVarsS, TVarsS, Shared),
  498    ord_union(GVarsS, Shared, TemplVars),
  499    Templ =.. [v|TemplVars].
  500
  501ca_worker(JobQueue, AnswerQueue, Templ, Test, ReplyTempl) :-
  502    thread_self(Me),
  503    EG = error(existence_error(message_queue, _), _),
  504    repeat,
  505    catch(thread_get_message(JobQueue, Req), EG, Req=all_done),
  506    (   Req = job(Templ)
  507    ->  (   catch(Test, E, true),
  508            (   var(E)
  509            ->  thread_send_message(AnswerQueue, true(ReplyTempl))
  510            ;   thread_send_message(AnswerQueue, error(E))
  511            ),
  512            fail
  513        )
  514    ;   Req == done
  515    ->  !,
  516        message_queue_destroy(JobQueue),
  517        thread_send_message(AnswerQueue, done(Me))
  518    ;   assertion(Req == all_done)
  519    ->  !,
  520        thread_send_message(AnswerQueue, done(Me))
  521    ).
  522
  523ca_generator(Gen, Templ, JobQueue, AnswerQueue) :-
  524    (   catch(Gen, E, true),
  525        (   var(E)
  526        ->  thread_send_message(JobQueue, job(Templ))
  527        ;   thread_send_message(AnswerQueue, error(E))
  528        ),
  529        fail
  530    ;   thread_send_message(JobQueue, done)
  531    ).
  532
  533ca_cleanup(GenThread, Workers, JobQueue, AnswerQueue) :-
  534    safe_abort(GenThread),
  535    safe_join(GenThread),
  536    maplist(safe_abort, Workers),
  537    maplist(safe_join, Workers),
  538    message_queue_destroy(AnswerQueue),
  539    catch(message_queue_destroy(JobQueue), error(_,_), true).
  540
  541
  542                   545
  562
  563concurrent_maplist(Goal, List) :-
  564    workers(List, WorkerCount),
  565    !,
  566    maplist(ml_goal(Goal), List, Goals),
  567    concurrent(WorkerCount, Goals, []).
  568concurrent_maplist(M:Goal, List) :-
  569    maplist(once_in_module(M, Goal), List).
  570
  571once_in_module(M, Goal, Arg) :-
  572    call(M:Goal, Arg), !.
  573
  574ml_goal(Goal, Elem, call(Goal, Elem)).
  575
  576concurrent_maplist(Goal, List1, List2) :-
  577    same_length(List1, List2),
  578    workers(List1, WorkerCount),
  579    !,
  580    maplist(ml_goal(Goal), List1, List2, Goals),
  581    concurrent(WorkerCount, Goals, []).
  582concurrent_maplist(M:Goal, List1, List2) :-
  583    maplist(once_in_module(M, Goal), List1, List2).
  584
  585once_in_module(M, Goal, Arg1, Arg2) :-
  586    call(M:Goal, Arg1, Arg2), !.
  587
  588ml_goal(Goal, Elem1, Elem2, call(Goal, Elem1, Elem2)).
  589
  590concurrent_maplist(Goal, List1, List2, List3) :-
  591    same_length(List1, List2, List3),
  592    workers(List1, WorkerCount),
  593    !,
  594    maplist(ml_goal(Goal), List1, List2, List3, Goals),
  595    concurrent(WorkerCount, Goals, []).
  596concurrent_maplist(M:Goal, List1, List2, List3) :-
  597    maplist(once_in_module(M, Goal), List1, List2, List3).
  598
  599once_in_module(M, Goal, Arg1, Arg2, Arg3) :-
  600    call(M:Goal, Arg1, Arg2, Arg3), !.
  601
  602ml_goal(Goal, Elem1, Elem2, Elem3, call(Goal, Elem1, Elem2, Elem3)).
  603
  604workers(List, Count) :-
  605    current_prolog_flag(cpu_count, Cores),
  606    Cores > 1,
  607    length(List, Len),
  608    Count is min(Cores,Len),
  609    Count > 1,
  610    !.
  611
  612same_length([], [], []).
  613same_length([_|T1], [_|T2], [_|T3]) :-
  614    same_length(T1, T2, T3).
  615
  616
  617                   620
  657
  658
  659first_solution(X, M:List, Options) :-
  660    message_queue_create(Done),
  661    thread_options(Options, ThreadOptions, RestOptions),
  662    length(List, JobCount),
  663    create_solvers(List, M, X, Done, Solvers, ThreadOptions),
  664    wait_for_one(JobCount, Done, Result, RestOptions),
  665    concur_cleanup(kill, Solvers, [Done]),
  666    (   Result = done(_, Var)
  667    ->  X = Var
  668    ;   Result = error(_, Error)
  669    ->  throw(Error)
  670    ).
  671
  672create_solvers([], _, _, _, [], _).
  673create_solvers([H|T], M, X, Done, [Id|IDs], Options) :-
  674    thread_create(solve(M:H, X, Done), Id, Options),
  675    create_solvers(T, M, X, Done, IDs, Options).
  676
  677solve(Goal, Var, Queue) :-
  678    thread_self(Me),
  679    (   catch(Goal, E, true)
  680    ->  (   var(E)
  681        ->  thread_send_message(Queue, done(Me, Var))
  682        ;   thread_send_message(Queue, error(Me, E))
  683        )
  684    ;   thread_send_message(Queue, failed(Me))
  685    ).
  686
  687wait_for_one(0, _, failed, _) :- !.
  688wait_for_one(JobCount, Queue, Result, Options) :-
  689    thread_get_message(Queue, Msg),
  690    LeftCount is JobCount - 1,
  691    (   Msg = done(_, _)
  692    ->  Result = Msg
  693    ;   Msg = failed(_)
  694    ->  (   option(on_fail(stop), Options, stop)
  695        ->  Result = Msg
  696        ;   wait_for_one(LeftCount, Queue, Result, Options)
  697        )
  698    ;   Msg = error(_, _)
  699    ->  (   option(on_error(stop), Options, stop)
  700        ->  Result = Msg
  701        ;   wait_for_one(LeftCount, Queue, Result, Options)
  702        )
  703    ).
  704
  705
  710
  711thread_options([], [], []).
  712thread_options([H|T], [H|Th], O) :-
  713    thread_option(H),
  714    !,
  715    thread_options(T, Th, O).
  716thread_options([H|T], Th, [H|O]) :-
  717    thread_options(T, Th, O).
  718
  719thread_option(local(_)).
  720thread_option(global(_)).
  721thread_option(trail(_)).
  722thread_option(argument(_)).
  723thread_option(stack(_)).
  724
  725
  736
  737call_in_thread(Thread, Goal) :-
  738    must_be(callable, Goal),
  739    var(Thread),
  740    !,
  741    instantiation_error(Thread).
  742call_in_thread(Thread, Goal) :-
  743    thread_self(Thread),
  744    !,
  745    once(Goal).
  746call_in_thread(Thread, Goal) :-
  747    term_variables(Goal, Vars),
  748    thread_self(Me),
  749    A is random(1 000 000 000),
  750    thread_signal(Thread, run_in_thread(Goal,Vars,A,Me)),
  751    catch(thread_get_message(in_thread(A,Result)),
  752          Error,
  753          forward_exception(Thread, A, Error)),
  754    (   Result = true(Vars)
  755    ->  true
  756    ;   Result = error(Error)
  757    ->  throw(Error)
  758    ;   fail
  759    ).
  760
  761run_in_thread(Goal, Vars, Id, Sender) :-
  762    (   catch_with_backtrace(call(Goal), Error, true)
  763    ->  (   var(Error)
  764        ->  thread_send_message(Sender, in_thread(Id, true(Vars)))
  765        ;   Error = stop(_)
  766        ->  true
  767        ;   thread_send_message(Sender, in_thread(Id, error(Error)))
  768        )
  769    ;   thread_send_message(Sender, in_thread(Id, false))
  770    ).
  771
  772forward_exception(Thread, Id, Error) :-
  773    kill_with(Error, Kill),
  774    thread_signal(Thread, kill_task(Id, Kill)),
  775    throw(Error).
  776
  777kill_with(time_limit_exceeded, stop(time_limit_exceeded)) :-
  778    !.
  779kill_with(_, stop(interrupt)).
  780
  781kill_task(Id, Exception) :-
  782    prolog_current_frame(Frame),
  783    prolog_frame_attribute(Frame, parent_goal,
  784                           run_in_thread(_Goal, _Vars, Id, _Sender)),
  785    !,
  786    throw(Exception).
  787kill_task(_, _)