diff --git a/src/mad/ruby/ssh_stream.rb b/src/mad/ruby/ssh_stream.rb index 06d9aba021..453fb2c4a4 100644 --- a/src/mad/ruby/ssh_stream.rb +++ b/src/mad/ruby/ssh_stream.rb @@ -34,9 +34,10 @@ class SshStream # # # - def initialize(host, shell="bash") - @host = host - @shell = shell + def initialize(host, shell = "bash", timeout = nil) + @host = host + @shell = shell + @timeout = timeout end def opened? @@ -48,7 +49,10 @@ class SshStream end def open - @stdin, @stdout, @stderr=Open3::popen3("#{SSH_CMD} #{@host} #{@shell} -s ; echo #{SSH_RC_STR} $? 1>&2") + @stdin, @stdout, @stderr, @wait_thr = Open3::popen3( + "#{SSH_CMD} #{@host} #{@shell} -s ; echo #{SSH_RC_STR} $? 1>&2", + :pgroup => true + ) @stream_out = "" @stream_err = "" @@ -74,6 +78,15 @@ class SshStream @alive = false end + def kill(pid) + # executed processes now have its own process group to be able + # to kill all children + pgid = Process.getpgid(pid) + + # Kill all processes belonging to process group + Process.kill("HUP", pgid * -1) + end + def exec(command) return if ! @alive @@ -100,48 +113,60 @@ class SshStream def wait_for_command done_out = false done_err = false + time_start = Time.now.to_i code = -1 while not (done_out and done_err ) and @alive - rc, rw, re= IO.select([@stdout, @stderr],[],[]) + rc, rw, re= IO.select([@stdout, @stderr],[],[], 1) - rc.each { |fd| - begin - c = fd.read_nonblock(100) - next if !c - rescue #rescue from EOF if ssh command finishes and closes fds - next - end - - if fd == @stdout - @out << c - done_out = true if @out.slice!("#{EOF_OUT}\n") - else - @err << c - - tmp = @err.scan(/^#{SSH_RC_STR}(\d+)$/) - - if tmp[0] - message = "Error connecting to #{@host}" - code = tmp[0][0].to_i - - @err << OpenNebula.format_error_message(message) - - @alive = false - break + if rc + rc.each { |fd| + begin + c = fd.read_nonblock(100) + next if !c + rescue #rescue from EOF if ssh command finishes and + # closes fds + next end - tmp = @err.scan(/^#{RC_STR}(\d*) #{EOF_ERR}\n/) + if fd == @stdout + @out << c + done_out = true if @out.slice!("#{EOF_OUT}\n") + else + @err << c - if tmp[0] - code = tmp[0][0].to_i - done_err = true + tmp = @err.scan(/^#{SSH_RC_STR}(\d+)$/) - @err.slice!(" #{EOF_ERR}\n") + if tmp[0] + message = "Error connecting to #{@host}" + code = tmp[0][0].to_i + + @err << OpenNebula.format_error_message(message) + + @alive = false + break + end + + tmp = @err.scan(/^#{RC_STR}(\d*) #{EOF_ERR}\n/) + + if tmp[0] + code = tmp[0][0].to_i + done_err = true + + @err.slice!(" #{EOF_ERR}\n") + end end - end - } + } + end + + if @timeout && Time.now.to_i - time_start > @timeout + @err << "\nTimeout Error" + self.close + self.kill(@wait_thr.pid) + + break + end end @stream_out << @out @@ -158,11 +183,11 @@ end class SshStreamCommand < RemotesCommand - def initialize(host, remote_dir, logger=nil, stdin=nil, shell='bash') - super('true', host, logger, stdin) + def initialize(host, remote_dir, logger=nil, stdin=nil, shell='bash', timeout=nil) + super('true', host, logger, stdin, timeout) @remote_dir = remote_dir - @stream = SshStream.new(host, shell) + @stream = SshStream.new(host, shell, timeout) end def run(command, stdin=nil, base_cmd = nil)