Inteligencia Artificial de Privacidad

Ver en GitHub

Federated Learning: Regression using the California Housing Database

In this notebook, we explain how you can use a federated learning environment to create a regression model. In the notebook on Linear regression for a simple 2D case, we explained the basic concepts of the framework, so now we will go slightly faster.

First, we load a dataset (included in the framework) to allow for regression experiments.

import shfl
from shfl.data_base.california_housing import CaliforniaHousing

database = CaliforniaHousing()
train_data, train_labels, test_data, test_labels = database.load_data()

Now, we are going to explore the data:

print("Shape of train_data: " + str(train_data.shape))
print("Shape of train_labels: " + str(train_labels.shape))
print("One sample features: " + str(train_data[0]))
print("One sample label: " + str(train_labels[0]))
Shape of train_data: (18576, 8)
Shape of train_labels: (18576,)
One sample features: [   2.5125       52.            3.72315036    1.07637232  700.
    1.67064439   37.97       -122.53      ]
One sample label: 2.708

Federated data generation:

import shfl

iid_distribution = shfl.data_distribution.IidDataDistribution(database)
federated_data, test_data, test_label = iid_distribution.get_federated_data(20, percent=10)

Model definition:

import tensorflow as tf

def model_builder():
    # create model
    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Dense(8, input_dim=8, kernel_initializer='normal', activation='relu'))
    model.add(tf.keras.layers.Dense(1, kernel_initializer='normal'))

    # Compile model
    model.compile(loss='mean_squared_error', optimizer='adam', metrics=["mae"])

    return shfl.model.DeepLearningModel(model)

Federated environment definition:

aggregator = shfl.federated_aggregator.FedAvgAggregator()
federated_government = shfl.federated_government.FederatedGovernment(model_builder, federated_data, aggregator)

Reshaping data:

import numpy as np

class Reshape(shfl.private.FederatedTransformation):

    def apply(self, labeled_data):
        labeled_data.label = np.reshape(labeled_data.label, (labeled_data.label.shape[0], 1))

shfl.private.federated_operation.apply_federated_transformation(federated_data, Reshape())

Running experiment:

test_label = np.reshape(test_label, (test_label.shape[0], 1))
federated_government.run_rounds(3, test_data, test_label)
Accuracy round 0
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461d0f90>: [6.560263633728027, 1.6206369400024414]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461d0910>: [5.745161533355713, 1.5526326894760132]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de110>: [5.817258358001709, 1.5580724477767944]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de350>: [5.603631496429443, 1.5391887426376343]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de490>: [5.4023847579956055, 1.5247150659561157]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de6d0>: [5.4312872886657715, 1.5270318984985352]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de750>: [5.415994644165039, 1.5255640745162964]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de890>: [5.308160781860352, 1.5183627605438232]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de9d0>: [5.170998573303223, 1.5088306665420532]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461deb50>: [5.722560882568359, 1.5498887300491333]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461dec10>: [6.297583103179932, 1.595190405845642]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461ded50>: [5.452840805053711, 1.532146692276001]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461dee90>: [5.520835876464844, 1.5359090566635132]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb050>: [5.8125081062316895, 1.5577442646026611]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb150>: [5.746515274047852, 1.5522164106369019]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb290>: [5.301519393920898, 1.516774296760559]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb3d0>: [5.9085493087768555, 1.5644185543060303]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb510>: [5.828071594238281, 1.562639832496643]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb650>: [6.0064849853515625, 1.5701179504394531]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb790>: [5.4278645515441895, 1.526447057723999]
Global model test performance : [5.653794288635254, 1.544110655784607]



Accuracy round 1
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461d0f90>: [4.112930774688721, 1.5580612421035767]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461d0910>: [4.463381290435791, 1.6896424293518066]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de110>: [3.9861090183258057, 1.5091578960418701]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de350>: [4.219106674194336, 1.614454746246338]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de490>: [4.825565814971924, 1.801005482673645]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de6d0>: [4.3899126052856445, 1.6652277708053589]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de750>: [4.8324809074401855, 1.8052619695663452]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de890>: [5.00560998916626, 1.8548837900161743]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de9d0>: [4.911076545715332, 1.8276656866073608]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461deb50>: [4.271508693695068, 1.6307674646377563]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461dec10>: [3.9934544563293457, 1.5257854461669922]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461ded50>: [4.3297529220581055, 1.6517047882080078]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461dee90>: [4.423217296600342, 1.6778076887130737]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb050>: [4.382129192352295, 1.6510546207427979]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb150>: [4.557219505310059, 1.720855712890625]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb290>: [4.635447025299072, 1.7467052936553955]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb3d0>: [4.2572221755981445, 1.6271306276321411]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb510>: [4.53870964050293, 1.7143478393554688]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb650>: [4.028736114501953, 1.5350831747055054]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb790>: [4.317218780517578, 1.6463879346847534]
Global model test performance : [4.372098922729492, 1.6620105504989624]



Accuracy round 2
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461d0f90>: [4.585392951965332, 1.7460705041885376]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461d0910>: [4.7851057052612305, 1.8058794736862183]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de110>: [3.8755931854248047, 1.479520559310913]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de350>: [4.618406772613525, 1.7583948373794556]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de490>: [6.018651962280273, 2.142995595932007]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de6d0>: [4.499324798583984, 1.7165062427520752]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de750>: [6.057710647583008, 2.155027151107788]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de890>: [7.687643051147461, 2.517160177230835]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461de9d0>: [5.424689292907715, 1.988004446029663]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461deb50>: [4.975666046142578, 1.86579430103302]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461dec10>: [4.975329875946045, 1.8689584732055664]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461ded50>: [4.6139912605285645, 1.756758213043213]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461dee90>: [4.604976654052734, 1.7495594024658203]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb050>: [6.7682576179504395, 2.3205671310424805]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb150>: [5.436490535736084, 1.9939947128295898]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb290>: [4.94906759262085, 1.8556327819824219]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb3d0>: [5.0063700675964355, 1.8748692274093628]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb510>: [6.369729042053223, 2.2321643829345703]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb650>: [4.999528408050537, 1.8746261596679688]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x1461eb790>: [4.472737789154053, 1.71150803565979]
Global model test performance : [5.083588600158691, 1.8952934741973877]