open Viz_misc

let init =
  Giochannel.init ; Gspawn.init

let debug fmt =
  Printf.kprintf 
    (if Viz_misc.debug "spawn" 
    then (fun s -> Printf.eprintf "### spawn: %s\n%!" s)
    else ignore)
    fmt

type encoding = [ `NONE | `LOCALE | `CHARSET of string ]


let setup_channel ~nonblock encoding fd =
  let chan = Giochannel.new_fd (some fd) in
  if nonblock then Giochannel.set_flags_noerr chan [`NONBLOCK] ;
  begin
    match encoding with
    | `NONE -> 
	Giochannel.set_encoding chan None ;
	Giochannel.set_buffered chan false
    | `LOCALE -> 
	let (is_utf8, charset) = Glib.Convert.get_charset () in
	if not is_utf8 
	then Giochannel.set_encoding chan (Some charset)
    | `CHARSET charset ->
	Giochannel.set_encoding chan (Some charset)
  end ;
  chan

let all_done_cb ~nb cb =
  let count = ref nb in
  fun () ->
    decr count ;
    if !count = 0
    then cb ()

type watch = { 
    mutable finished : bool ;
    name : string ;
    chan : Giochannel.t ;
    exn_cb  : exn -> unit ;
    done_cb : unit -> unit ;
  }

let stop_watch w =
  w.finished <- true ;
  debug "%s cb: closing pipe" w.name ;
  try Giochannel.shutdown w.chan true 
  with Giochannel.Error (_, msg) | Glib.Convert.Error (_, msg) ->
    debug "%s cb: error closing pipe %s" w.name msg

let reset_watch w continue =
  if not continue
  then begin
    stop_watch w ;
    w.done_cb ()
  end ;
  continue

let in_channel_watch w input =
  let input_pos = ref 0 in
  let callback conditions =
    debug "stdin  cb: %d left in buffer" (String.length input - !input_pos) ;

    let to_write = String.length input - !input_pos in
    let do_write = ref (to_write > 0 && List.mem `OUT conditions) in

    if !do_write
    then begin
      let bytes_written = ref 0 in
      try
	match Giochannel.write_chars w.chan ~bytes_written ~off:!input_pos input with
	| `NORMAL written ->
	    debug "stdin  cb: wrote %d" written ;
	    input_pos := !input_pos + written
	| `AGAIN -> 
	    debug "stdin  cb: EAGAIN ?"
      with 
      | Giochannel.Error (_, msg)
      | Glib.Convert.Error (_, msg) as exn ->
	  w.exn_cb exn ;
	  debug "stdin  cb: error %s, wrote %d" msg !bytes_written ; 
	  do_write := false
    end ;

    reset_watch w !do_write in

  Giochannel.add_watch w.chan [ `OUT ; `HUP ; `ERR ] callback


let out_channel_watch w b =
  let sb = String.create 4096 in
  let callback conditions =
    let need_to_read = ref (List.mem `IN conditions) in

    if !need_to_read 
    then begin
      try 
	match Giochannel.read_chars w.chan sb with
	| `NORMAL read ->
	    debug "%s cb: read %d" w.name read ;
	    Buffer.add_substring b sb 0 read 
	| `EOF ->
	    debug "%s cb: eof" w.name ;
	    need_to_read := false
	| `AGAIN ->
 	    debug "%s cb: AGAIN" w.name
      with
      | Giochannel.Error (_, msg)
      | Glib.Convert.Error (_, msg) as exn ->
	  w.exn_cb exn ;
	  debug "%s cb: error %s" w.name msg ; 
	  need_to_read := false
    end ;

    reset_watch w !need_to_read in

  Giochannel.add_watch w.chan [ `IN ; `HUP ; `ERR ] callback

let pid_watch pid callback =
  let callback status =
    debug "child %d exiting, status %d" (Gspawn.int_of_pid pid) status ;
    callback status ; () in
  Gspawn.add_child_watch pid callback


type t = {
    mutable watches : (watch * Giochannel.source_id) list ;
    mutable aborted : bool ;
    mutable status  : int ;
  }

let spawn ~encoding ~cmd ~input:input_opt ~reap_callback done_callback =
  if Viz_misc.debug "exec"
  then Printf.eprintf "### exec: Running '%s'\n%!" (String.concat " " cmd) ;
  let has_input = input_opt <> None in
  let spawn_flags = 
    [ `PIPE_STDOUT ; `PIPE_STDERR ;
      `SEARCH_PATH ; `DO_NOT_REAP_CHILD ] in
  
  let child_info =
    Gspawn.async_with_pipes 
      (if has_input then `PIPE_STDIN :: spawn_flags else spawn_flags)
      cmd in

  let state = { watches = [] ; aborted = false ; status = -1 } in

  let out_buffer = Buffer.create 4096 in
  let err_buffer = Buffer.create 1024 in
  let exn_list = ref [] in

  let all_done = 
    all_done_cb 
      ~nb:(if has_input then 4 else 3)
      (fun () -> 
	if not state.aborted 
	then
	  done_callback 
	    ~exceptions:!exn_list
	    ~stdout:(Buffer.contents out_buffer) 
	    ~stderr:(Buffer.contents err_buffer) state.status) in

  let exn_cb exn = 
    exn_list := exn :: !exn_list in

  let add_watch w id =
    state.watches <- (w, id) :: state.watches in

  if has_input then begin
    let ic = setup_channel ~nonblock:true encoding child_info.Gspawn.standard_input in
    let in_watch = { name = "stdin" ; finished = false ; chan = ic ; 
		     exn_cb = exn_cb ; done_cb = all_done } in
    let in_id = in_channel_watch in_watch (some input_opt) in
    add_watch in_watch in_id
  end ;

  begin
    let oc = setup_channel ~nonblock:false encoding child_info.Gspawn.standard_output in
    let out_watch = { name = "stdout" ; finished = false ; chan = oc ;
		      exn_cb = exn_cb ; done_cb = all_done } in
    let out_id = out_channel_watch out_watch out_buffer in
    add_watch out_watch out_id
  end ; 

  begin
    let ec = setup_channel ~nonblock:false encoding child_info.Gspawn.standard_error in
    let err_watch = { name = "stderr" ; finished = false ; chan = ec ;
		      exn_cb = exn_cb ; done_cb = all_done } in
    let err_id = out_channel_watch err_watch err_buffer in
    add_watch err_watch err_id
  end ; 

  let pid = some child_info.Gspawn.pid in
  ignore (pid_watch pid 
	    (fun s -> 
	      state.status <- s ; 
	      begin
		try reap_callback () 
		with _ -> ()
	      end ;
	      Gspawn.close_pid pid ;
	      all_done ())) ;

  state

type callback = 
  exceptions:exn list -> 
  stdout:string -> 
  stderr:string -> 
  int -> unit

let abort sub_data =
  if not sub_data.aborted then begin
    sub_data.aborted <- true ;
    List.iter 
      (fun (w, id) ->
	if not w.finished then begin
	  Giochannel.remove_watch id ;
	  stop_watch w
	end)
      sub_data.watches
  end
