# bzr-avahi - share and browse Bazaar branches with mDNS
# Copyright (C) 2007-2008 James Henstridge
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import avahi
import dbus
import dbus.lowlevel
import gobject

from bzrlib.tests import (
    TestCase, TestCaseWithMemoryTransport, TestCaseWithTransport)
from bzrlib.urlutils import strip_trailing_slash

from bzrlib.plugins.avahi.advertise import (
    BranchInfo, ServerInfo, Advertiser, get_mdns_advertise, get_mdns_name,
    send_change_notification, set_mdns_advertise, set_mdns_name)


class UtilitiesTests(TestCaseWithMemoryTransport):

    def test_get_mdns_name(self):
        b = self.make_branch('testbranch')
        # With no configuration, defaults to the branch location
        self.assertEqual(get_mdns_name(b), 'testbranch')
        # ... or nickname if set
        b.nick = 'nickname'
        self.assertEqual(get_mdns_name(b), 'nickname')
        # ... or the mdns-name config option:
        b.get_config().set_user_option('mdns-name', 'advertised-name')
        self.assertEqual(get_mdns_name(b), 'advertised-name')

    def test_set_mdns_name(self):
        b = self.make_branch('testbranch')
        set_mdns_name(b, 'advertised-name')
        self.assertEqual(get_mdns_name(b), 'advertised-name')

    def test_set_mdns_name_hash(self):
        # Test that hashes in the name get preserved.
        # Currently the hashes are being manually escaped to make this work :(
        b = self.make_branch('testbranch')
        set_mdns_name(b, 'foo #2')
        self.assertEqual(get_mdns_name(b), 'foo #2')

    def test_get_mdns_advertise(self):
        b = self.make_branch('testbranch')
        self.assertEqual(get_mdns_advertise(b), False)
        b.get_config().set_user_option('mdns-advertise', 'invalid')
        self.assertEqual(get_mdns_advertise(b), False)
        b.get_config().set_user_option('mdns-advertise', 'True')
        self.assertEqual(get_mdns_advertise(b), True)

    def test_set_mdns_advertise(self):
        b = self.make_branch('testbranch')
        set_mdns_advertise(b, True)
        self.assertEqual(get_mdns_advertise(b), True)
        set_mdns_advertise(b, False)
        self.assertEqual(get_mdns_advertise(b), False)

    def test_send_change_notification(self):
        class FakeBus:
            messages = []
            def send_message(self, message):
                self.messages.append(message)
        b = self.make_branch('testbranch')
        send_change_notification(b, _bus_factory=FakeBus)
        self.assertEqual(len(FakeBus.messages), 1)
        message = FakeBus.messages[0]
        self.assertTrue(isinstance(message, dbus.lowlevel.SignalMessage))
        self.assertEqual(message.get_path(), '/')
        self.assertEqual(message.get_interface(),
                         'org.bazaar_vcs.plugins.avahi.Notify')
        self.assertEqual(message.get_member(), 'BranchStateChanged')
        self.assertEqual(message.get_signature(), 's')
        self.assertEqual(message.get_args_list(), [b.base])


# XXX: We should be able to use TestCaseWithMemoryTransport here, but
# that is broken: https://bugs.launchpad.net/bugs/188855
class ServerInfoTests(TestCaseWithTransport):

    def test_scan_branches(self):
        b1 = self.make_branch('.')
        b2 = self.make_branch('b2')
        b3 = self.make_branch('b3')
        b4 = self.make_branch('b3/b4')
        b5 = self.make_branch('b3/b4/b5')

        class MyServerInfo(ServerInfo):
            branches = []
            def handle_branch(self, branch):
                self.branches.append(branch.base)
        info = MyServerInfo(None, self.get_transport().base,
                            'http://example.com/')
        info.scan_branches()
        self.assertEqual(len(info.branches), 5)
        self.assertEqual(info.branches[0], b1.base)
        self.assertEqual(info.branches[1], b2.base)
        self.assertEqual(info.branches[2], b3.base)
        self.assertEqual(info.branches[3], b4.base)
        self.assertEqual(info.branches[4], b5.base)

    def test_handle_branch(self):
        class FakeAdvertiser:
            def __init__(self):
                self.branches = []
                self.removed_branches = []
            def add_branch(self, server_info, branch, public_loc):
                self.branches.append(
                    (public_loc, strip_trailing_slash(branch.base),
                     server_info))
            def remove_branch(self, branch):
                self.removed_branches.append(strip_trailing_slash(branch.base))
        advertiser = FakeAdvertiser()
        info = ServerInfo(advertiser, self.get_transport().base,
                          'http://example.com/')

        # Handling a non-advertised branch tells the advertiser to
        # remove it:
        b = self.make_branch('testbranch')
        set_mdns_advertise(b, False)
        info.handle_branch(b)
        self.assertEqual(len(advertiser.branches), 0)
        self.assertEqual(len(advertiser.removed_branches), 1)
        self.assertEqual(advertiser.removed_branches[0],
                         self.get_transport().abspath('testbranch'))
        del advertiser.removed_branches[:]

        # Handling an advertised branch passes it to the advertiser
        # with the public URL:
        set_mdns_advertise(b, True)
        info.handle_branch(b)
        self.assertEqual(len(advertiser.branches), 1)
        self.assertEqual(len(advertiser.removed_branches), 0)
        self.assertEqual(advertiser.branches[0][0],
                         'http://example.com/testbranch/')
        self.assertEqual(advertiser.branches[0][1],
                         self.get_transport().abspath('testbranch'))
        self.assertEqual(advertiser.branches[0][2], info)

    def test_maybe_handle_branch(self):
        # The maybe_handle_branch() method returns True if the branch
        # is covered by this server, calling handle_branch().
        class MyServerInfo(ServerInfo):
            branches = []
            def handle_branch(self, branch):
                self.branches.append(branch.base)
        info = MyServerInfo(None, self.get_transport().abspath('public'),
                            'http://example.com/')
        t = self.get_transport()
        b1 = self.make_branch('private')
        self.assertEqual(info.maybe_handle_branch(b1), False)
        self.assertEqual(len(info.branches), 0)

        t.mkdir('public')
        b2 = self.make_branch('public/branch')
        self.assertEqual(info.maybe_handle_branch(b2), True)
        self.assertEqual(len(info.branches), 1)
        self.assertEqual(info.branches[0], b2.base)


class BranchInfoTests(TestCaseWithMemoryTransport):

    def test_get_name(self):
        branch = self.make_branch('b1')
        info = BranchInfo(None, None, branch, None)
        self.assertEqual(info.get_name(), 'b1')

        branch.nick = 'nickname'
        self.assertEqual(info.get_name(), 'nickname')

        branch.get_config().set_user_option('mdns-name', 'advertised name')
        self.assertEqual(info.get_name(), 'advertised name')

        branch.get_config().set_user_option('mdns-name', 'branch \\x232')
        self.assertEqual(info.get_name(), 'branch #2')

    def test_handle_name_conflict(self):
        class FakeAdvertiser:
            def __init__(self):
                self.avahi = self
                self.old_names = []
            def GetAlternativeServiceName(self, name):
                self.old_names.append(name)
                return 'replacement-name'
        advertiser = FakeAdvertiser()
        branch = self.make_branch('b1')
        info = BranchInfo(advertiser, None, branch, None)
        self.assertEqual(info.get_name(), 'b1')
        info.handle_name_conflict()
        self.assertEqual(info.get_name(), 'replacement-name')
        self.assertEqual(branch.get_config().get_user_option('mdns-name'),
                         'replacement-name')
        self.assertEqual(advertiser.old_names, ['b1'])

    def test_add_service(self):
        # Test that add_service() correctly builds the service group,
        # retrying for name conflicts.
        class FakeAdvertiser:
            def __init__(self):
                self.avahi = self
                self.old_names = []
            def GetAlternativeServiceName(self, name):
                self.old_names.append(name)
                return 'replacement-name'
        class FakeGroup:
            def __init__(self):
                self.calls = []
            def IsEmpty(self):
                self.calls.append(('IsEmpty',))
                return True
            def Commit(self):
                self.calls.append(('Commit',))
            def AddService(self, interface, protocol, flags, name, type,
                           domain, host, port, txt):
                self.calls.append(('AddService', interface, protocol,
                                   flags, name, type, domain, host, port,
                                   avahi.txt_array_to_string_array(txt)))
                if name == 'b1':
                    raise dbus.DBusException(
                        name='org.freedesktop.Avahi.CollisionError')
        class MyBranchInfo(BranchInfo):
            def make_group(self):
                assert self.group is None
                self.group = FakeGroup()
        advertiser = FakeAdvertiser()
        branch = self.make_branch('b1')
        info = MyBranchInfo(advertiser, None, branch, 'bzr://0.0.0.0:4155/b1')
        info.add_service()
        self.assertNotEqual(info.group, None)
        self.assertEqual(len(info.group.calls), 4)
        self.assertEqual(info.group.calls[0], ('IsEmpty',))
        self.assertEqual(info.group.calls[1],
                         ('AddService', avahi.IF_UNSPEC, avahi.PROTO_UNSPEC,
                          dbus.UInt32(0), 'b1', '_bzr._tcp', '', '',
                          dbus.UInt16(4155), ['path=/b1', 'scheme=bzr']))
        self.assertEqual(info.group.calls[2],
                         ('AddService', avahi.IF_UNSPEC, avahi.PROTO_UNSPEC,
                          dbus.UInt32(0), 'replacement-name', '_bzr._tcp', '',
                          '', dbus.UInt16(4155), ['path=/b1', 'scheme=bzr']))
        self.assertEqual(info.group.calls[3], ('Commit',))
        self.assertEqual(advertiser.old_names, ['b1'])

    def test_state_changed(self):
        class MyBranchInfo(BranchInfo):
            def __init__(self):
                self.calls = []
            def handle_name_conflict(self):
                self.calls.append('handle_name_conflict')
            def remove_service(self):
                self.calls.append('remove_service')
            def add_service(self):
                self.calls.append('add_service')
        # On collision, handle the conflict then re-create the service.
        bi = MyBranchInfo()
        bi.state_changed(avahi.ENTRY_GROUP_COLLISION)
        self.assertEqual(len(bi.calls), 3)
        self.assertEqual(bi.calls[0], 'handle_name_conflict')
        self.assertEqual(bi.calls[1], 'remove_service')
        self.assertEqual(bi.calls[2], 'add_service')


class AdvertiserTests(TestCaseWithMemoryTransport):

    def test_branch_state_changed_new_branch(self):
        # Test that unknown branches are handed to each ServerInfo object.
        b = self.make_branch('testbranch')
        class MyServerInfo:
            def __init__(self):
                self.branches = []
                self.should_handle_branch = False
            def maybe_handle_branch(self, branch):
                self.branches.append(branch)
                return self.should_handle_branch
        advertiser = Advertiser()
        advertiser.servers['server1'] = MyServerInfo()
        advertiser.servers['server2'] = MyServerInfo()
        advertiser.branch_state_changed(b.base)
        branches = advertiser.servers['server1'].branches
        self.assertEqual(len(branches), 1)
        self.assertEqual(branches[0].base, b.base)
        branches = advertiser.servers['server2'].branches
        self.assertEqual(len(branches), 1)
        self.assertEqual(branches[0].base, b.base)

        # If the advertisers instead say they handle the branch, only
        # one instance should be passed the branch:
        del advertiser.servers['server1'].branches[:]
        del advertiser.servers['server2'].branches[:]
        advertiser.servers['server1'].should_handle_branch = True
        advertiser.servers['server2'].should_handle_branch = True
        advertiser.branch_state_changed(b.base)
        self.assertEqual(len(advertiser.servers['server1'].branches +
                             advertiser.servers['server2'].branches), 1)
