Skip to content

Commit

Permalink
The authenticated_ fields are moved out of stubs and into
Browse files Browse the repository at this point in the history
ClientSession. Messages to the stubs are dispatched via
ClientSession, and the stub classes are pure virtual.

BUG=none
TEST=none

Review URL: http://codereview.chromium.org/6724033

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@79991 0039d316-1c4b-4281-b951-d872f2087c98
  • Loading branch information
simonmorris@chromium.org committed Mar 31, 2011
1 parent 22efa08 commit 4ea2c7c
Show file tree
Hide file tree
Showing 23 changed files with 329 additions and 328 deletions.
37 changes: 22 additions & 15 deletions remoting/host/chromoting_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "remoting/host/chromoting_host.h"

#include "base/bind.h"
#include "base/stl_util-inl.h"
#include "base/task.h"
#include "build/build_config.h"
Expand All @@ -19,6 +20,7 @@
#include "remoting/host/host_config.h"
#include "remoting/host/host_key_pair.h"
#include "remoting/host/screen_recorder.h"
#include "remoting/host/user_authenticator.h"
#include "remoting/proto/auth.pb.h"
#include "remoting/protocol/connection_to_client.h"
#include "remoting/protocol/client_stub.h"
Expand Down Expand Up @@ -180,13 +182,22 @@ void ChromotingHost::OnClientConnected(ConnectionToClient* connection) {
void ChromotingHost::OnClientDisconnected(ConnectionToClient* connection) {
DCHECK_EQ(context_->main_message_loop(), MessageLoop::current());

// Find the client session corresponding to the given connection.
std::vector<scoped_refptr<ClientSession> >::iterator client;
for (client = clients_.begin(); client != clients_.end(); ++client) {
if (client->get()->connection() == connection)
break;
}
if (client == clients_.end())
return;

// Remove the connection from the session manager and stop the session.
// TODO(hclam): Stop only if the last connection disconnected.
if (recorder_.get()) {
recorder_->RemoveConnection(connection);
// The recorder only exists to serve the unique authenticated client.
// If that client has disconnected, then we can kill the recorder.
if (connection->client_authenticated()) {
if (client->get()->authenticated()) {
recorder_->Stop(NULL);
recorder_ = NULL;
}
Expand All @@ -196,13 +207,8 @@ void ChromotingHost::OnClientDisconnected(ConnectionToClient* connection) {
connection->Disconnect();

// Also remove reference to ConnectionToClient from this object.
std::vector<scoped_refptr<ClientSession> >::iterator it;
for (it = clients_.begin(); it != clients_.end(); ++it) {
if (it->get()->connection() == connection) {
clients_.erase(it);
break;
}
}
clients_.erase(client);

if (!HasAuthenticatedClients())
EnableCurtainMode(false);
}
Expand Down Expand Up @@ -321,13 +327,16 @@ void ChromotingHost::OnNewClientSession(

// We accept the connection, so create a connection object.
ConnectionToClient* connection = new ConnectionToClient(
context_->network_message_loop(),
this,
desktop_environment_->input_stub());
context_->network_message_loop(), this);

// Create a client object.
ClientSession* client = new ClientSession(this, connection);
ClientSession* client = new ClientSession(
this,
base::Bind(UserAuthenticator::Create),
connection,
desktop_environment_->input_stub());
connection->set_host_stub(client);
connection->set_input_stub(client);

connection->Init(session);

Expand Down Expand Up @@ -377,7 +386,7 @@ std::string ChromotingHost::GenerateHostAuthToken(
bool ChromotingHost::HasAuthenticatedClients() const {
std::vector<scoped_refptr<ClientSession> >::const_iterator it;
for (it = clients_.begin(); it != clients_.end(); ++it) {
if (it->get()->connection()->client_authenticated())
if (it->get()->authenticated())
return true;
}
return false;
Expand Down Expand Up @@ -408,8 +417,6 @@ void ChromotingHost::LocalLoginSucceeded(
connection->client_stub()->BeginSessionResponse(
status, new DeleteTask<protocol::LocalLoginStatus>(status));

connection->OnClientAuthenticated();

// Disconnect all other clients.
// Iterate over a copy of the list of clients, to avoid mutating the list
// while iterating over it.
Expand Down
45 changes: 30 additions & 15 deletions remoting/host/chromoting_host_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "base/bind.h"
#include "base/scoped_ptr.h"
#include "base/task.h"
#include "remoting/host/capturer_fake.h"
#include "remoting/host/chromoting_host.h"
#include "remoting/host/chromoting_host_context.h"
#include "remoting/host/host_mock_objects.h"
#include "remoting/host/in_memory_host_config.h"
#include "remoting/host/user_authenticator_fake.h"
#include "remoting/proto/video.pb.h"
#include "remoting/protocol/protocol_mock_objects.h"
#include "remoting/protocol/session_config.h"
Expand Down Expand Up @@ -41,6 +43,10 @@ namespace remoting {

namespace {

UserAuthenticator* MakeUserAuthenticator() {
return new UserAuthenticatorFake();
}

void PostQuitTask(MessageLoop* message_loop) {
message_loop->PostTask(FROM_HERE, new MessageLoop::QuitTask());
}
Expand All @@ -56,6 +62,9 @@ ACTION_P(QuitMainMessageLoop, message_loop) {
PostQuitTask(message_loop);
}

void DummyDoneTask() {
}

} // namespace

class ChromotingHostTest : public testing::Test {
Expand Down Expand Up @@ -87,6 +96,12 @@ class ChromotingHostTest : public testing::Test {
DesktopEnvironment* desktop =
new DesktopEnvironment(capturer, input_stub_, curtain_);
host_ = ChromotingHost::Create(&context_, config_, desktop);
credentials_good_.set_type(protocol::PASSWORD);
credentials_good_.set_username("user");
credentials_good_.set_credential("password");
credentials_bad_.set_type(protocol::PASSWORD);
credentials_bad_.set_username(UserAuthenticatorFake::fail_username());
credentials_bad_.set_credential(UserAuthenticatorFake::fail_password());
connection_ = new MockConnectionToClient(
&message_loop_, &handler_, host_stub_, input_stub_);
connection2_ = new MockConnectionToClient(
Expand Down Expand Up @@ -143,8 +158,13 @@ class ChromotingHostTest : public testing::Test {
void SimulateClientConnection(int connection_index, bool authenticate) {
scoped_refptr<MockConnectionToClient> connection =
(connection_index == 0) ? connection_ : connection2_;
scoped_refptr<ClientSession> client = new ClientSession(host_.get(),
connection);
protocol::LocalLoginCredentials& credentials =
authenticate ? credentials_good_ : credentials_bad_;
scoped_refptr<ClientSession> client = new ClientSession(
host_.get(),
base::Bind(MakeUserAuthenticator),
connection,
input_stub_);
connection->set_host_stub(client.get());

context_.network_message_loop()->PostTask(
Expand All @@ -157,19 +177,12 @@ class ChromotingHostTest : public testing::Test {
NewRunnableMethod(host_.get(),
&ChromotingHost::OnClientConnected,
connection));
if (authenticate) {
context_.network_message_loop()->PostTask(
FROM_HERE,
NewRunnableMethod(host_.get(),
&ChromotingHost::LocalLoginSucceeded,
connection));
} else {
context_.network_message_loop()->PostTask(
FROM_HERE,
NewRunnableMethod(host_.get(),
&ChromotingHost::LocalLoginFailed,
connection));
}
context_.network_message_loop()->PostTask(
FROM_HERE,
NewRunnableMethod(client.get(),
&ClientSession::BeginSessionRequest,
&credentials,
NewRunnableFunction(&DummyDoneTask)));
}

// Helper method to remove a client connection from ChromotingHost.
Expand All @@ -187,6 +200,8 @@ class ChromotingHostTest : public testing::Test {
scoped_refptr<ChromotingHost> host_;
scoped_refptr<InMemoryHostConfig> config_;
MockChromotingHostContext context_;
protocol::LocalLoginCredentials credentials_good_;
protocol::LocalLoginCredentials credentials_bad_;
scoped_refptr<MockConnectionToClient> connection_;
scoped_refptr<MockSession> session_;
scoped_ptr<SessionConfig> session_config_;
Expand Down
48 changes: 37 additions & 11 deletions remoting/host/client_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,46 @@

#include "base/memory/scoped_ptr.h"
#include "base/task.h"
#include "media/base/callback.h"
#include "remoting/host/user_authenticator.h"
#include "remoting/proto/auth.pb.h"

namespace remoting {

ClientSession::ClientSession(
EventHandler* event_handler,
scoped_refptr<protocol::ConnectionToClient> connection)
const base::Callback<UserAuthenticatorFactory>& auth_factory,
scoped_refptr<protocol::ConnectionToClient> connection,
protocol::InputStub* input_stub)
: event_handler_(event_handler),
connection_(connection) {
auth_factory_(auth_factory),
connection_(connection),
input_stub_(input_stub),
authenticated_(false) {
}

ClientSession::~ClientSession() {
}

void ClientSession::SuggestResolution(
const protocol::SuggestResolutionRequest* msg, Task* done) {
done->Run();
delete done;
media::AutoTaskRunner done_runner(done);

if (!authenticated_) {
LOG(WARNING) << "Invalid control message received "
<< "(client not authenticated).";
return;
}
}

void ClientSession::BeginSessionRequest(
const protocol::LocalLoginCredentials* credentials, Task* done) {
DCHECK(event_handler_);

media::AutoTaskRunner done_runner(done);

bool success = false;
scoped_ptr<UserAuthenticator> authenticator(UserAuthenticator::Create());
scoped_ptr<UserAuthenticator> authenticator(auth_factory_.Run());
switch (credentials->type()) {
case protocol::PASSWORD:
success = authenticator->Authenticate(credentials->username(),
Expand All @@ -45,22 +58,35 @@ void ClientSession::BeginSessionRequest(
}

if (success) {
authenticated_ = true;
event_handler_->LocalLoginSucceeded(connection_.get());
} else {
LOG(WARNING) << "Login failed for user " << credentials->username();
event_handler_->LocalLoginFailed(connection_.get());
}
}

void ClientSession::InjectKeyEvent(const protocol::KeyEvent* event,
Task* done) {
media::AutoTaskRunner done_runner(done);
if (authenticated_) {
done_runner.release();
input_stub_->InjectKeyEvent(event, done);
}
}

done->Run();
delete done;
void ClientSession::InjectMouseEvent(const protocol::MouseEvent* event,
Task* done) {
media::AutoTaskRunner done_runner(done);
if (authenticated_) {
done_runner.release();
input_stub_->InjectMouseEvent(event, done);
}
}

void ClientSession::Disconnect() {
connection_->Disconnect();
}

protocol::ConnectionToClient* ClientSession::connection() const {
return connection_.get();
authenticated_ = false;
}

} // namespace remoting
38 changes: 33 additions & 5 deletions remoting/host/client_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@

#include "remoting/protocol/connection_to_client.h"
#include "remoting/protocol/host_stub.h"
#include "remoting/protocol/input_stub.h"

namespace remoting {

class UserAuthenticator;

// A ClientSession keeps a reference to a connection to a client, and maintains
// per-client state.
class ClientSession : public protocol::HostStub,
public protocol::InputStub,
public base::RefCountedThreadSafe<ClientSession> {
public:
// Callback interface for passing events to the ChromotingHost.
Expand All @@ -30,28 +34,52 @@ class ClientSession : public protocol::HostStub,
scoped_refptr<protocol::ConnectionToClient> client) = 0;
};

typedef UserAuthenticator* UserAuthenticatorFactory();

ClientSession(EventHandler* event_handler,
scoped_refptr<protocol::ConnectionToClient> connection);
const base::Callback<UserAuthenticatorFactory>& auth_factory,
scoped_refptr<protocol::ConnectionToClient> connection,
protocol::InputStub* input_stub);

// protocol::HostStub interface.
virtual void SuggestResolution(
const protocol::SuggestResolutionRequest* msg, Task* done);
virtual void BeginSessionRequest(
const protocol::LocalLoginCredentials* credentials, Task* done);

// protocol::InputStub interface.
virtual void InjectKeyEvent(const protocol::KeyEvent* event, Task* done);
virtual void InjectMouseEvent(const protocol::MouseEvent* event, Task* done);

// Disconnect this client session.
void Disconnect();

protocol::ConnectionToClient* connection() const;
protocol::ConnectionToClient* connection() const {
return connection_.get();
}

protected:
friend class base::RefCountedThreadSafe<ClientSession>;
~ClientSession();
bool authenticated() const {
return authenticated_;
}

private:
friend class base::RefCountedThreadSafe<ClientSession>;
virtual ~ClientSession();

EventHandler* event_handler_;

// A factory for user authenticators.
base::Callback<UserAuthenticatorFactory> auth_factory_;

// The connection to the client.
scoped_refptr<protocol::ConnectionToClient> connection_;

// The input stub to which this object delegates.
protocol::InputStub* input_stub_;

// Whether this client is authenticated.
bool authenticated_;

DISALLOW_COPY_AND_ASSIGN(ClientSession);
};

Expand Down
Loading

0 comments on commit 4ea2c7c

Please sign in to comment.