diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d8a4f1af300a3e41801e0af2f811a7626da5d11c..1afecc94f2c11675c66a4ac9ad6c0f18367a31bc 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,7 +1,7 @@ include: - project: 'just-ci/templates' file: 'templates/container/python.yml' - ref: 'v5.2.0' + ref: 'v5.4.0' - project: 'just-ci/templates' file: 'project-automation/badge.yml' @@ -9,14 +9,6 @@ variables: PYTHON_PACKAGE: 'dgad' KANIKO_EXTRA_ARGS: "--use-new-run --single-snapshot" -dgad-cli: - stage: test - script: - - which dgad - - dgad --domain wikipedia.org - - dgad --domains wikipedia.org sjdkahflaksdjhf.net - - dgad --csv tests/data/domains_todo.csv - kaniko:redis-worker: extends: .kaniko variables: @@ -24,6 +16,10 @@ kaniko:redis-worker: KANIKO_CONTEXT: ${CI_PROJECT_DIR}/redis-worker KANIKO_DOCKERFILE: ${CI_PROJECT_DIR}/redis-worker/Dockerfile +python:pytest: + variables: + DEFAULT_ARGS: "-vvv --color=yes" + # FIXME python:mypy: allow_failure: true @@ -89,4 +85,4 @@ badge:codestyle: LABEL: "codestyle" VALUE: "black" COLOR: "black" - URL: "${CI_PROJECT_URL}" \ No newline at end of file + URL: "${CI_PROJECT_URL}" diff --git a/Dockerfile b/Dockerfile index 63c27f300acadc9199badcd9fcdc918653f3279c..3258201d94ba563655ed9f494486f0b10f84e61b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,9 +3,9 @@ FROM python:3.9-slim WORKDIR /project COPY pyproject.toml ./ COPY dgad/ dgad/ -RUN pip install . +RUN pip --disable-pip-version-check install --no-compile . RUN dgad --help ENV TF_CPP_MIN_LOG_LEVEL=3 ENTRYPOINT [ "dgad"] -CMD [ "-h" ] +CMD [ "--help" ] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..833025db53f90d0e92819f924f57fe66152d41d6 --- /dev/null +++ b/Makefile @@ -0,0 +1,13 @@ +test: + python dgad/cli.py -ft csv -f tests/data/domains_todo.csv + python dgad/cli.py -f tests/data/domains_todo.csv + python dgad/cli.py -ft jsonl -f tests/data/domains_todo.jsonl + cat tests/data/domains_todo.csv | python dgad/cli.py -ft csv -f - + cat tests/data/domains_todo.jsonl | python dgad/cli.py -ft jsonl -f - + +clean: + black . + isort --profile=black . + +protoc: + python -m grpc_tools.protoc -I protos --python_out=dgad/grpc --grpc_python_out=dgad/grpc prediction.proto diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..ce80dfece67bc0aeb3854d5d6acc2e898792656e --- /dev/null +++ b/demo.py @@ -0,0 +1,11 @@ +from dgad.prediction import Detective +from dgad.utils import pretty_print + +mydomains = ["adslkfjhsakldjfhasdlkf.com"] +detective = Detective() +# convert mydomains strings into dgad.schema.Domain +mydomains, _ = detective.prepare_domains(mydomains) +# classify them +detective.investigate(mydomains) +# view result, drops padded_token_vector for pretty printing +pretty_print(mydomains, output_format="json") diff --git a/dgad/app/cli.py b/dgad/app/cli.py deleted file mode 100644 index 805068f77110b424bbbc3e73236d69c4455a14d8..0000000000000000000000000000000000000000 --- a/dgad/app/cli.py +++ /dev/null @@ -1,76 +0,0 @@ -import argparse -import logging -import os -from importlib import resources -from pathlib import Path - -import dgad.models -from dgad.classification import TCNClassifier -from dgad.utils import create_domains_dataframe - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - - -def classify(args: argparse.Namespace) -> None: - classifier = TCNClassifier() - if not args.model: - with resources.path(dgad.models, "tcn_best.h5") as model_path: - classifier.load_keras_model(filepath=model_path) - else: - classifier.load_keras_model(filepath=args.model) - if args.domains: - classified_dataframe = classifier.classify_domains_in_dataframe( - dataframe=create_domains_dataframe(domains=args.domains) - ) - if args.csv: - classified_dataframe = classifier.classify_domains_in_csv(csv_filepath=args.csv) - logging.info( - "classified %s domains from csv file %s", - len(classified_dataframe), - args.csv, - ) - if not args.quiet: - logging.critical(classified_dataframe) - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument( - "--domains", - help="space separated list of 1 or more domains you want DGA detective to classify", - nargs="*", - metavar="DOMAIN", - type=str, - required=False, - ) - parser.add_argument( - "--model", - help="the hdf5 keras model file to pass to the classifier", - type=Path, - required=False, - ) - parser.add_argument( - "--csv", - help="csv file containing the domains to classify. This file must have a column 'domain'. The classification will be stored in the same file under a column 'classification'", - type=Path, - required=False, - ) - parser.add_argument( - "-q", "--quiet", help="disables stdout", action="store_true", required=False - ) - return parser - - -def main() -> None: - logging.basicConfig( - format="%(levelname)s: %(message)s", - level=os.environ.get(key="LOG_LEVEL", default="ERROR").upper(), - ) - parser = setup_parser() - args = parser.parse_args() - if args: - classify(args) - - -if __name__ == "__main__": - main() diff --git a/dgad/classification.py b/dgad/classification.py deleted file mode 100644 index a0d74e2d013337d014fc26a0e7c95e1aa5e89825..0000000000000000000000000000000000000000 --- a/dgad/classification.py +++ /dev/null @@ -1,221 +0,0 @@ -import logging -import os -import sys -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Any, List - -import numpy as np -import pandas as pd -import tensorflow -import tensorflow.keras as keras -from tcn import TCN -from tensorflow.keras.callbacks import ModelCheckpoint -from tensorflow.keras.layers import LSTM, Activation, Dense, Dropout, Embedding, Input -from tensorflow.keras.models import Model, Sequential - -from dgad.data_model import Domain, get_all_padded_token_vectors, set_all_predictions -from dgad.utils import create_characters_dictionary, separate_domains_that_are_too_long - - -class GenericClassifier(ABC): - """ - Abstract base class for the specialised classifiers - """ - - def __init__( - self, - optimizer: str, - loss: str = "binary_crossentropy", - abnormal_label: str = "DGA", - normal_label: str = "ok", - model: Any = None, - ): - self.normal_label = normal_label - self.abnormal_label = abnormal_label - self.model = model - self.loss = loss - self.optimizer = optimizer - self.characters_dictionary = create_characters_dictionary() - self.labels_dictionary = {1: self.abnormal_label, 0: self.normal_label} - - @abstractmethod - def initialise_keras_model(self, x_train: np.ndarray) -> None: # pragma: no cover - pass - - def train_keras_model( - self, - x_train: np.ndarray, - y_train: np.ndarray, - checkpoints_directory: Path, - epochs: int = 5, - save_best_only: bool = True, - ) -> keras.callbacks.History: - checkpoint = ModelCheckpoint( - filepath=os.path.join( - checkpoints_directory, "checkpoint-{epoch:02d}-{loss:.2f}.hdf5" - ), - monitor="loss", - verbose=0, - save_best_only=save_best_only, - mode="max", - ) - callbacks_list = [checkpoint] - return self.model.fit( - x_train, - y_train, - batch_size=16, - epochs=epochs, - callbacks=callbacks_list, - shuffle=True, - ) - - def load_keras_model(self, filepath: Path) -> Any: - self.model = tensorflow.keras.models.load_model(filepath=filepath) - self.model.compile(loss=self.loss, optimizer=self.optimizer) - - def __predict_binary_labels__(self, x_test: np.ndarray) -> np.ndarray: - return (self.model.predict(x_test) > 0.5).astype("int32") - - def __classify_domains_binary__(self, domains: List[Domain]) -> None: - x_test = get_all_padded_token_vectors(domains=domains) - predicted_labels = self.__predict_binary_labels__(x_test=x_test) - set_all_predictions(domains=domains, predictions=predicted_labels) - for domain in domains: - domain.update_label() - - def __label_domains__( - self, - domains: List[Domain], - domains_too_long: List[Domain], - binary_classification: bool = True, - ) -> List[Domain]: - if binary_classification: - self.__classify_domains_binary__(domains=domains) - for domain in domains_too_long: - domain.binary_label = "N/A - domain too long for model" - return domains - - def __classify_dataframe__(self, dataframe: pd.DataFrame) -> pd.DataFrame: - raw_domains = dataframe["domain"] - raw_domains_todo, raw_domains_too_long = separate_domains_that_are_too_long( - raw_domains, self.model.input_shape[1] - ) - domains_todo = [ - Domain( - raw_domain, - self.model.input_shape[1], - self.characters_dictionary, - self.labels_dictionary, - ) - for raw_domain in raw_domains_todo - ] - domains_too_long = [ - Domain( - raw_domain, - self.model.input_shape[1], - self.characters_dictionary, - self.labels_dictionary, - ) - for raw_domain in raw_domains_too_long - ] - self.__label_domains__( - domains=domains_todo, - domains_too_long=domains_too_long, - binary_classification=True, - ) - dataframe["classification"] = [domain.binary_label for domain in domains_todo] - return dataframe - - def classify_domains_in_dataframe(self, dataframe: pd.DataFrame) -> pd.DataFrame: - """ - high level method to interact with - """ - if "domain" not in dataframe.columns: - logging.critical( - "the dataframe does not contain the required column 'domain'!" - ) - sys.exit() - if self.model: - return self.__classify_dataframe__(dataframe) - logging.critical(msg="can not perform classification without a model!") - sys.exit(1) - - def classify_raw_domains(self, raw_domains: List[str]) -> List[Domain]: - domains = [ - Domain( - raw=raw_domain, - padded_length=self.model.input_shape[1], - characters_dictionary=self.characters_dictionary, - labels_dictionary=self.labels_dictionary, - ) - for raw_domain in raw_domains - ] - self.__classify_domains_binary__(domains=domains) - for domain in domains: - domain.update_label() - return domains - - def classify_domains_in_csv(self, csv_filepath: Path) -> pd.DataFrame: - dataframe = pd.read_csv(csv_filepath) - classified_dataframe = self.classify_domains_in_dataframe(dataframe=dataframe) - classified_dataframe.to_csv(csv_filepath, index=False) - return dataframe - - -class LSTMClassifier(GenericClassifier): - """ - Specialised LSTM implementation of GenericClassifier - """ - - def __init__(self, optimizer: str = "rmsprop"): - super().__init__(optimizer) - - def initialise_keras_model(self, x_train: np.ndarray) -> None: - max_features = 1024 - max_domain_length = x_train.shape[1] - model = Sequential() - model.add(Embedding(max_features, 128, input_length=max_domain_length)) - model.add(LSTM(128)) - model.add(Dropout(0.5)) - model.add(Dense(1)) - model.add(Activation("sigmoid")) - model.compile(loss=self.loss, optimizer=self.optimizer) - self.model = model - - -class TCNClassifier(GenericClassifier): - """ - Specialised TCN implementation of GenericClassifier - """ - - def __init__(self, optimizer: str = "adam"): - super().__init__(optimizer) - - def initialise_keras_model(self, x_train: np.ndarray) -> None: - max_features = 1024 - max_domain_length = x_train.shape[1] - i = Input(batch_shape=(None, x_train.shape[1])) - # embedding layer to map the input to from 2D to 3D tensor - e = Embedding(max_features, 128, input_length=max_domain_length)(i) - o = TCN( - nb_filters=8, - kernel_size=4, - nb_stacks=1, - dilations=[1, 2, 4, 8, 16, 32], - padding="same", - use_skip_connections=True, - return_sequences=False, - )( - e - ) # The TCN layers are here. - o = Dense(1)(o) - o = Activation("sigmoid")(o) - self.model = Model(inputs=[i], outputs=[o]) - self.model.compile(loss=self.loss, optimizer=self.optimizer) - - def load_keras_model(self, filepath: Path) -> Any: - self.model = tensorflow.keras.models.load_model( - filepath=filepath, custom_objects={"TCN": TCN} - ) - self.model.compile(loss=self.loss, optimizer=self.optimizer) diff --git a/dgad/cli.py b/dgad/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0d852efec70b566576243320d4b05e1b999cda --- /dev/null +++ b/dgad/cli.py @@ -0,0 +1,215 @@ +# pylint: disable-all + +import logging +import sys +from importlib import resources + +import click +import pandas as pd + +import dgad.label_encoders +import dgad.models +from dgad.grpc.api import DGADClient, DGADServer +from dgad.prediction import Detective, Model +from dgad.utils import load_labels, log_analysis, pretty_print, setup_logging + + +def validate_families_number(ctx, param, value): + allowed = [52, 81] + if value not in allowed: + raise click.BadParameter(f"must be one of {allowed}") + else: + return value + + +def validate_file_format(ctx, param, value): + allowed = ["csv", "jsonl"] + if value not in allowed: + raise click.BadParameter(f"must be one of {allowed}") + else: + return value + + +def input_domains_from_cli_filepath_or_buf( + input_filepath_or_buf, format, domains_column +): + df = pd.DataFrame() + if format == "csv": + df = pd.read_csv(input_filepath_or_buf, index_col=False) + elif format == "jsonl": + df = pd.read_json(input_filepath_or_buf, lines=True) + if not df.empty: + try: + domains_set = set(df[domains_column].tolist()) + except KeyError: + logging.critical( + "you must have a {domains_column} column in your csv/jsonl file" + ) + sys.exit(-1) + return domains_set + + +def load_multi_class_model(n_families: int) -> Model: + with resources.path( + dgad.models, f"tcn_family_{n_families}_classes.h5" + ) as model_path: + with resources.path( + dgad.label_encoders, f"encoder_{n_families}_classes.npy" + ) as labels_path: + model_multi = Model(filepath=model_path, labels=load_labels(labels_path)) + return model_multi + + +def analyse_domains_remotely(dgad_client, domains): + responses = [dgad_client.requests(domain) for domain in domains] + pretty_print(domains=responses) + + +@click.group() +def cli(): + """ + DGA Detective can predict if a domain name has been generated by a Domain Generation Algorithm + """ + + +@cli.command() +@click.option( + "-d", "--domain", type=str, multiple=True, help="the domain(s) you want to check" +) +@click.option( + "-f", + "--input-filepath-or-buf", + type=click.File("rb"), + required=False, + help="file containing the domains to check, can be piped through stdin from another command. When using this, you must specify a format with either --csv or --jsonl", +) +# TODO: what if I want to pass a url to a file in object storage? +# TODO: what if I want to pass a regex to handle multiple files? +@click.option( + "-fmt", "--format", type=str, default="csv", callback=validate_file_format +) +@click.option( + "-dc", + "--domains_column", + type=str, + default="domain", + help="the name of the column that contains the domains in your file", +) +@click.option( + "-n", + "--families-number", + type=int, + default=81, + help="dgad comes with two trained models for family classification. They have been trained with examples from 52 and 81 families, respectively. This option allows you to choose the model by specifying the number of families in the model (thus this can be either 52 or 81)", + callback=validate_families_number, +) +@click.option( + "-r", + "--remote-analysis", + is_flag=True, + default=False, + help="send domains to remote DGAD server for analysis (instead of performing it locally)", +) +@click.option( + "-h", + "--remote-host", + default="localhost", + type=str, + help="remote DGA Detective hostname/ip", +) +@click.option( + "-p", + "--remote-port", + type=int, + default=4714, + help="remote DGA Detective port", +) +def client( + domain, + input_filepath_or_buf, + format, + domains_column, + families_number, + remote_analysis, + remote_host, + remote_port, +): + """ + classify domains from cli args or csv/jsonl files + """ + domains_set = {} + # 1. input + if domain: + domains_set = set(domain) + elif input_filepath_or_buf and format: + domains_set = input_domains_from_cli_filepath_or_buf( + input_filepath_or_buf, format, domains_column + ) + # 2. classification + if domains_set: + if remote_analysis: + dgad_client = DGADClient(host=remote_host, port=remote_port) + analyse_domains_remotely( + dgad_client=dgad_client, + domains=domains_set, + ) + else: + # 2a. load a specific family model + if families_number: + multi_class_model = load_multi_class_model(n_families=families_number) + detective = Detective(model_multi=multi_class_model) + else: # 2b or use defaults + detective = Detective() + # 3. run + domains, _ = detective.prepare_domains(raw_domains=domains_set) + detective.investigate(domains) + # 4. output + pretty_print(domains) + + +@click.option( + "-v", + "--verbosity", + type=str, + default="WARNING", + help="sets log level, uses python logging module so you can pass strings like DEBUG, CRITICAL...", +) +@click.option( + "-p", + "--port", + type=int, + default=4714, + help="DGAD grpc api will listen at this port", +) +@click.option( + "-n", + "--families-number", + type=int, + default=81, + help="dgad comes with two trained models for family classification. They have been trained with examples from 52 and 81 families, respectively. This option allows you to choose the model by specifying the number of families in the model (thus this can be either 52 or 81)", + callback=validate_families_number, +) +@click.option( + "-w", + "--max-workers", + type=int, + default=10, + help="maximum amount of threads the grpc thread can spawn to handle incoming requests", +) +@cli.command() +def server(verbosity, port, families_number, max_workers): + """ + deploy a DGA Detective server + """ + setup_logging(level=verbosity) + if families_number: + multi_class_model = load_multi_class_model(n_families=families_number) + detective = Detective(model_multi=multi_class_model) + else: + detective = Detective() + server = DGADServer(detective=detective, port=port, max_workers=max_workers) + server.bootstrap() + + +if __name__ == "__main__": + cli() diff --git a/dgad/data_model.py b/dgad/data_model.py deleted file mode 100644 index 8d7a86ea4791e9218418eac86b461c5efaee92c9..0000000000000000000000000000000000000000 --- a/dgad/data_model.py +++ /dev/null @@ -1,130 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, List, Sequence, Tuple - -import numpy as np - -import dgad.utils as utils - - -@dataclass -class Word: - """ - A word is the smallest unit that can be reasonably classified as produced by a DGA. - A subdomain is just a word. - """ - - name: str - characters_dictionary: Dict[str, int] - padded_length: int - padded_token_vector = [0] - binary_prediction: int = 0 - - def __post_init__(self) -> None: - """ - word preprocessing - """ - sanitised_name = utils.strip_forbidden_characters( - word=self.name, characters_dictionary=self.characters_dictionary - ) - token_vector = utils.tokenize_word( - word=sanitised_name, characters_dictionary=self.characters_dictionary - ) - self.padded_token_vector = utils.pad_vector( - vector=token_vector, desired_length=self.padded_length - ) - - -class Domain: - """ - RFC 1035 Domain Name. - Can have 0 or more subdomains, which are Words. - """ - - domain_name: Word - - def __init__( - self, - raw: str, - padded_length: int, - characters_dictionary: Dict[str, int], - labels_dictionary: Dict[int, str], - ): - self.raw = raw - self.labels_dictionary = labels_dictionary - raw_domain_name, list_raw_subdomains = utils.extract_domain_name_and_subdomains( - self.raw - ) - self.domain_name = Word(raw_domain_name, characters_dictionary, padded_length) - self.subdomains = [ - Word( - name=subdomain, - characters_dictionary=characters_dictionary, - padded_length=padded_length, - ) - for subdomain in list_raw_subdomains - ] - self.binary_label: str = "" - - def update_label(self) -> None: - binary_prediction = self.__get_overall_prediction__() - self.binary_label = self.__get_human_label__(binary_prediction) - - def __get_overall_prediction__(self) -> int: - if self.domain_name.binary_prediction == 1: - return 1 - for subdomain in self.subdomains: - if subdomain.binary_prediction == 1: - return 1 - else: - return 0 - - def __get_human_label__(self, binary_prediction: int) -> str: - return self.labels_dictionary[binary_prediction] - - def __hash__(self) -> int: - return hash(self.raw) - - -def get_all_subdomains(domains: List[Domain]) -> List[Word]: - """ - returns a list of all the subdomains from the provided Domains - """ - subdomains: List[Word] = [] - for domain in domains: - for subdomain in domain.subdomains: - subdomains.append(subdomain) - return subdomains - - -def get_all_padded_token_vectors(domains: List[Domain]) -> np.ndarray: - """ - returns a list of all the padded token vectors from the provided Domains - """ - subdomains = get_all_subdomains(domains=domains) - domains_vectors = [domain.domain_name.padded_token_vector for domain in domains] - subdomains_vectors = [subdomain.padded_token_vector for subdomain in subdomains] - return np.array(domains_vectors + subdomains_vectors) - - -def set_all_words_predictions(predictions: np.ndarray, words: Sequence[Word]) -> None: - """ - for each word, stores the prediction from the array to - the binary_prediction attribute of the Word - """ - for index, prediction in enumerate(predictions): - word = words[index] - word.binary_prediction = prediction[0] - - -def set_all_predictions(domains: List[Domain], predictions: np.ndarray) -> None: - """ - for each domain, stores the prediction from the array to - all the domain names and subdomains - """ - domain_names = [domain.domain_name for domain in domains] - subdomains = get_all_subdomains(domains=domains) - predictions_domain_names, predictions_subdomains = np.split( - predictions, [len(domains)] - ) - set_all_words_predictions(predictions=predictions_domain_names, words=domain_names) - set_all_words_predictions(predictions=predictions_subdomains, words=subdomains) diff --git a/dgad/grpc/api.py b/dgad/grpc/api.py new file mode 100644 index 0000000000000000000000000000000000000000..6c68dce9b83127ffbaf5cc328ae980d099284b83 --- /dev/null +++ b/dgad/grpc/api.py @@ -0,0 +1,100 @@ +# type: ignore +# pylint: disable-all + +import logging +import time +import uuid +from concurrent import futures + +import grpc + +from dgad.grpc import prediction_pb2, prediction_pb2_grpc +from dgad.prediction import Detective +from dgad.schema import Domain, Word +from dgad.utils import log_performance + + +def unpack(response) -> Domain: + domain = Domain( + raw=response.fqdn, is_dga=response.is_dga, family_label=response.family + ) + words = [ + Word( + value=word.value, + binary_score=word.binary_score, + binary_label=word.binary_label, + family_score=word.family_score, + family_label=word.family_label, + ) + for word in response.words + ] + domain.words = words + return domain + + +def pack(domain: Domain): + words = [ + prediction_pb2.Word( + value=word.value, + binary_score=word.binary_score, + binary_label=word.binary_label, + family_score=word.family_score, + family_label=word.family_label, + ) + for word in domain.words + ] + return prediction_pb2.Domain( + fqdn=domain.raw, + is_dga=domain.is_dga, + family=domain.family_label, + words=words, + ) + + +class Classifier(prediction_pb2_grpc.Classifier): + def __init__(self, detective: Detective): + self.detective = detective + self.counter = 0 + self.start_time = time.time() + self.id = uuid.uuid4() + logging.warning(f"started dga detective classifier {self.id}") + + def GetClassification(self, request, context): + raw_domains = [request.fqdn] + domains, _ = self.detective.prepare_domains(raw_domains) + self.detective.investigate(domains=domains) + domain = domains[0] + self.counter += 1 + if self.counter % 100 == 0: + log_performance(counter=self.counter, start_time=self.start_time) + return pack(domain) + + +class DGADServer: + def __init__(self, detective: Detective, port: int, max_workers: int): + self.detective = detective + self.port = port + self.max_workers = max_workers + + def bootstrap(self): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=self.max_workers)) + prediction_pb2_grpc.add_ClassifierServicer_to_server( + Classifier(self.detective), server + ) + server.add_insecure_port(f"[::]:{self.port}") + server.start() + server.wait_for_termination() + + +class DGADClient: + def __init__(self, host: str, port: int): + self.host = host + self.port = port + + def requests(self, domain: str): + with grpc.insecure_channel(f"{self.host}:{self.port}") as channel: + stub = prediction_pb2_grpc.ClassifierStub(channel) + response = stub.GetClassification( + prediction_pb2.Domain(fqdn=domain), wait_for_ready=True + ) + return unpack(response) diff --git a/dgad/grpc/classification_pb2.py b/dgad/grpc/classification_pb2.py deleted file mode 100644 index 41c92424cee5626873686d555efa1666ad8c1fb6..0000000000000000000000000000000000000000 --- a/dgad/grpc/classification_pb2.py +++ /dev/null @@ -1,180 +0,0 @@ -# type: ignore - -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: classification.proto -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database - -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor.FileDescriptor( - name="classification.proto", - package="", - syntax="proto3", - serialized_options=None, - create_key=_descriptor._internal_create_key, - serialized_pb=b'\n\x14\x63lassification.proto"\x16\n\x06\x44omain\x12\x0c\n\x04\x66qdn\x18\x01 \x01(\t"=\n\x0e\x43lassification\x12\x0c\n\x04\x66qdn\x18\x01 \x01(\t\x12\x1d\n\x15\x62inary_classification\x18\x02 \x01(\t2=\n\nClassifier\x12/\n\x11GetClassification\x12\x07.Domain\x1a\x0f.Classification"\x00\x62\x06proto3', -) - - -_DOMAIN = _descriptor.Descriptor( - name="Domain", - full_name="Domain", - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name="fqdn", - full_name="Domain.fqdn", - index=0, - number=1, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=b"".decode("utf-8"), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - serialized_options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=24, - serialized_end=46, -) - - -_CLASSIFICATION = _descriptor.Descriptor( - name="Classification", - full_name="Classification", - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name="fqdn", - full_name="Classification.fqdn", - index=0, - number=1, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=b"".decode("utf-8"), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - _descriptor.FieldDescriptor( - name="binary_classification", - full_name="Classification.binary_classification", - index=1, - number=2, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=b"".decode("utf-8"), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - serialized_options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=48, - serialized_end=109, -) - -DESCRIPTOR.message_types_by_name["Domain"] = _DOMAIN -DESCRIPTOR.message_types_by_name["Classification"] = _CLASSIFICATION -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -Domain = _reflection.GeneratedProtocolMessageType( - "Domain", - (_message.Message,), - { - "DESCRIPTOR": _DOMAIN, - "__module__": "classification_pb2" - # @@protoc_insertion_point(class_scope:Domain) - }, -) -_sym_db.RegisterMessage(Domain) - -Classification = _reflection.GeneratedProtocolMessageType( - "Classification", - (_message.Message,), - { - "DESCRIPTOR": _CLASSIFICATION, - "__module__": "classification_pb2" - # @@protoc_insertion_point(class_scope:Classification) - }, -) -_sym_db.RegisterMessage(Classification) - - -_CLASSIFIER = _descriptor.ServiceDescriptor( - name="Classifier", - full_name="Classifier", - file=DESCRIPTOR, - index=0, - serialized_options=None, - create_key=_descriptor._internal_create_key, - serialized_start=111, - serialized_end=172, - methods=[ - _descriptor.MethodDescriptor( - name="GetClassification", - full_name="Classifier.GetClassification", - index=0, - containing_service=None, - input_type=_DOMAIN, - output_type=_CLASSIFICATION, - serialized_options=None, - create_key=_descriptor._internal_create_key, - ), - ], -) -_sym_db.RegisterServiceDescriptor(_CLASSIFIER) - -DESCRIPTOR.services_by_name["Classifier"] = _CLASSIFIER - -# @@protoc_insertion_point(module_scope) diff --git a/dgad/grpc/classifier_client.py b/dgad/grpc/classifier_client.py deleted file mode 100644 index 6fa144f0f5743bc5aa12362189da55ff9ee62687..0000000000000000000000000000000000000000 --- a/dgad/grpc/classifier_client.py +++ /dev/null @@ -1,38 +0,0 @@ -# type: ignore - -import logging -import os -import random -import string - -import grpc - -from dgad.grpc import classification_pb2, classification_pb2_grpc - - -def run(): - - domains = [] - amount = int(os.environ.get("AMOUNT", 10000)) - for _ in range(amount): - domain = "".join( - random.choice(string.ascii_lowercase) for _ in range(20) # nosec - ) - domains.append(domain + ".com") - logging.critical("created %s random domains", amount) - - host = os.environ.get("GRPC_HOST", "localhost") - port = os.environ.get("GRPC_PORT", "50054") - - with grpc.insecure_channel(f"{host}:{port}") as channel: - for domain in domains: - stub = classification_pb2_grpc.ClassifierStub(channel) - response = stub.GetClassification( - classification_pb2.Domain(fqdn=domain), wait_for_ready=True - ) - logging.critical("%s %s", response.fqdn, response.binary_classification) - - -if __name__ == "__main__": - logging.basicConfig() - run() diff --git a/dgad/grpc/classifier_server.py b/dgad/grpc/classifier_server.py deleted file mode 100644 index d4d304eda92f1c8c330d0dea6c8e50d5a673a8c3..0000000000000000000000000000000000000000 --- a/dgad/grpc/classifier_server.py +++ /dev/null @@ -1,63 +0,0 @@ -# type: ignore - -import logging -import os -import random -import string -import time -from concurrent import futures -from importlib import resources - -import grpc - -import dgad.models -from dgad.classification import TCNClassifier -from dgad.grpc import classification_pb2, classification_pb2_grpc - - -class Classifier(classification_pb2_grpc.Classifier): - def __init__(self): - self.classifier = TCNClassifier() - with resources.path(dgad.models, "tcn_best.h5") as model_path: - self.classifier.load_keras_model(filepath=model_path) - self.counter = 0 - self.start_interval_time = time.time() - self.interval_size = 100 - self.id = "".join(random.choice(string.digits) for _ in range(5)) # nosec - logging.critical("started dga detective classifier %s", self.id) - - def GetClassification(self, request, context): - classified_domain = self.classifier.classify_raw_domains( - raw_domains=[request.fqdn] - )[0] - self.counter += 1 - if self.counter % self.interval_size == 0: - logging.info( - "%s: classified %s domains in %s", - self.id, - self.interval_size, - time.time() - self.start_interval_time, - ) - self.start_interval_time = time.time() - return classification_pb2.Classification( - fqdn=request.fqdn, - binary_classification=classified_domain.binary_label, - ) - - -def serve(): - port = os.environ.get("LISTENING_PORT", 50054) - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - classification_pb2_grpc.add_ClassifierServicer_to_server(Classifier(), server) - server.add_insecure_port(f"[::]:{port}") - server.start() - server.wait_for_termination() - - -if __name__ == "__main__": - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - logging.basicConfig( - format="%(levelname)s: %(message)s", - level=os.environ.get(key="LOG_LEVEL", default="ERROR").upper(), - ) - serve() diff --git a/dgad/grpc/prediction_pb2.py b/dgad/grpc/prediction_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..a7065678ec1c38675344cb3b8b82075d1f7fe9a0 --- /dev/null +++ b/dgad/grpc/prediction_pb2.py @@ -0,0 +1,58 @@ +# type: ignore +# pylint: disable-all + +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: prediction.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x10prediction.proto"m\n\x04Word\x12\r\n\x05value\x18\x01 \x01(\t\x12\x14\n\x0c\x62inary_score\x18\x02 \x01(\x02\x12\x14\n\x0c\x62inary_label\x18\x03 \x01(\t\x12\x14\n\x0c\x66\x61mily_score\x18\x04 \x01(\x02\x12\x14\n\x0c\x66\x61mily_label\x18\x05 \x01(\t"L\n\x06\x44omain\x12\x0c\n\x04\x66qdn\x18\x01 \x01(\t\x12\x0e\n\x06is_dga\x18\x02 \x01(\x08\x12\x0e\n\x06\x66\x61mily\x18\x03 \x01(\t\x12\x14\n\x05words\x18\x04 \x03(\x0b\x32\x05.Word25\n\nClassifier\x12\'\n\x11GetClassification\x12\x07.Domain\x1a\x07.Domain"\x00\x62\x06proto3' +) + + +_WORD = DESCRIPTOR.message_types_by_name["Word"] +_DOMAIN = DESCRIPTOR.message_types_by_name["Domain"] +Word = _reflection.GeneratedProtocolMessageType( + "Word", + (_message.Message,), + { + "DESCRIPTOR": _WORD, + "__module__": "prediction_pb2" + # @@protoc_insertion_point(class_scope:Word) + }, +) +_sym_db.RegisterMessage(Word) + +Domain = _reflection.GeneratedProtocolMessageType( + "Domain", + (_message.Message,), + { + "DESCRIPTOR": _DOMAIN, + "__module__": "prediction_pb2" + # @@protoc_insertion_point(class_scope:Domain) + }, +) +_sym_db.RegisterMessage(Domain) + +_CLASSIFIER = DESCRIPTOR.services_by_name["Classifier"] +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _WORD._serialized_start = 20 + _WORD._serialized_end = 129 + _DOMAIN._serialized_start = 131 + _DOMAIN._serialized_end = 207 + _CLASSIFIER._serialized_start = 209 + _CLASSIFIER._serialized_end = 262 +# @@protoc_insertion_point(module_scope) diff --git a/dgad/grpc/classification_pb2_grpc.py b/dgad/grpc/prediction_pb2_grpc.py similarity index 80% rename from dgad/grpc/classification_pb2_grpc.py rename to dgad/grpc/prediction_pb2_grpc.py index 6ff6c6083450c07c32f777be3d9e5f2e0e3fca58..623e21ca47d2e25ca7fd2264fa93072b7bad5f4b 100644 --- a/dgad/grpc/classification_pb2_grpc.py +++ b/dgad/grpc/prediction_pb2_grpc.py @@ -1,10 +1,11 @@ # type: ignore +# pylint: disable-all # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc -import dgad.grpc.classification_pb2 as classification__pb2 +import dgad.grpc.prediction_pb2 as prediction__pb2 class ClassifierStub(object): @@ -18,8 +19,8 @@ class ClassifierStub(object): """ self.GetClassification = channel.unary_unary( "/Classifier/GetClassification", - request_serializer=classification__pb2.Domain.SerializeToString, - response_deserializer=classification__pb2.Classification.FromString, + request_serializer=prediction__pb2.Domain.SerializeToString, + response_deserializer=prediction__pb2.Domain.FromString, ) @@ -37,8 +38,8 @@ def add_ClassifierServicer_to_server(servicer, server): rpc_method_handlers = { "GetClassification": grpc.unary_unary_rpc_method_handler( servicer.GetClassification, - request_deserializer=classification__pb2.Domain.FromString, - response_serializer=classification__pb2.Classification.SerializeToString, + request_deserializer=prediction__pb2.Domain.FromString, + response_serializer=prediction__pb2.Domain.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -68,8 +69,8 @@ class Classifier(object): request, target, "/Classifier/GetClassification", - classification__pb2.Domain.SerializeToString, - classification__pb2.Classification.FromString, + prediction__pb2.Domain.SerializeToString, + prediction__pb2.Domain.FromString, options, channel_credentials, insecure, diff --git a/dgad/grpc/protos/classification.proto b/dgad/grpc/protos/classification.proto deleted file mode 100644 index 81007c8201a6aade7b3f886e90c2b5d280234459..0000000000000000000000000000000000000000 --- a/dgad/grpc/protos/classification.proto +++ /dev/null @@ -1,14 +0,0 @@ -syntax = "proto3"; - -service Classifier { - rpc GetClassification(Domain) returns (Classification) {} -} - -message Domain { - string fqdn = 1; -} - -message Classification { - string fqdn = 1; - string binary_classification = 2; -} diff --git a/dgad/app/__init__.py b/dgad/label_encoders/__init__.py similarity index 100% rename from dgad/app/__init__.py rename to dgad/label_encoders/__init__.py diff --git a/dgad/label_encoders/encoder_52_classes.npy b/dgad/label_encoders/encoder_52_classes.npy new file mode 100644 index 0000000000000000000000000000000000000000..0957f3113fedbaa2e843b480879bf41f6c43cf70 Binary files /dev/null and b/dgad/label_encoders/encoder_52_classes.npy differ diff --git a/dgad/label_encoders/encoder_81_classes.npy b/dgad/label_encoders/encoder_81_classes.npy new file mode 100644 index 0000000000000000000000000000000000000000..0fef6e96b9f563032f3615a4d4df5acbe3d09a05 Binary files /dev/null and b/dgad/label_encoders/encoder_81_classes.npy differ diff --git a/dgad/models/tcn_family_52_classes.h5 b/dgad/models/tcn_family_52_classes.h5 new file mode 100644 index 0000000000000000000000000000000000000000..99ec8a0f16d66a6d398a2777e5a7a0d7e0911e46 Binary files /dev/null and b/dgad/models/tcn_family_52_classes.h5 differ diff --git a/dgad/models/tcn_family_81_classes.h5 b/dgad/models/tcn_family_81_classes.h5 new file mode 100644 index 0000000000000000000000000000000000000000..ee312ea4a598b422805862d006cba9fc66206af5 Binary files /dev/null and b/dgad/models/tcn_family_81_classes.h5 differ diff --git a/dgad/prediction.py b/dgad/prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..e1239e025f7022e8de39c0511a35fdc4f30751d4 --- /dev/null +++ b/dgad/prediction.py @@ -0,0 +1,119 @@ +import logging +import os +from dataclasses import dataclass, field +from importlib import resources +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import numpy as np +import tensorflow +from tcn import TCN + +import dgad.label_encoders +import dgad.models +from dgad.schema import Domain, Word +from dgad.utils import load_labels, log_analysis, separate_domains_that_are_too_long + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + + +def default_binary_labels(): + return {0: "benign", 1: "DGA"} + + +def default_custom_objects(): + return {} + + +@dataclass +class Model: + filepath: Path + data: Any = None + labels: Dict[int, str] = field(default_factory=default_binary_labels) + optimizer: str = "adam" + loss: str = "binary_crossentropy" + custom_objects: Dict[str, str] = field(default_factory=default_custom_objects) + + def __post_init__(self): + self.data = tensorflow.keras.models.load_model( + filepath=self.filepath, custom_objects={"TCN": TCN} + ) + self.data.compile(loss=self.loss, optimizer=self.optimizer) + + +class Detective: + def __init__(self, model_binary: Model = None, model_multi: Model = None) -> None: + # use included binary model if one is not provided + if model_binary: + self.model_binary = model_binary + else: + with resources.path(dgad.models, "tcn_best.h5") as model_path: + self.model_binary = Model( + filepath=model_path, custom_objects={"TCN": TCN} + ) + # use included family model if one is not provided + if model_multi: + self.model_multi = model_multi + else: + with resources.path(dgad.models, "tcn_family_81_classes.h5") as model_path: + with resources.path( + dgad.label_encoders, "encoder_81_classes.npy" + ) as labels_path: + self.model_multi = Model( + filepath=model_path, + labels=load_labels(labels_path), + custom_objects={"TCN": TCN}, + ) + + def prepare_domains( + self, raw_domains: List[str], max_length: int = 0 + ) -> Tuple[List[Domain], List[str]]: + """ + preprocesses the domains, tokenizing and applying padding to have same size as binary model + """ + # TODO: padding may be different for the multi model...! That could lead to issues. Same size should be documented or enforced... + if not max_length: + max_length = self.model_binary.data.input_shape[1] + raw_domains_todo, domains_to_skip = separate_domains_that_are_too_long( + raw_domains, max_length + ) + domains_todo = [ + Domain(raw=raw_domain, padded_length=max_length) + for raw_domain in raw_domains_todo + ] + if domains_to_skip: + logging.warning( + f"will skip domains {domains_to_skip} because they are too long for the binary model" + ) + return domains_todo, domains_to_skip + + def investigate_binary(self, word: Word) -> None: + x_test: np.ndarray = np.array([word.padded_token_vector]) + y_test: np.ndarray = self.model_binary.data.predict(x_test, verbose=0) + word.binary_score = float(y_test[0][0]) + word.binary_label = self.model_binary.labels[int(np.round(word.binary_score))] + + def investigate_family(self, word: Word) -> None: + x_test: np.ndarray = np.array([word.padded_token_vector]) + y_test = self.model_multi.data.predict(x_test, verbose=0) + best_class_label_index = np.argmax(y_test, axis=1)[0] + word.family_score = float(np.max(y_test, axis=1)[0]) + word.family_label = self.model_multi.labels[best_class_label_index] + + def investigate(self, domains: List[Domain]) -> None: + """ + performs binary and family predictions on provided list of domains. + predictions are stored in the words attributes + """ + # TODO: test performance + for domain in domains: + is_dga = False + for word in domain.words: + self.investigate_binary(word) + if word.binary_score > 0.5: + self.investigate_family(word) + is_dga = True + if is_dga: + domain.is_dga = True + domain.set_family() + log_analysis(domain) diff --git a/dgad/schema.py b/dgad/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..d68ba7ee0193b9358405ba8b3c0d936037d4b77a --- /dev/null +++ b/dgad/schema.py @@ -0,0 +1,116 @@ +import logging +from dataclasses import dataclass, field +from typing import List + +import tldextract + +from dgad import utils + +CHARACTERS_DICTIONARY = { + "0": 1, + "1": 2, + "2": 3, + "3": 4, + "4": 5, + "5": 6, + "6": 7, + "7": 8, + "8": 9, + "9": 10, + "a": 11, + "b": 12, + "c": 13, + "d": 14, + "e": 15, + "f": 16, + "g": 17, + "h": 18, + "i": 19, + "j": 20, + "k": 21, + "l": 22, + "m": 23, + "n": 24, + "o": 25, + "p": 26, + "q": 27, + "r": 28, + "s": 29, + "t": 30, + "u": 31, + "v": 32, + "w": 33, + "x": 34, + "y": 35, + "z": 36, + "-": 38, + "_": 39, + ".": 40, +} + + +@dataclass +class Word: + """ + A word is the smallest unit on which we can peform classification + """ + + value: str + padded_length: int = 0 + padded_token_vector: List[int] = field(default_factory=list) + binary_score: float = 0.0 + binary_label: str = "" + family_score: float = 0.0 + family_label: str = "N/A" + + def __post_init__(self) -> None: + """ + preprocessing. if padded length is not provided (default=0) then sets it to the length of the string + """ + sanitised_value = utils.strip_forbidden_characters( + word=self.value, characters_dictionary=CHARACTERS_DICTIONARY + ) + token_vector = utils.tokenize_word( + word=sanitised_value, characters_dictionary=CHARACTERS_DICTIONARY + ) + if not self.padded_length: + self.padded_length = len(self.value) + self.padded_token_vector = utils.pad_vector( + vector=token_vector, desired_length=self.padded_length + ) + + +@dataclass +class Domain: + raw: str + words: List[Word] = None + suffix: str = "" + is_dga: bool = False + family_label: str = "N/A" + padded_length: int = 0 + + def __post_init__(self) -> None: + raw_subdomains, raw_domain_name, self.suffix = tldextract.extract( + utils.remove_prefix(self.raw, "www.") + ) + raw_words = [] + raw_words.append(raw_domain_name) + if raw_subdomains: + raw_words += list(set(raw_subdomains.split("."))) + self.words = [Word(raw_word, self.padded_length) for raw_word in raw_words] + logging.debug(self) + + def set_family( + self, + binary_confidence_threshold: float = 0.5, + family_confidence_threshold: float = 0, + ): + """ + sets the domain family to be the one from the word with the highest family score + """ + max_family_score = family_confidence_threshold + for word in self.words: + if word.binary_score > binary_confidence_threshold: + if word.family_score > max_family_score: + max_family_score = word.family_score + self.family_label = word.family_label diff --git a/dgad/utils.py b/dgad/utils.py index a61b8fc29983f1c87dc1b6af72576cf205ab5fcb..943c6f19834d7f21aece523f76dfbcf69eece93c 100644 --- a/dgad/utils.py +++ b/dgad/utils.py @@ -3,66 +3,15 @@ utils module for lstm library provides non lstm specific miscellaneous methods """ +import json import logging +import time +from dataclasses import asdict +from pathlib import Path from typing import Dict, List, Tuple import numpy as np -import pandas as pd -import tldextract - - -def create_domains_dataframe(domains: List[str]) -> pd.DataFrame: - return pd.DataFrame(domains, columns=["domain"]) - - -def create_characters_dictionary() -> Dict[str, int]: - # digits = [digit for digit in range(10)] - # digit_to_str = [str(digit) for digit in digits] - # digits_str: Dict[str, int] = dict(zip(digit_to_str, range(1, 11))) - # letters: Dict[str, int] = dict(zip(string.ascii_lowercase, range(11, 38))) - # symbols: Dict[str, int] = dict(zip(["-", "_", "."], range(38, 41))) - # return {**digits_str, **letters, **symbols} - return { - "0": 1, - "1": 2, - "2": 3, - "3": 4, - "4": 5, - "5": 6, - "6": 7, - "7": 8, - "8": 9, - "9": 10, - "a": 11, - "b": 12, - "c": 13, - "d": 14, - "e": 15, - "f": 16, - "g": 17, - "h": 18, - "i": 19, - "j": 20, - "k": 21, - "l": 22, - "m": 23, - "n": 24, - "o": 25, - "p": 26, - "q": 27, - "r": 28, - "s": 29, - "t": 30, - "u": 31, - "v": 32, - "w": 33, - "x": 34, - "y": 35, - "z": 36, - "-": 38, - "_": 39, - ".": 40, - } +from sklearn import preprocessing def strip_forbidden_characters(word: str, characters_dictionary: Dict[str, int]) -> str: @@ -78,7 +27,7 @@ def tokenize_word(word: str, characters_dictionary: Dict[str, int]) -> List[int] @param word: word to tokenize @return: vector of word tokens """ - word_characters = [character for character in word] + word_characters = list(word) vector: List[int] = [characters_dictionary[_] for _ in word_characters] return vector @@ -94,31 +43,52 @@ def pad_vector(vector: List[int], desired_length: int) -> List[int]: return vector[:desired_length] + padding -def random_split_train_test( - domain_names_df: pd.DataFrame, split_ratio: float = 0.8 -) -> Tuple[pd.DataFrame, pd.DataFrame]: - train_set_df = pd.DataFrame() - test_set_df = pd.DataFrame() - if 0 < split_ratio < 1.0: - train_set = np.random.rand(len(domain_names_df)) < split_ratio - train_set_df = domain_names_df[train_set] - test_set_df = domain_names_df[~train_set] - else: - logging.error(msg="split_ratio must be between 0 and 1") - return train_set_df, test_set_df - - -def extract_domain_name_and_subdomains(raw_domain: str) -> Tuple[str, List[str]]: - raw_subdomains, raw_domain_name, _ = tldextract.extract( - remove_prefix(raw_domain, "www.") - ) - raw_subdomains = list(set(raw_subdomains.split("."))) - return raw_domain_name, raw_subdomains - - def separate_domains_that_are_too_long( domains: List[str], max_size: int ) -> Tuple[List[str], List[str]]: domains_shorter_or_equal = [domain for domain in domains if len(domain) <= max_size] domains_too_long = set(domains) - set(domains_shorter_or_equal) return domains_shorter_or_equal, list(domains_too_long) + + +def setup_logging( + level: str, logformat: str = "%(asctime)2s %(levelname)-8s %(message)s" +): + numeric_level = getattr(logging, level.upper(), None) + logging.basicConfig(level=numeric_level, format=logformat) + logging.debug(f"logging level set to {level}") + + +def load_labels(encoder_path: Path) -> Dict[int, str]: + encoder = preprocessing.LabelEncoder() + encoder.classes_ = np.load(encoder_path) + labels = dict(zip(range(len(encoder.classes_)), encoder.classes_)) + return labels + + +def log_analysis(domain) -> None: + """ + reports to stdout outcome of classification + """ + logging.info( + f"{domain.raw}, is_dga: {domain.is_dga}, family: {domain.family_label}" + ) + for word in domain.words: + logging.debug(asdict(domain)) + + +def log_performance(counter, start_time): + elapsed_time = time.time() - start_time + avg_classification_took = elapsed_time / counter + logging.warning( + f"classified {counter} domains, took on average: {avg_classification_took}s" + ) + + +def pretty_print(domains, output_format="json") -> str: + dicts = [asdict(domain) for domain in domains] + for domain in dicts: + for word in domain["words"]: + del word["padded_token_vector"] + if output_format == "json": + print(json.dumps(dicts, indent=2)) diff --git a/helm/Chart.yaml b/helm/Chart.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa89ff405a25bfcd719a7b281eb2f882913c94d9 --- /dev/null +++ b/helm/Chart.yaml @@ -0,0 +1,9 @@ +--- +apiVersion: v2 +name: dgad +description: soccrates dga detective + +type: application +version: 3.1.4 + +appVersion: "3.1.4" diff --git a/helm/templates/deployment.yaml b/helm/templates/deployment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ca5f2d8b0082068cfac7e55f56ed096489c1e38c --- /dev/null +++ b/helm/templates/deployment.yaml @@ -0,0 +1,29 @@ +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: "{{ .Release.Name }}-detective" + namespace: "{{ .Release.Namespace }}" +spec: + replicas: {{ .Values.replicas | default 1 }} + selector: + matchLabels: + app: "{{ .Release.Name }}-detective" + template: + metadata: + labels: + app: "{{ .Release.Name }}-detective" + spec: + containers: + - name: {{ .Chart.Name }} + image: {{ .Values.image.repository | default "registry.gitlab.com/cossas/dgad" }}:{{ .Values.image.tag }} + args: ["serve", "-p", {{ .Values.containerPort | quote }} ] + ports: + - containerPort: {{ .Values.containerPort }} + resources: + requests: + memory: {{ .Values.requests.memory | default "500Mi" | quote }} + cpu: {{ .Values.requests.cpu | default "1000m" | quote }} + limits: + memory: {{ .Values.limits.memory | default "500Mi" | quote }} + cpu: {{ .Values.limits.cpu | default "1000m" | quote }} diff --git a/helm/templates/ingress.yaml b/helm/templates/ingress.yaml new file mode 100644 index 0000000000000000000000000000000000000000..464090415c47109523e91779d4f40e19495c9cf1 --- /dev/null +++ b/helm/templates/ingress.yaml @@ -0,0 +1 @@ +# TODO diff --git a/helm/templates/service.yaml b/helm/templates/service.yaml new file mode 100644 index 0000000000000000000000000000000000000000..082fea9524901f916d49dd6858ddbd7738d46d39 --- /dev/null +++ b/helm/templates/service.yaml @@ -0,0 +1,12 @@ +--- +apiVersion: v1 +kind: Service +metadata: + name: "{{ .Release.Name }}-detective" + namespace: {{ .Release.Namespace }} +spec: + ports: + - port: {{ .Values.containerPort }} + protocol: TCP + selector: + app: "{{ .Release.Name }}-detective" diff --git a/helm/values.yaml b/helm/values.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cfe17fb12ef9b5592b4dbefcfcb9e2c906e7d55f --- /dev/null +++ b/helm/values.yaml @@ -0,0 +1,11 @@ +replicas: 1 +image: + repository: registry.gitlab.com/cossas/dgad + tag: "latest" +containerPort: 4714 +requests: + memory: "" + cpu: "" +limits: + memory: "" + cpu: "" diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 0000000000000000000000000000000000000000..7612fad2de322d6e9928d9d4e18313a9b0df384d --- /dev/null +++ b/poetry.lock @@ -0,0 +1,1215 @@ +[[package]] +name = "absl-py" +version = "1.2.0" +description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." +category = "main" +optional = false +python-versions = ">=3.6" + +[[package]] +name = "astroid" +version = "2.11.7" +description = "An abstract syntax tree for Python with inference support." +category = "dev" +optional = false +python-versions = ">=3.6.2" + +[package.dependencies] +lazy-object-proxy = ">=1.4.0" +typing-extensions = {version = ">=3.10", markers = "python_version < \"3.10\""} +wrapt = ">=1.11,<2" + +[[package]] +name = "astunparse" +version = "1.6.3" +description = "An AST unparser for Python" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +six = ">=1.6.1,<2.0" + +[[package]] +name = "atomicwrites" +version = "1.4.1" +description = "Atomic file writes." +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[[package]] +name = "attrs" +version = "21.4.0" +description = "Classes Without Boilerplate" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + +[package.extras] +dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"] +docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] +tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"] +tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"] + +[[package]] +name = "black" +version = "22.6.0" +description = "The uncompromising code formatter." +category = "dev" +optional = false +python-versions = ">=3.6.2" + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""} +typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + +[[package]] +name = "cachetools" +version = "5.2.0" +description = "Extensible memoizing collections and decorators" +category = "main" +optional = false +python-versions = "~=3.7" + +[[package]] +name = "certifi" +version = "2022.6.15" +description = "Python package for providing Mozilla's CA Bundle." +category = "main" +optional = false +python-versions = ">=3.6" + +[[package]] +name = "charset-normalizer" +version = "2.1.0" +description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "main" +optional = false +python-versions = ">=3.6.0" + +[package.extras] +unicode_backport = ["unicodedata2"] + +[[package]] +name = "click" +version = "8.1.3" +description = "Composable command line interface toolkit" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[[package]] +name = "colorama" +version = "0.4.5" +description = "Cross-platform colored terminal text." +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + +[[package]] +name = "dill" +version = "0.3.5.1" +description = "serialize all of python" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" + +[package.extras] +graph = ["objgraph (>=1.7.2)"] + +[[package]] +name = "filelock" +version = "3.7.1" +description = "A platform independent file lock." +category = "main" +optional = false +python-versions = ">=3.7" + +[package.extras] +docs = ["furo (>=2021.8.17b43)", "sphinx (>=4.1)", "sphinx-autodoc-typehints (>=1.12)"] +testing = ["covdefaults (>=1.2.0)", "coverage (>=4)", "pytest (>=4)", "pytest-cov", "pytest-timeout (>=1.4.2)"] + +[[package]] +name = "flatbuffers" +version = "1.12" +description = "The FlatBuffers serialization format for Python" +category = "main" +optional = false +python-versions = "*" + +[[package]] +name = "gast" +version = "0.4.0" +description = "Python AST that abstracts the underlying Python version" +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[[package]] +name = "google-auth" +version = "2.9.1" +description = "Google Authentication Library" +category = "main" +optional = false +python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*" + +[package.dependencies] +cachetools = ">=2.0.0,<6.0" +pyasn1-modules = ">=0.2.1" +rsa = {version = ">=3.1.4,<5", markers = "python_version >= \"3.6\""} +six = ">=1.9.0" + +[package.extras] +aiohttp = ["requests (>=2.20.0,<3.0.0dev)", "aiohttp (>=3.6.2,<4.0.0dev)"] +enterprise_cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"] +pyopenssl = ["pyopenssl (>=20.0.0)"] +reauth = ["pyu2f (>=0.1.5)"] + +[[package]] +name = "google-auth-oauthlib" +version = "0.4.6" +description = "Google Authentication Library" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +google-auth = ">=1.0.0" +requests-oauthlib = ">=0.7.0" + +[package.extras] +tool = ["click (>=6.0.0)"] + +[[package]] +name = "google-pasta" +version = "0.2.0" +description = "pasta is an AST-based Python refactoring library" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +six = "*" + +[[package]] +name = "grpcio" +version = "1.47.0" +description = "HTTP/2-based RPC framework" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +six = ">=1.5.2" + +[package.extras] +protobuf = ["grpcio-tools (>=1.47.0)"] + +[[package]] +name = "grpcio-tools" +version = "1.47.0" +description = "Protobuf code generator for gRPC" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +grpcio = ">=1.47.0" +protobuf = ">=3.12.0,<4.0dev" + +[[package]] +name = "h5py" +version = "3.7.0" +description = "Read and write HDF5 files from Python" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +numpy = ">=1.14.5" + +[[package]] +name = "idna" +version = "3.3" +description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" +optional = false +python-versions = ">=3.5" + +[[package]] +name = "importlib-metadata" +version = "4.12.0" +description = "Read metadata from Python packages" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)"] +perf = ["ipython"] +testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.3)", "packaging", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "importlib-resources (>=1.3)"] + +[[package]] +name = "iniconfig" +version = "1.1.1" +description = "iniconfig: brain-dead simple config-ini parsing" +category = "dev" +optional = false +python-versions = "*" + +[[package]] +name = "isort" +version = "5.10.1" +description = "A Python utility / library to sort Python imports." +category = "dev" +optional = false +python-versions = ">=3.6.1,<4.0" + +[package.extras] +pipfile_deprecated_finder = ["pipreqs", "requirementslib"] +requirements_deprecated_finder = ["pipreqs", "pip-api"] +colors = ["colorama (>=0.4.3,<0.5.0)"] +plugins = ["setuptools"] + +[[package]] +name = "joblib" +version = "1.1.0" +description = "Lightweight pipelining with Python functions" +category = "main" +optional = false +python-versions = ">=3.6" + +[[package]] +name = "keras" +version = "2.9.0" +description = "Deep learning for humans." +category = "main" +optional = false +python-versions = "*" + +[[package]] +name = "keras-preprocessing" +version = "1.1.2" +description = "Easy data preprocessing and data augmentation for deep learning models" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +numpy = ">=1.9.1" +six = ">=1.9.0" + +[package.extras] +image = ["scipy (>=0.14)", "Pillow (>=5.2.0)"] +pep8 = ["flake8"] +tests = ["pandas", "pillow", "tensorflow", "keras", "pytest", "pytest-xdist", "pytest-cov"] + +[[package]] +name = "keras-tcn" +version = "3.4.4" +description = "Keras TCN" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +numpy = "*" +tensorflow = "*" +tensorflow-addons = "*" + +[[package]] +name = "lazy-object-proxy" +version = "1.7.1" +description = "A fast and thorough lazy object proxy." +category = "dev" +optional = false +python-versions = ">=3.6" + +[[package]] +name = "libclang" +version = "14.0.1" +description = "Clang Python Bindings, mirrored from the official LLVM repo: https://github.com/llvm/llvm-project/tree/main/clang/bindings/python, to make the installation process easier." +category = "main" +optional = false +python-versions = "*" + +[[package]] +name = "markdown" +version = "3.4.1" +description = "Python implementation of Markdown." +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""} + +[package.extras] +testing = ["coverage", "pyyaml"] + +[[package]] +name = "mccabe" +version = "0.7.0" +description = "McCabe checker, plugin for flake8" +category = "dev" +optional = false +python-versions = ">=3.6" + +[[package]] +name = "mypy-extensions" +version = "0.4.3" +description = "Experimental type system extensions for programs checked with the mypy typechecker." +category = "dev" +optional = false +python-versions = "*" + +[[package]] +name = "numpy" +version = "1.23.1" +description = "NumPy is the fundamental package for array computing with Python." +category = "main" +optional = false +python-versions = ">=3.8" + +[[package]] +name = "oauthlib" +version = "3.2.0" +description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.extras] +rsa = ["cryptography (>=3.0.0)"] +signals = ["blinker (>=1.4.0)"] +signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] + +[[package]] +name = "opt-einsum" +version = "3.3.0" +description = "Optimizing numpys einsum function" +category = "main" +optional = false +python-versions = ">=3.5" + +[package.dependencies] +numpy = ">=1.7" + +[package.extras] +docs = ["sphinx (==1.2.3)", "sphinxcontrib-napoleon", "sphinx-rtd-theme", "numpydoc"] +tests = ["pytest", "pytest-cov", "pytest-pep8"] + +[[package]] +name = "packaging" +version = "21.3" +description = "Core utilities for Python packages" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" + +[[package]] +name = "pandas" +version = "1.4.3" +description = "Powerful data structures for data analysis, time series, and statistics" +category = "main" +optional = false +python-versions = ">=3.8" + +[package.dependencies] +numpy = [ + {version = ">=1.18.5", markers = "platform_machine != \"aarch64\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, + {version = ">=1.19.2", markers = "platform_machine == \"aarch64\" and python_version < \"3.10\""}, + {version = ">=1.20.0", markers = "platform_machine == \"arm64\" and python_version < \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, +] +python-dateutil = ">=2.8.1" +pytz = ">=2020.1" + +[package.extras] +test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] + +[[package]] +name = "pathspec" +version = "0.9.0" +description = "Utility library for gitignore style pattern matching of file paths." +category = "dev" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" + +[[package]] +name = "platformdirs" +version = "2.5.2" +description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.extras] +docs = ["furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx-autodoc-typehints (>=1.12)", "sphinx (>=4)"] +test = ["appdirs (==1.4.4)", "pytest-cov (>=2.7)", "pytest-mock (>=3.6)", "pytest (>=6)"] + +[[package]] +name = "pluggy" +version = "1.0.0" +description = "plugin and hook calling mechanisms for python" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "protobuf" +version = "3.19.4" +description = "Protocol Buffers" +category = "main" +optional = false +python-versions = ">=3.5" + +[[package]] +name = "py" +version = "1.11.0" +description = "library with cross-python path, ini-parsing, io, code, log facilities" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + +[[package]] +name = "pyasn1" +version = "0.4.8" +description = "ASN.1 types and codecs" +category = "main" +optional = false +python-versions = "*" + +[[package]] +name = "pyasn1-modules" +version = "0.2.8" +description = "A collection of ASN.1-based protocols modules." +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +pyasn1 = ">=0.4.6,<0.5.0" + +[[package]] +name = "pylint" +version = "2.14.5" +description = "python code static checker" +category = "dev" +optional = false +python-versions = ">=3.7.2" + +[package.dependencies] +astroid = ">=2.11.6,<=2.12.0-dev0" +colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} +dill = ">=0.2" +isort = ">=4.2.5,<6" +mccabe = ">=0.6,<0.8" +platformdirs = ">=2.2.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +tomlkit = ">=0.10.1" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +spelling = ["pyenchant (>=3.2,<4.0)"] +testutils = ["gitpython (>3)"] + +[[package]] +name = "pyparsing" +version = "3.0.9" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +category = "main" +optional = false +python-versions = ">=3.6.8" + +[package.extras] +diagrams = ["railroad-diagrams", "jinja2"] + +[[package]] +name = "pytest" +version = "7.1.2" +description = "pytest: simple powerful testing with Python" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} +attrs = ">=19.2.0" +colorama = {version = "*", markers = "sys_platform == \"win32\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +py = ">=1.8.2" +tomli = ">=1.0.0" + +[package.extras] +testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] + +[[package]] +name = "python-dateutil" +version = "2.8.2" +description = "Extensions to the standard Python datetime module" +category = "main" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" + +[package.dependencies] +six = ">=1.5" + +[[package]] +name = "pytz" +version = "2022.1" +description = "World timezone definitions, modern and historical" +category = "main" +optional = false +python-versions = "*" + +[[package]] +name = "requests" +version = "2.28.1" +description = "Python HTTP for Humans." +category = "main" +optional = false +python-versions = ">=3.7, <4" + +[package.dependencies] +certifi = ">=2017.4.17" +charset-normalizer = ">=2,<3" +idna = ">=2.5,<4" +urllib3 = ">=1.21.1,<1.27" + +[package.extras] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] + +[[package]] +name = "requests-file" +version = "1.5.1" +description = "File transport adapter for Requests" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +requests = ">=1.0.0" +six = "*" + +[[package]] +name = "requests-oauthlib" +version = "1.3.1" +description = "OAuthlib authentication support for Requests." +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[package.dependencies] +oauthlib = ">=3.0.0" +requests = ">=2.0.0" + +[package.extras] +rsa = ["oauthlib[signedtoken] (>=3.0.0)"] + +[[package]] +name = "rsa" +version = "4.9" +description = "Pure-Python RSA implementation" +category = "main" +optional = false +python-versions = ">=3.6,<4" + +[package.dependencies] +pyasn1 = ">=0.1.3" + +[[package]] +name = "scikit-learn" +version = "1.1.1" +description = "A set of python modules for machine learning and data mining" +category = "main" +optional = false +python-versions = ">=3.8" + +[package.dependencies] +joblib = ">=1.0.0" +numpy = ">=1.17.3" +scipy = ">=1.3.2" +threadpoolctl = ">=2.0.0" + +[package.extras] +benchmark = ["matplotlib (>=3.1.2)", "pandas (>=1.0.5)", "memory-profiler (>=0.57.0)"] +docs = ["matplotlib (>=3.1.2)", "scikit-image (>=0.14.5)", "pandas (>=1.0.5)", "seaborn (>=0.9.0)", "memory-profiler (>=0.57.0)", "sphinx (>=4.0.1)", "sphinx-gallery (>=0.7.0)", "numpydoc (>=1.2.0)", "Pillow (>=7.1.2)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"] +examples = ["matplotlib (>=3.1.2)", "scikit-image (>=0.14.5)", "pandas (>=1.0.5)", "seaborn (>=0.9.0)"] +tests = ["matplotlib (>=3.1.2)", "scikit-image (>=0.14.5)", "pandas (>=1.0.5)", "pytest (>=5.0.1)", "pytest-cov (>=2.9.0)", "flake8 (>=3.8.2)", "black (>=22.3.0)", "mypy (>=0.770)", "pyamg (>=4.0.0)", "numpydoc (>=1.2.0)"] + +[[package]] +name = "scipy" +version = "1.8.1" +description = "SciPy: Scientific Library for Python" +category = "main" +optional = false +python-versions = ">=3.8,<3.11" + +[package.dependencies] +numpy = ">=1.17.3,<1.25.0" + +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" + +[[package]] +name = "tensorboard" +version = "2.9.1" +description = "TensorBoard lets you watch Tensors Flow" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +absl-py = ">=0.4" +google-auth = ">=1.6.3,<3" +google-auth-oauthlib = ">=0.4.1,<0.5" +grpcio = ">=1.24.3" +markdown = ">=2.6.8" +numpy = ">=1.12.0" +protobuf = ">=3.9.2,<3.20" +requests = ">=2.21.0,<3" +tensorboard-data-server = ">=0.6.0,<0.7.0" +tensorboard-plugin-wit = ">=1.6.0" +werkzeug = ">=1.0.1" + +[[package]] +name = "tensorboard-data-server" +version = "0.6.1" +description = "Fast data loading for TensorBoard" +category = "main" +optional = false +python-versions = ">=3.6" + +[[package]] +name = "tensorboard-plugin-wit" +version = "1.8.1" +description = "What-If Tool TensorBoard plugin." +category = "main" +optional = false +python-versions = "*" + +[[package]] +name = "tensorflow" +version = "2.9.1" +description = "TensorFlow is an open source machine learning framework for everyone." +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +absl-py = ">=1.0.0" +astunparse = ">=1.6.0" +flatbuffers = ">=1.12,<2" +gast = ">=0.2.1,<=0.4.0" +google-pasta = ">=0.1.1" +grpcio = ">=1.24.3,<2.0" +h5py = ">=2.9.0" +keras = ">=2.9.0rc0,<2.10.0" +keras-preprocessing = ">=1.1.1" +libclang = ">=13.0.0" +numpy = ">=1.20" +opt-einsum = ">=2.3.2" +packaging = "*" +protobuf = ">=3.9.2,<3.20" +six = ">=1.12.0" +tensorboard = ">=2.9,<2.10" +tensorflow-estimator = ">=2.9.0rc0,<2.10.0" +tensorflow-io-gcs-filesystem = ">=0.23.1" +termcolor = ">=1.1.0" +typing-extensions = ">=3.6.6" +wrapt = ">=1.11.0" + +[[package]] +name = "tensorflow-addons" +version = "0.17.1" +description = "TensorFlow Addons." +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +packaging = "*" +typeguard = ">=2.7" + +[package.extras] +tensorflow = ["tensorflow (>=2.7.0,<2.10.0)"] +tensorflow-cpu = ["tensorflow-cpu (>=2.7.0,<2.10.0)"] +tensorflow-gpu = ["tensorflow-gpu (>=2.7.0,<2.10.0)"] + +[[package]] +name = "tensorflow-estimator" +version = "2.9.0" +description = "TensorFlow Estimator." +category = "main" +optional = false +python-versions = ">=3.7" + +[[package]] +name = "tensorflow-io-gcs-filesystem" +version = "0.26.0" +description = "TensorFlow IO" +category = "main" +optional = false +python-versions = ">=3.7, <3.11" + +[package.extras] +tensorflow = ["tensorflow (>=2.9.0,<2.10.0)"] +tensorflow-aarch64 = ["tensorflow-aarch64 (>=2.9.0,<2.10.0)"] +tensorflow-cpu = ["tensorflow-cpu (>=2.9.0,<2.10.0)"] +tensorflow-gpu = ["tensorflow-gpu (>=2.9.0,<2.10.0)"] +tensorflow-rocm = ["tensorflow-rocm (>=2.9.0,<2.10.0)"] + +[[package]] +name = "termcolor" +version = "1.1.0" +description = "ANSII Color formatting for output in terminal." +category = "main" +optional = false +python-versions = "*" + +[[package]] +name = "threadpoolctl" +version = "3.1.0" +description = "threadpoolctl" +category = "main" +optional = false +python-versions = ">=3.6" + +[[package]] +name = "tldextract" +version = "3.3.1" +description = "Accurately separates a URL's subdomain, domain, and public suffix, using the Public Suffix List (PSL). By default, this includes the public ICANN TLDs and their exceptions. You can optionally support the Public Suffix List's private domains as well." +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +filelock = ">=3.0.8" +idna = "*" +requests = ">=2.1.0" +requests-file = ">=1.4" + +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +category = "dev" +optional = false +python-versions = ">=3.7" + +[[package]] +name = "tomlkit" +version = "0.11.1" +description = "Style preserving TOML library" +category = "dev" +optional = false +python-versions = ">=3.6,<4.0" + +[[package]] +name = "typeguard" +version = "2.13.3" +description = "Run-time type checker for Python" +category = "main" +optional = false +python-versions = ">=3.5.3" + +[package.extras] +doc = ["sphinx-rtd-theme", "sphinx-autodoc-typehints (>=1.2.0)"] +test = ["pytest", "typing-extensions", "mypy"] + +[[package]] +name = "typing-extensions" +version = "4.3.0" +description = "Backported and Experimental Type Hints for Python 3.7+" +category = "main" +optional = false +python-versions = ">=3.7" + +[[package]] +name = "urllib3" +version = "1.26.10" +description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, <4" + +[package.extras] +brotli = ["brotlicffi (>=0.8.0)", "brotli (>=1.0.9)", "brotlipy (>=0.6.0)"] +secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] + +[[package]] +name = "werkzeug" +version = "2.1.2" +description = "The comprehensive WSGI web application library." +category = "main" +optional = false +python-versions = ">=3.7" + +[package.extras] +watchdog = ["watchdog"] + +[[package]] +name = "wrapt" +version = "1.14.1" +description = "Module for decorators, wrappers and monkey patching." +category = "main" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" + +[[package]] +name = "zipp" +version = "3.8.1" +description = "Backport of pathlib-compatible object wrapper for zip files" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.extras] +docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)", "jaraco.tidelift (>=1.4)"] +testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.3)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)"] + +[metadata] +lock-version = "1.1" +python-versions = ">=3.8,<3.11" +content-hash = "217d7e356df5771df3c382886721515e09ee14fdbd2379158a3e4d55625a7e3b" + +[metadata.files] +absl-py = [] +astroid = [ + {file = "astroid-2.11.7-py3-none-any.whl", hash = "sha256:86b0a340a512c65abf4368b80252754cda17c02cdbbd3f587dddf98112233e7b"}, + {file = "astroid-2.11.7.tar.gz", hash = "sha256:bb24615c77f4837c707669d16907331374ae8a964650a66999da3f5ca68dc946"}, +] +astunparse = [] +atomicwrites = [ + {file = "atomicwrites-1.4.1.tar.gz", hash = "sha256:81b2c9071a49367a7f770170e5eec8cb66567cfbbc8c73d20ce5ca4a8d71cf11"}, +] +attrs = [ + {file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"}, + {file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"}, +] +black = [ + {file = "black-22.6.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f586c26118bc6e714ec58c09df0157fe2d9ee195c764f630eb0d8e7ccce72e69"}, + {file = "black-22.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b270a168d69edb8b7ed32c193ef10fd27844e5c60852039599f9184460ce0807"}, + {file = "black-22.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6797f58943fceb1c461fb572edbe828d811e719c24e03375fd25170ada53825e"}, + {file = "black-22.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c85928b9d5f83b23cee7d0efcb310172412fbf7cb9d9ce963bd67fd141781def"}, + {file = "black-22.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:f6fe02afde060bbeef044af7996f335fbe90b039ccf3f5eb8f16df8b20f77666"}, + {file = "black-22.6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cfaf3895a9634e882bf9d2363fed5af8888802d670f58b279b0bece00e9a872d"}, + {file = "black-22.6.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94783f636bca89f11eb5d50437e8e17fbc6a929a628d82304c80fa9cd945f256"}, + {file = "black-22.6.0-cp36-cp36m-win_amd64.whl", hash = "sha256:2ea29072e954a4d55a2ff58971b83365eba5d3d357352a07a7a4df0d95f51c78"}, + {file = "black-22.6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e439798f819d49ba1c0bd9664427a05aab79bfba777a6db94fd4e56fae0cb849"}, + {file = "black-22.6.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:187d96c5e713f441a5829e77120c269b6514418f4513a390b0499b0987f2ff1c"}, + {file = "black-22.6.0-cp37-cp37m-win_amd64.whl", hash = "sha256:074458dc2f6e0d3dab7928d4417bb6957bb834434516f21514138437accdbe90"}, + {file = "black-22.6.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a218d7e5856f91d20f04e931b6f16d15356db1c846ee55f01bac297a705ca24f"}, + {file = "black-22.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:568ac3c465b1c8b34b61cd7a4e349e93f91abf0f9371eda1cf87194663ab684e"}, + {file = "black-22.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6c1734ab264b8f7929cef8ae5f900b85d579e6cbfde09d7387da8f04771b51c6"}, + {file = "black-22.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9a3ac16efe9ec7d7381ddebcc022119794872abce99475345c5a61aa18c45ad"}, + {file = "black-22.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:b9fd45787ba8aa3f5e0a0a98920c1012c884622c6c920dbe98dbd05bc7c70fbf"}, + {file = "black-22.6.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7ba9be198ecca5031cd78745780d65a3f75a34b2ff9be5837045dce55db83d1c"}, + {file = "black-22.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a3db5b6409b96d9bd543323b23ef32a1a2b06416d525d27e0f67e74f1446c8f2"}, + {file = "black-22.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:560558527e52ce8afba936fcce93a7411ab40c7d5fe8c2463e279e843c0328ee"}, + {file = "black-22.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b154e6bbde1e79ea3260c4b40c0b7b3109ffcdf7bc4ebf8859169a6af72cd70b"}, + {file = "black-22.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:4af5bc0e1f96be5ae9bd7aaec219c901a94d6caa2484c21983d043371c733fc4"}, + {file = "black-22.6.0-py3-none-any.whl", hash = "sha256:ac609cf8ef5e7115ddd07d85d988d074ed00e10fbc3445aee393e70164a2219c"}, + {file = "black-22.6.0.tar.gz", hash = "sha256:6c6d39e28aed379aec40da1c65434c77d75e65bb59a1e1c283de545fb4e7c6c9"}, +] +cachetools = [] +certifi = [] +charset-normalizer = [] +click = [ + {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, + {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, +] +colorama = [ + {file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"}, + {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, +] +dill = [ + {file = "dill-0.3.5.1-py2.py3-none-any.whl", hash = "sha256:33501d03270bbe410c72639b350e941882a8b0fd55357580fbc873fba0c59302"}, + {file = "dill-0.3.5.1.tar.gz", hash = "sha256:d75e41f3eff1eee599d738e76ba8f4ad98ea229db8b085318aa2b3333a208c86"}, +] +filelock = [] +flatbuffers = [] +gast = [] +google-auth = [] +google-auth-oauthlib = [] +google-pasta = [] +grpcio = [] +grpcio-tools = [ + {file = "grpcio-tools-1.47.0.tar.gz", hash = "sha256:f64b5378484be1d6ce59311f86174be29c8ff98d8d90f589e1c56d5acae67d3c"}, + {file = "grpcio_tools-1.47.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:3edb04d102e0d6f0149d93fe8cf69a38c20a2259a913701a4c35c119049c8404"}, + {file = "grpcio_tools-1.47.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:dd5d330230038374e64fc652fc4c1b25d457a8b67b9069bfce83a17ab675650b"}, + {file = "grpcio_tools-1.47.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:498c0bae4975683a5a33b72cf1bd64703b34c826871fd3ee8d295407cd5211ec"}, + {file = "grpcio_tools-1.47.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1de1f139f05ab6bbdabc58b06f6ebb5940a92214bbc7246270299387d0af2ae"}, + {file = "grpcio_tools-1.47.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3fccc282ee97211a33652419dcdfd24a9a60bbd2d56f5c5dd50c7186a0f4d978"}, + {file = "grpcio_tools-1.47.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:441a0a378117447c089b944f325f11039329d8aa961ecdb8226c5dd84af6f003"}, + {file = "grpcio_tools-1.47.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0eced69e159b3fdd7597d85950f56990e0aa81c11a20a7785fb66f0e47c46b57"}, + {file = "grpcio_tools-1.47.0-cp310-cp310-win32.whl", hash = "sha256:2c5c50886e6e79af5387c6514eb19f1f6b1a0b4eb787f1b7a8f21a74e2444102"}, + {file = "grpcio_tools-1.47.0-cp310-cp310-win_amd64.whl", hash = "sha256:156b5f6654fea51983fd9257d47f1ad7bfb2a1d09ed471e610a7b34b97d40802"}, + {file = "grpcio_tools-1.47.0-cp36-cp36m-linux_armv7l.whl", hash = "sha256:94114e01c4508d904825bd984e3d2752c0b0e6eb714ac08b99f73421691cf931"}, + {file = "grpcio_tools-1.47.0-cp36-cp36m-macosx_10_10_x86_64.whl", hash = "sha256:51352070f13ea3346b5f5ca825f2203528b8218fffc6ac6d951216f812272d8b"}, + {file = "grpcio_tools-1.47.0-cp36-cp36m-manylinux_2_17_aarch64.whl", hash = "sha256:53c47b08ee2f59a89e8df5f3c09850d7fac264754cbaeabae65f6fbf78d80536"}, + {file = "grpcio_tools-1.47.0-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:818fca1c7dd4ad1c9c01f91ba37006964f4c57c93856fa4ebd7d5589132844d6"}, + {file = "grpcio_tools-1.47.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2364ac3bd7266752c9971dbef3f79d21cd958777823512faa93473cbd973b8f1"}, + {file = "grpcio_tools-1.47.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:9dd6e26e3e0555deadcb52b087c6064e4fd02c09180b42e96c66260137d26b50"}, + {file = "grpcio_tools-1.47.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:a93263955da8d6e449d7ceb84af4e84b82fa760fd661b4ef4549929d9670ab8e"}, + {file = "grpcio_tools-1.47.0-cp36-cp36m-win32.whl", hash = "sha256:6804cbd92b9069ae9189d65300e456bcc3945f6ae196d2af254e9635b9c3ef0d"}, + {file = "grpcio_tools-1.47.0-cp36-cp36m-win_amd64.whl", hash = "sha256:7589d6f56e633378047274223f0a75534b2cd7c598f9f2894cb4854378b8b00b"}, + {file = "grpcio_tools-1.47.0-cp37-cp37m-linux_armv7l.whl", hash = "sha256:6d41ec06f2ccc8adcd400a63508ea8e008fb03f270e0031ff2de047def2ada9d"}, + {file = "grpcio_tools-1.47.0-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:74f607b9084b5325a997d9ae57c0814955e19311111568d029b2a6a66f4869ec"}, + {file = "grpcio_tools-1.47.0-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:7fd10683f4f03400536e7a026de9929430ee198c2cbdf2c584edfa909ccc8993"}, + {file = "grpcio_tools-1.47.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7be45d69f0eed912df2e92d94958d1a3e72617469ec58ffcac3e2eb153a7057e"}, + {file = "grpcio_tools-1.47.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca548afcfa0ffc47c3cf9eeede81adde15c321bfe897085e90ce8913615584ae"}, + {file = "grpcio_tools-1.47.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f19191460435f8bc72450cf26ac0559726f98c49ad9b0969db3db8ba51be98c8"}, + {file = "grpcio_tools-1.47.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:b2fa3c545c8aa1e8c33ca04b1424be3ff77da631faf37db3350d7459c3bdedde"}, + {file = "grpcio_tools-1.47.0-cp37-cp37m-win32.whl", hash = "sha256:0b32002ff4ae860c85feb2aca1b752eb4518e7781c5770b869e7b2dfa9d92cbe"}, + {file = "grpcio_tools-1.47.0-cp37-cp37m-win_amd64.whl", hash = "sha256:5c8ab9b541a869d3b4ef34c291fbfb6ec78ad728e04737fddd91eac3c2193459"}, + {file = "grpcio_tools-1.47.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:05b495ed997a9afc9016c696ed7fcd35678a7276fe0bd8b95743a382363ad2b4"}, + {file = "grpcio_tools-1.47.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:6c66094fd79ee98bcb504e9f1a3fa6e7ebfd246b4e3d8132227e5020b5633988"}, + {file = "grpcio_tools-1.47.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:84e38f46af513a6f62a3d482160fcb94063dbc9fdd1452d09f8010422f144de1"}, + {file = "grpcio_tools-1.47.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:058060fbc5a60a1c6cc2cbb3d99f730825ba249917978d48b7d0fd8f2caf01da"}, + {file = "grpcio_tools-1.47.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc6567d652c6b70d8c03f4e450a694e62b4d69a400752f8b9c3c8b659dd6b06a"}, + {file = "grpcio_tools-1.47.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:9ab78cd16b4ac7c6b79c8be194c67e03238f6378694133ce3ce9b123caf24ed5"}, + {file = "grpcio_tools-1.47.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ccc8ce33bd31bf12649541b5857fabfee7dd84b04138336a27bf46a28d150c11"}, + {file = "grpcio_tools-1.47.0-cp38-cp38-win32.whl", hash = "sha256:4eced9e0674bfb5c528a3bf2ea2b8596da133148b3e0718915792074204ea226"}, + {file = "grpcio_tools-1.47.0-cp38-cp38-win_amd64.whl", hash = "sha256:45ceb73a97e2d7ff719fc12c02f1ef13014c47bad60a864313da88ccd90cdf36"}, + {file = "grpcio_tools-1.47.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:ac5c6aef72618ebc5ee9ad725dd53e1c145ef420b79d21a7c43ca80658d3d8d4"}, + {file = "grpcio_tools-1.47.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:c2c280197d68d5a28f5b90adf755bd9e28c99f3e47ad4edcfe20497cf3456e1d"}, + {file = "grpcio_tools-1.47.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:93d08c02bd82e423353399582f22493a191db459c3f34031b583f13bcf42b95e"}, + {file = "grpcio_tools-1.47.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18548f35b0657422d5d40e6fa89994469f4bb77df09f8133ecdccec0e31fc72c"}, + {file = "grpcio_tools-1.47.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb44ae747fd299b6513420cb6ead50491dc3691d17da48f28fcc5ebf07f47741"}, + {file = "grpcio_tools-1.47.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ae53ae35a9761ceea50a502addb7186c5188969d63ad21cf12e00d939db5b967"}, + {file = "grpcio_tools-1.47.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2a6a6e5e08866d643b84c89140bbe504f864f11b87bfff7a5f2af94c5a2be18d"}, + {file = "grpcio_tools-1.47.0-cp39-cp39-win32.whl", hash = "sha256:759064fc8439bbfe5402b2fd3b0685f4ffe07d7cc6a64908c2f88a7c80449ce4"}, + {file = "grpcio_tools-1.47.0-cp39-cp39-win_amd64.whl", hash = "sha256:1a0a91941f6f2a4d97e843a5d9ad7ccccf702af2d9455932f18cf922e65af95e"}, +] +h5py = [] +idna = [ + {file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"}, + {file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"}, +] +importlib-metadata = [] +iniconfig = [ + {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, + {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, +] +isort = [ + {file = "isort-5.10.1-py3-none-any.whl", hash = "sha256:6f62d78e2f89b4500b080fe3a81690850cd254227f27f75c3a0c491a1f351ba7"}, + {file = "isort-5.10.1.tar.gz", hash = "sha256:e8443a5e7a020e9d7f97f1d7d9cd17c88bcb3bc7e218bf9cf5095fe550be2951"}, +] +joblib = [ + {file = "joblib-1.1.0-py2.py3-none-any.whl", hash = "sha256:f21f109b3c7ff9d95f8387f752d0d9c34a02aa2f7060c2135f465da0e5160ff6"}, + {file = "joblib-1.1.0.tar.gz", hash = "sha256:4158fcecd13733f8be669be0683b96ebdbbd38d23559f54dca7205aea1bf1e35"}, +] +keras = [] +keras-preprocessing = [] +keras-tcn = [] +lazy-object-proxy = [ + {file = "lazy-object-proxy-1.7.1.tar.gz", hash = "sha256:d609c75b986def706743cdebe5e47553f4a5a1da9c5ff66d76013ef396b5a8a4"}, + {file = "lazy_object_proxy-1.7.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bb8c5fd1684d60a9902c60ebe276da1f2281a318ca16c1d0a96db28f62e9166b"}, + {file = "lazy_object_proxy-1.7.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a57d51ed2997e97f3b8e3500c984db50a554bb5db56c50b5dab1b41339b37e36"}, + {file = "lazy_object_proxy-1.7.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd45683c3caddf83abbb1249b653a266e7069a09f486daa8863fb0e7496a9fdb"}, + {file = "lazy_object_proxy-1.7.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:8561da8b3dd22d696244d6d0d5330618c993a215070f473b699e00cf1f3f6443"}, + {file = "lazy_object_proxy-1.7.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fccdf7c2c5821a8cbd0a9440a456f5050492f2270bd54e94360cac663398739b"}, + {file = "lazy_object_proxy-1.7.1-cp310-cp310-win32.whl", hash = "sha256:898322f8d078f2654d275124a8dd19b079080ae977033b713f677afcfc88e2b9"}, + {file = "lazy_object_proxy-1.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:85b232e791f2229a4f55840ed54706110c80c0a210d076eee093f2b2e33e1bfd"}, + {file = "lazy_object_proxy-1.7.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:46ff647e76f106bb444b4533bb4153c7370cdf52efc62ccfc1a28bdb3cc95442"}, + {file = "lazy_object_proxy-1.7.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12f3bb77efe1367b2515f8cb4790a11cffae889148ad33adad07b9b55e0ab22c"}, + {file = "lazy_object_proxy-1.7.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c19814163728941bb871240d45c4c30d33b8a2e85972c44d4e63dd7107faba44"}, + {file = "lazy_object_proxy-1.7.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:e40f2013d96d30217a51eeb1db28c9ac41e9d0ee915ef9d00da639c5b63f01a1"}, + {file = "lazy_object_proxy-1.7.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:2052837718516a94940867e16b1bb10edb069ab475c3ad84fd1e1a6dd2c0fcfc"}, + {file = "lazy_object_proxy-1.7.1-cp36-cp36m-win32.whl", hash = "sha256:6a24357267aa976abab660b1d47a34aaf07259a0c3859a34e536f1ee6e76b5bb"}, + {file = "lazy_object_proxy-1.7.1-cp36-cp36m-win_amd64.whl", hash = "sha256:6aff3fe5de0831867092e017cf67e2750c6a1c7d88d84d2481bd84a2e019ec35"}, + {file = "lazy_object_proxy-1.7.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6a6e94c7b02641d1311228a102607ecd576f70734dc3d5e22610111aeacba8a0"}, + {file = "lazy_object_proxy-1.7.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4ce15276a1a14549d7e81c243b887293904ad2d94ad767f42df91e75fd7b5b6"}, + {file = "lazy_object_proxy-1.7.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e368b7f7eac182a59ff1f81d5f3802161932a41dc1b1cc45c1f757dc876b5d2c"}, + {file = "lazy_object_proxy-1.7.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:6ecbb350991d6434e1388bee761ece3260e5228952b1f0c46ffc800eb313ff42"}, + {file = "lazy_object_proxy-1.7.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:553b0f0d8dbf21890dd66edd771f9b1b5f51bd912fa5f26de4449bfc5af5e029"}, + {file = "lazy_object_proxy-1.7.1-cp37-cp37m-win32.whl", hash = "sha256:c7a683c37a8a24f6428c28c561c80d5f4fd316ddcf0c7cab999b15ab3f5c5c69"}, + {file = "lazy_object_proxy-1.7.1-cp37-cp37m-win_amd64.whl", hash = "sha256:df2631f9d67259dc9620d831384ed7732a198eb434eadf69aea95ad18c587a28"}, + {file = "lazy_object_proxy-1.7.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:07fa44286cda977bd4803b656ffc1c9b7e3bc7dff7d34263446aec8f8c96f88a"}, + {file = "lazy_object_proxy-1.7.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4dca6244e4121c74cc20542c2ca39e5c4a5027c81d112bfb893cf0790f96f57e"}, + {file = "lazy_object_proxy-1.7.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91ba172fc5b03978764d1df5144b4ba4ab13290d7bab7a50f12d8117f8630c38"}, + {file = "lazy_object_proxy-1.7.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:043651b6cb706eee4f91854da4a089816a6606c1428fd391573ef8cb642ae4f7"}, + {file = "lazy_object_proxy-1.7.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b9e89b87c707dd769c4ea91f7a31538888aad05c116a59820f28d59b3ebfe25a"}, + {file = "lazy_object_proxy-1.7.1-cp38-cp38-win32.whl", hash = "sha256:9d166602b525bf54ac994cf833c385bfcc341b364e3ee71e3bf5a1336e677b55"}, + {file = "lazy_object_proxy-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:8f3953eb575b45480db6568306893f0bd9d8dfeeebd46812aa09ca9579595148"}, + {file = "lazy_object_proxy-1.7.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:dd7ed7429dbb6c494aa9bc4e09d94b778a3579be699f9d67da7e6804c422d3de"}, + {file = "lazy_object_proxy-1.7.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70ed0c2b380eb6248abdef3cd425fc52f0abd92d2b07ce26359fcbc399f636ad"}, + {file = "lazy_object_proxy-1.7.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7096a5e0c1115ec82641afbdd70451a144558ea5cf564a896294e346eb611be1"}, + {file = "lazy_object_proxy-1.7.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f769457a639403073968d118bc70110e7dce294688009f5c24ab78800ae56dc8"}, + {file = "lazy_object_proxy-1.7.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:39b0e26725c5023757fc1ab2a89ef9d7ab23b84f9251e28f9cc114d5b59c1b09"}, + {file = "lazy_object_proxy-1.7.1-cp39-cp39-win32.whl", hash = "sha256:2130db8ed69a48a3440103d4a520b89d8a9405f1b06e2cc81640509e8bf6548f"}, + {file = "lazy_object_proxy-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:677ea950bef409b47e51e733283544ac3d660b709cfce7b187f5ace137960d61"}, + {file = "lazy_object_proxy-1.7.1-pp37.pp38-none-any.whl", hash = "sha256:d66906d5785da8e0be7360912e99c9188b70f52c422f9fc18223347235691a84"}, +] +libclang = [] +markdown = [] +mccabe = [ + {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, + {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, +] +mypy-extensions = [ + {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, + {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, +] +numpy = [] +oauthlib = [] +opt-einsum = [] +packaging = [ + {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, + {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, +] +pandas = [] +pathspec = [ + {file = "pathspec-0.9.0-py2.py3-none-any.whl", hash = "sha256:7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a"}, + {file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"}, +] +platformdirs = [ + {file = "platformdirs-2.5.2-py3-none-any.whl", hash = "sha256:027d8e83a2d7de06bbac4e5ef7e023c02b863d7ea5d079477e722bb41ab25788"}, + {file = "platformdirs-2.5.2.tar.gz", hash = "sha256:58c8abb07dcb441e6ee4b11d8df0ac856038f944ab98b7be6b27b2a3c7feef19"}, +] +pluggy = [ + {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, + {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, +] +protobuf = [] +py = [ + {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, + {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, +] +pyasn1 = [] +pyasn1-modules = [] +pylint = [ + {file = "pylint-2.14.5-py3-none-any.whl", hash = "sha256:fabe30000de7d07636d2e82c9a518ad5ad7908590fe135ace169b44839c15f90"}, + {file = "pylint-2.14.5.tar.gz", hash = "sha256:487ce2192eee48211269a0e976421f334cf94de1806ca9d0a99449adcdf0285e"}, +] +pyparsing = [ + {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"}, + {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"}, +] +pytest = [ + {file = "pytest-7.1.2-py3-none-any.whl", hash = "sha256:13d0e3ccfc2b6e26be000cb6568c832ba67ba32e719443bfe725814d3c42433c"}, + {file = "pytest-7.1.2.tar.gz", hash = "sha256:a06a0425453864a270bc45e71f783330a7428defb4230fb5e6a731fde06ecd45"}, +] +python-dateutil = [ + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, +] +pytz = [ + {file = "pytz-2022.1-py2.py3-none-any.whl", hash = "sha256:e68985985296d9a66a881eb3193b0906246245294a881e7c8afe623866ac6a5c"}, + {file = "pytz-2022.1.tar.gz", hash = "sha256:1e760e2fe6a8163bc0b3d9a19c4f84342afa0a2affebfaa84b01b978a02ecaa7"}, +] +requests = [] +requests-file = [] +requests-oauthlib = [] +rsa = [] +scikit-learn = [ + {file = "scikit-learn-1.1.1.tar.gz", hash = "sha256:3e77b71e8e644f86c8b5be7f1c285ef597de4c384961389ee3e9ca36c445b256"}, + {file = "scikit_learn-1.1.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:102f51797cd8944bf44a038d106848ddf2804f2c1edf7aea45fba81a4fdc4d80"}, + {file = "scikit_learn-1.1.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:723cdb278b1fa57a55f68945bc4e501a2f12abe82f76e8d21e1806cbdbef6fc5"}, + {file = "scikit_learn-1.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33cf061ed0b79d647a3e4c3f6c52c412172836718a7cd4d11c1318d083300133"}, + {file = "scikit_learn-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47464c110eaa9ed9d1fe108cb403510878c3d3a40f110618d2a19b2190a3e35c"}, + {file = "scikit_learn-1.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:542ccd2592fe7ad31f5c85fed3a3deb3e252383960a85e4b49a629353fffaba4"}, + {file = "scikit_learn-1.1.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:3be10d8d325821ca366d4fe7083d87c40768f842f54371a9c908d97c45da16fc"}, + {file = "scikit_learn-1.1.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b2db720e13e697d912a87c1a51194e6fb085dc6d8323caa5ca51369ca6948f78"}, + {file = "scikit_learn-1.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e851f8874398dcd50d1e174e810e9331563d189356e945b3271c0e19ee6f4d6f"}, + {file = "scikit_learn-1.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b928869072366dc138762fe0929e7dc88413f8a469aebc6a64adc10a9226180c"}, + {file = "scikit_learn-1.1.1-cp38-cp38-win32.whl", hash = "sha256:e9d228ced1214d67904f26fb820c8abbea12b2889cd4aa8cda20a4ca0ed781c1"}, + {file = "scikit_learn-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:f2d5b5d6e87d482e17696a7bfa03fe9515fdfe27e462a4ad37f3d7774a5e2fd6"}, + {file = "scikit_learn-1.1.1-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:0403ad13f283e27d43b0ad875f187ec7f5d964903d92d1ed06c51439560ecea0"}, + {file = "scikit_learn-1.1.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8fe80df08f5b9cee5dd008eccc672e543976198d790c07e5337f7dfb67eaac05"}, + {file = "scikit_learn-1.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ff56d07b9507fbe07ca0f4e5c8f3e171f74a429f998da03e308166251316b34"}, + {file = "scikit_learn-1.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2dad2bfc502344b869d4a3f4aa7271b2a5f4fe41f7328f404844c51612e2c58"}, + {file = "scikit_learn-1.1.1-cp39-cp39-win32.whl", hash = "sha256:22145b60fef02e597a8e7f061ebc7c51739215f11ce7fcd2ca9af22c31aa9f86"}, + {file = "scikit_learn-1.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:45c0f6ae523353f1d99b85469d746f9c497410adff5ba8b24423705b6956a86e"}, +] +scipy = [ + {file = "scipy-1.8.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:65b77f20202599c51eb2771d11a6b899b97989159b7975e9b5259594f1d35ef4"}, + {file = "scipy-1.8.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:e013aed00ed776d790be4cb32826adb72799c61e318676172495383ba4570aa4"}, + {file = "scipy-1.8.1-cp310-cp310-macosx_12_0_universal2.macosx_10_9_x86_64.whl", hash = "sha256:02b567e722d62bddd4ac253dafb01ce7ed8742cf8031aea030a41414b86c1125"}, + {file = "scipy-1.8.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1da52b45ce1a24a4a22db6c157c38b39885a990a566748fc904ec9f03ed8c6ba"}, + {file = "scipy-1.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0aa8220b89b2e3748a2836fbfa116194378910f1a6e78e4675a095bcd2c762d"}, + {file = "scipy-1.8.1-cp310-cp310-win_amd64.whl", hash = "sha256:4e53a55f6a4f22de01ffe1d2f016e30adedb67a699a310cdcac312806807ca81"}, + {file = "scipy-1.8.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:28d2cab0c6ac5aa131cc5071a3a1d8e1366dad82288d9ec2ca44df78fb50e649"}, + {file = "scipy-1.8.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:6311e3ae9cc75f77c33076cb2794fb0606f14c8f1b1c9ff8ce6005ba2c283621"}, + {file = "scipy-1.8.1-cp38-cp38-macosx_12_0_universal2.macosx_10_9_x86_64.whl", hash = "sha256:3b69b90c9419884efeffaac2c38376d6ef566e6e730a231e15722b0ab58f0328"}, + {file = "scipy-1.8.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6cc6b33139eb63f30725d5f7fa175763dc2df6a8f38ddf8df971f7c345b652dc"}, + {file = "scipy-1.8.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c4e3ae8a716c8b3151e16c05edb1daf4cb4d866caa385e861556aff41300c14"}, + {file = "scipy-1.8.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23b22fbeef3807966ea42d8163322366dd89da9bebdc075da7034cee3a1441ca"}, + {file = "scipy-1.8.1-cp38-cp38-win32.whl", hash = "sha256:4b93ec6f4c3c4d041b26b5f179a6aab8f5045423117ae7a45ba9710301d7e462"}, + {file = "scipy-1.8.1-cp38-cp38-win_amd64.whl", hash = "sha256:70ebc84134cf0c504ce6a5f12d6db92cb2a8a53a49437a6bb4edca0bc101f11c"}, + {file = "scipy-1.8.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f3e7a8867f307e3359cc0ed2c63b61a1e33a19080f92fe377bc7d49f646f2ec1"}, + {file = "scipy-1.8.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:2ef0fbc8bcf102c1998c1f16f15befe7cffba90895d6e84861cd6c6a33fb54f6"}, + {file = "scipy-1.8.1-cp39-cp39-macosx_12_0_universal2.macosx_10_9_x86_64.whl", hash = "sha256:83606129247e7610b58d0e1e93d2c5133959e9cf93555d3c27e536892f1ba1f2"}, + {file = "scipy-1.8.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:93d07494a8900d55492401917a119948ed330b8c3f1d700e0b904a578f10ead4"}, + {file = "scipy-1.8.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3b3c8924252caaffc54d4a99f1360aeec001e61267595561089f8b5900821bb"}, + {file = "scipy-1.8.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70de2f11bf64ca9921fda018864c78af7147025e467ce9f4a11bc877266900a6"}, + {file = "scipy-1.8.1-cp39-cp39-win32.whl", hash = "sha256:1166514aa3bbf04cb5941027c6e294a000bba0cf00f5cdac6c77f2dad479b434"}, + {file = "scipy-1.8.1-cp39-cp39-win_amd64.whl", hash = "sha256:9dd4012ac599a1e7eb63c114d1eee1bcfc6dc75a29b589ff0ad0bb3d9412034f"}, + {file = "scipy-1.8.1.tar.gz", hash = "sha256:9e3fb1b0e896f14a85aa9a28d5f755daaeeb54c897b746df7a55ccb02b340f33"}, +] +six = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] +tensorboard = [] +tensorboard-data-server = [] +tensorboard-plugin-wit = [] +tensorflow = [] +tensorflow-addons = [] +tensorflow-estimator = [] +tensorflow-io-gcs-filesystem = [] +termcolor = [] +threadpoolctl = [ + {file = "threadpoolctl-3.1.0-py3-none-any.whl", hash = "sha256:8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b"}, + {file = "threadpoolctl-3.1.0.tar.gz", hash = "sha256:a335baacfaa4400ae1f0d8e3a58d6674d2f8828e3716bb2802c44955ad391380"}, +] +tldextract = [] +tomli = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] +tomlkit = [ + {file = "tomlkit-0.11.1-py3-none-any.whl", hash = "sha256:1c5bebdf19d5051e2e1de6cf70adfc5948d47221f097fcff7a3ffc91e953eaf5"}, + {file = "tomlkit-0.11.1.tar.gz", hash = "sha256:61901f81ff4017951119cd0d1ed9b7af31c821d6845c8c477587bbdcd5e5854e"}, +] +typeguard = [] +typing-extensions = [] +urllib3 = [] +werkzeug = [] +wrapt = [] +zipp = [] diff --git a/protos/prediction.proto b/protos/prediction.proto new file mode 100644 index 0000000000000000000000000000000000000000..5fdcf579dfaedae83c28fcf5260d2bd56f896c8e --- /dev/null +++ b/protos/prediction.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +service Classifier { + rpc GetClassification(Domain) returns (Domain) {} +} + +message Word { + string value = 1; + float binary_score = 2; + string binary_label = 3; + float family_score = 4; + string family_label = 5; +} + +message Domain { + string fqdn = 1; + bool is_dga = 2; + string family = 3; + repeated Word words = 4; +} diff --git a/pyproject.toml b/pyproject.toml index de93ed87292c2f5fa22ad7721b858647334e040e..e898cab0900ab827d6875ad662e53667d1c0cea9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,20 +4,28 @@ version = "3.1.4" description = "Classifies DGA domains" authors = ["Federico Falconieri ","Irina Chiscop "] license = "Apache-2.0" -include = ["dgad/models/*.h5"] +include = ["dgad/models/*.h5", "dgad/label_encoders/*.npy"] [tool.poetry.dependencies] -python = ">=3.8,<4.0.0" -pandas = "^1" -tldextract = "^3" -grpcio = "^1" -keras-tcn = "^3" +python = ">=3.8,<3.11" +pandas = "^1.4.3" +tldextract = "^3.3.1" +grpcio = "^1.47.0" +keras-tcn = "^3.4.4" +scikit-learn = "^1.1.1" +tensorflow = "^2.9.1" +click = "^8.1.3" [tool.poetry.dev-dependencies] +black = "^22.6.0" +isort = "^5.10.1" +pytest = "^7.1.2" +pylint = "^2.14.5" +grpcio-tools = "^1.47.0" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] -dgad = 'dgad.app.cli:main' +dgad = 'dgad.cli:cli' diff --git a/readme.md b/readme.md index 53e266a7fa936c175800eb3f9b81491c9e51158d..f67b20e9d6f2dc9d8c693edb082c165f279dd92a 100644 --- a/readme.md +++ b/readme.md @@ -20,7 +20,7 @@ _All COSSAS projects are hosted on [GitLab](https://gitlab.com/cossas/dgad/) wit ## What is it? Domain generation algorithms (DGAs) are typically used by attackers to create fast changing domains for command & control channels. -The DGA detective is able to tell whether a domain is created by such an algorithm or not by using a variety of classification methods such as [TCN](https://github.com/philipperemy/keras-tcn) and LSTM. For example, a domain like `wikipedia.com` is not generated by an algorithm, whereas `ksadjfhlasdkjfsakjdf.com` is. +The DGA detective is able to tell whether a domain is created by such an algorithm or not by using a [Temporal Convolutional Network](https://github.com/philipperemy/keras-tcn). For example, a domain like `wikipedia.com` is not generated by an algorithm, whereas `ksadjfhlasdkjfsakjdf.com` is. | | Domain | Classification| | ------ | ------ | --- | @@ -39,71 +39,126 @@ source .venv/bin/activate pip install dgad ``` -## How to use? -The DGA Detective can be used a Python package, through a command line interface or remotely through gRPC. +## Usage -### Python package +### CLI -```python -import dgad -etc. +```bash +# list available commands with +$ dgad --help +Usage: dgad [OPTIONS] COMMAND [ARGS]... + + DGA Detective can predict if a domain name has been generated by a Domain + Generation Algorithm + +Options: + --help Show this message and exit. + +Commands: + client classify domains from cli args or csv/jsonl files + server deploy a DGA Detective server + +# list options with +# dgad client --help +# dgad server --help ``` -### CLI -```bash -usage: dgad [-h] [--domains [DOMAIN [DOMAIN ...]]] [--model MODEL] [--csv CSV] [-q] - -optional arguments: - -h, --help show this help message and exit - --domains [DOMAIN [DOMAIN ...]] - space separated list of 1 or more domains you want DGA detective to classify - --model MODEL the hdf5 keras model file to pass to the classifier - --csv CSV csv file containing the domains to classify. This file must have a column 'domain'. The classification will be stored in the same file under a column - 'classification' - -q, --quiet disables stdout - ``` - -For example, if you want to classify one or several domains: -```bash -# classify one domain -$ dgad --domain wikipedia.org - domain classification -0 wikipedia.org ok - -# classify several domains -$ dgad --domains wikipedia.org ksadjfhlasdkjfsakjdf.com - domain classification -0 wikipedia.org ok -1 ksadjfhlasdkjfsakjdf.com DGA +#### CLI Examples + +``` +$ dgad client -d kajsdfhasdlkjfh.com +$ dgad client -d dsfjkhalsdkfj.com -o json +$ dgad client -d wikipedia.org -d anotherdomain.com +$ dgad client -f tests/data/domains_todo.csv --format csv +$ cat tests/data/domains_todo.csv | dgad client -fmt csv -f - +$ dgad client -f tests/data/domains_todo.csv --format csv --verbosity DEBUG ``` -But you can also classify a large list of domains: +#### CLI input/output ```bash -# classify from/to a csv file -$ dgad --csv your_csv_file.csv +# you can pipe input data to the flag -f from another command +$ cat tests/data/domains_todo.csv | dgad client -fmt csv -f - + +# dgad outputs plain json, so you can easily pipe stdout to another command +$ dgad client -f tests/data/domains_todo.csv -fmt csv | jq '{domain: .[0].raw, is_dga: .[0].is_dga}' +{ + "domain": "wikipedia.org", + "is_dga": false +} ``` -### gRPC + +### production deployment with gRPC API + +In production you may want to split client and server. DGA Detective ships with a performant gRPC api. You can then scale the amount of servers to handle as many domains as you need. **Server** -To initialize a DGA Detective server listening on port `50054` ```bash -# listens by default on port 50054 -python dgad/grpc/classifier_server.py - -# you can override default logging and port like this -LOG_LEVEL=info LISTENING_PORT=55666 python dgad/grpc/classifier_server.py +# see dgad server --help for all options +# run +dgad server +2022-07-24 13:37:12,097 INFO started dga detective classifier ce8f8efe-8272-44dd-a0be-cc34a0df752b ``` **Client** -A client example is provided at [dgad/grpc/classifier_client.py](dgad/grpc/classifier_client.py) +```bash +# use the -r flag to achieve remote analysis +# for example you can reach a dgad server instance deployed at https://dgad.mydomain.com +dgad client -r -h dgad.mydomain.com -p 443 -f tests/data/domains_todo.csv -fmt csv | jq -r '.[] | {domain: .raw, is_dga: .is_dga}' +{ + "domain": "klajsdfiuweakjvnzslkvjneaiuvbkjbre.ru", + "is_dga": true +} +{ + "domain": "aksdjhflkajsdhflka.com", + "is_dga": true +} +{ + "domain": "wikipedia.org", + "is_dga": false +} +``` + +### as a python package in your code +```python +# demo.py +from dgad.prediction import Detective +from dgad.utils import pretty_print + +mydomains = ["adslkfjhsakldjfhasdlkf.com"] +detective = Detective() +# convert mydomains strings into dgad.schema.Domain +mydomains, _ = detective.prepare_domains(mydomains) +# classify them +detective.investigate(mydomains) +# view result, drops padded_token_vector for pretty printing +pretty_print(mydomains, output_format="json") +``` ```bash -# you can override default destination host and port like this -GRPC_HOST=x.x.x.x GRPC_PORT=55666 python dgad/grpc/classifier_client.py +python demo.py +[ + { + "raw": "adslkfjhsakldjfhasdlkf.com", + "words": [ + { + "value": "adslkfjhsakldjfhasdlkf", + "padded_length": 120, + "binary_score": 0.992063581943512, + "binary_label": "DGA", + "family_score": 0.34756162762641907, + "family_label": "necurs" + } + ], + "suffix": "com", + "is_dga": true, + "family_label": "necurs", + "padded_length": 120 + } +] ``` ## Contributing @@ -126,12 +181,8 @@ cd dgad # install project, poetry will spawn a new venv poetry install -# (optional) install pre-commit hooks -pre-commit install -pre-commit install --hook-type commit-msg - # gRPC code generation -python -m grpc_tools.protoc -I dgad/grpc/protos --python_out=dgad/grpc --grpc_python_out=dgad/grpc dgad/grpc/protos/classification.proto +make protoc ``` ## About diff --git a/skaffold.yaml b/skaffold.yaml index 69d7b867566ec0735818790964525e49daee993d..1e615c7e8716c899a2b4ed57ec70c485a654d06f 100644 --- a/skaffold.yaml +++ b/skaffold.yaml @@ -1,14 +1,25 @@ -apiVersion: skaffold/v2beta16 +apiVersion: skaffold/v2beta28 kind: Config metadata: name: dgad build: + local: + push: false + concurrency: 1 + tryImportMissing: false + useDockerCLI: false artifacts: - - image: registry.gitlab.com/cossas/dgad:v3.1.4 + - image: registry.gitlab.com/cossas/dgad docker: dockerfile: Dockerfile deploy: - kubectl: - manifests: - - skaffold/deployment.yaml - - skaffold/svc-clusterIP.yaml + helm: + releases: + - name: dgad + chartPath: ./helm + artifactOverrides: + image: registry.gitlab.com/cossas/dgad + imageStrategy: + helm: {} + valuesFiles: + - ./helm/values.yaml diff --git a/skaffold/deployment.yaml b/skaffold/deployment.yaml deleted file mode 100644 index b1e911f50be8fe78497cef06f1ca34a9492e773b..0000000000000000000000000000000000000000 --- a/skaffold/deployment.yaml +++ /dev/null @@ -1,21 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - labels: - app: dgad - name: dgad -spec: - replicas: 1 - selector: - matchLabels: - app: dgad - template: - metadata: - labels: - app: dgad - spec: - containers: - - image: registry.gitlab.com/cossas/dgad:v3.1.4 - name: dgad - command: ["python3"] - args: ["dgad/grpc/classifier_server.py"] diff --git a/skaffold/svc-clusterIP.yaml b/skaffold/svc-clusterIP.yaml deleted file mode 100644 index a47a85408e60794f7c498d56b5e92846c8d74371..0000000000000000000000000000000000000000 --- a/skaffold/svc-clusterIP.yaml +++ /dev/null @@ -1,14 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - labels: - app: dgad - name: dgad-service -spec: - selector: - app: dgad - type: ClusterIP - ports: - - protocol: TCP - port: 50666 - targetPort: 50054 diff --git a/tbump.toml b/tbump.toml index 4aff0bae5944269b9f0807f4c378e11622d8cc7b..5f84a77a5dd94fd735a942c64efb9facc9c5bb36 100644 --- a/tbump.toml +++ b/tbump.toml @@ -34,6 +34,10 @@ src = "tbump.toml" src = "race/docker-compose.yml" [[file]] src = "redis-worker/pyproject.toml" +[[file]] +src = "Chart.yaml" +[[file]] +src = "values.yaml" # You can specify a list of commands to # run after the files have been patched diff --git a/tests/data/domains_todo.csv b/tests/data/domains_todo.csv index 74ed3fe784e9c0194d7d53c487353d1bd000e88d..a052a9be9366acb396d3b5e816259dbdd45304bd 100644 --- a/tests/data/domains_todo.csv +++ b/tests/data/domains_todo.csv @@ -1,3 +1,4 @@ domain wikipedia.org klajsdfiuweakjvnzslkvjneaiuvbkjbre.ru +aksdjhflkajsdhflka.com diff --git a/tests/data/domains_todo.jsonl b/tests/data/domains_todo.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..f8783eb443bca7bdd5e735dd1a16d4b658bf1f1e --- /dev/null +++ b/tests/data/domains_todo.jsonl @@ -0,0 +1,2 @@ +{"domain":"wikipedia.org"} +{"domain":"klajsdfiuweakjvnzslkvjneaiuvbkjbre.ru"} diff --git a/tests/data/domains_todo_1.csv b/tests/data/domains_todo_1.csv new file mode 100644 index 0000000000000000000000000000000000000000..74ed3fe784e9c0194d7d53c487353d1bd000e88d --- /dev/null +++ b/tests/data/domains_todo_1.csv @@ -0,0 +1,3 @@ +domain +wikipedia.org +klajsdfiuweakjvnzslkvjneaiuvbkjbre.ru diff --git a/tests/data/domains_todo_2.csv b/tests/data/domains_todo_2.csv new file mode 100644 index 0000000000000000000000000000000000000000..6b5dd75f93c63c5e9f1c552146731c6c19bee456 --- /dev/null +++ b/tests/data/domains_todo_2.csv @@ -0,0 +1,3 @@ +domain +nytimes.com +nu.nl diff --git a/tests/test_classification.py b/tests/test_classification.py deleted file mode 100644 index b89231c2bf3291a433be547af2ef8244e487de0d..0000000000000000000000000000000000000000 --- a/tests/test_classification.py +++ /dev/null @@ -1,128 +0,0 @@ -import os -from importlib import resources - -import numpy as np -import pandas as pd -import pytest -import tldextract -from pytest import fixture - -import dgad.models -from dgad import utils as utils -from dgad.classification import LSTMClassifier, TCNClassifier -from dgad.data_model import Word - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - - -@fixture -def train_domains(): - raw_domains = [ - "wikipedia.org", - "haksjdfhasfewuy.ru", - "laskhafkjhkajsdhfaskjdhfljksadhflkjasdhflkjsdhilaweuhflkjnvkljdszbvlkjbaljkehrbvljhdfskgkjsdgbveiruaygfroeiuhfgiuhsdfjhkvjffkljadshfjlkhewiuhflirf.com", - ] - padded_length = len(max(raw_domains, key=len)) - characters_dictionary = utils.create_characters_dictionary() - _, domain_names, _ = [tldextract.extract(domain) for domain in raw_domains] - return [ - Word( - name=domain_name, - padded_length=padded_length, - characters_dictionary=characters_dictionary, - ) - for domain_name in domain_names - ] - - -@fixture -def x_train(train_domains): - return np.array([word.padded_token_vector for word in train_domains]) - - -@fixture -def y_train(): - return np.asarray([0, 1, 1]) - - -@fixture -def lstm_classifier(): - lstm_classifier = LSTMClassifier() - with resources.path(dgad.models, "lstm_best.h5") as model_path: - lstm_classifier.load_keras_model(filepath=model_path) - return lstm_classifier - - -@fixture -def tcn_classifier(): - tcn_classifier = TCNClassifier() - with resources.path(dgad.models, "tcn_best.h5") as model_path: - tcn_classifier.load_keras_model(filepath=model_path) - return tcn_classifier - - -@fixture -def classifiers(lstm_classifier, tcn_classifier): - return [lstm_classifier, tcn_classifier] - - -def test_train_keras_model(tmpdir, x_train, y_train, classifiers): - checkpoints_directory = tmpdir.mkdir("checkpoints") - for classifier in classifiers: - classifier.initialise_keras_model(x_train=x_train) - classifier.train_keras_model( - x_train=x_train, - y_train=y_train, - epochs=1, - checkpoints_directory=checkpoints_directory, - ) - - -@fixture -def test_dataframe(): - test_examples = [ - {"domain": "wikipedia.org", "expected_classification": "ok"}, - {"domain": "kajsdhflaksdjhfaskdj.com", "expected_classification": "DGA"}, - {"domain": "mail.google.com", "expected_classification": "ok"}, - {"domain": "*invalid.domain.com", "expected_classification": "ok"}, - ] - return pd.DataFrame.from_dict(test_examples) - - -def test_label_domains(test_dataframe, classifiers): - for classifier in classifiers: - test_dataframe = classifier.classify_domains_in_dataframe(test_dataframe) - assert ( - test_dataframe["classification"].all() - == test_dataframe["expected_classification"].all() - ) - - -def test_no_model(test_dataframe): - classifier_no_model = TCNClassifier() - assert classifier_no_model.model is None - with pytest.raises(SystemExit): - classifier_no_model.classify_domains_in_dataframe(dataframe=test_dataframe) - - -def test_predict_binary_empty_x_test(classifiers): - for classifier in classifiers: - full_x_test = np.ones([10, classifier.model.input_shape[1]]) - empty_x_test = np.ones([0, classifier.model.input_shape[1]]) - # assert this does not raise an exception - classifier.__predict_binary_labels__(x_test=full_x_test) - # assert this does raise an exception - with pytest.raises(ValueError): - classifier.__predict_binary_labels__(x_test=empty_x_test) - - -def test_classify_raw_domains(classifiers): - for classifier in classifiers: - classified_domains = classifier.classify_raw_domains( - ["wikipedia.org", "fkjshdkajsdhfalksdf.com"] - ) - for domain in classified_domains: - if domain.raw == "fkjshdkajsdhfalksdf.com": - assert domain.binary_label is "DGA" - elif domain.raw == "wikipedia.org": - assert domain.binary_label is "ok" diff --git a/tests/test_cli.py b/tests/test_cli.py index ca9cea85999f9dbd79a9ec6fd30c2c0ea182dabd..874206745124959f838074356bb97c4d9e5d2ff5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,10 +1,18 @@ -from dgad.app import cli +from click.testing import CliRunner +from dgad.cli import cli -def test_classify(): - parser = cli.setup_parser() - args = parser.parse_args(["--domains", "wikipedia.org"]) - cli.classify(args) - parser = cli.setup_parser() - args = parser.parse_args(["--csv", "tests/data/domains_todo_labelled.csv"]) - cli.classify(args) + +def test_cli(): + runner = CliRunner() + test_args = [ + ["client", "--help"], + ["client", "--domain", "wikipedia.org"], + ["client", "-d", "wikipedia.org", "-d", "ajksdfhlkdjsfh.net"], + ["client", "-fmt", "csv", "-f", "tests/data/domains_todo.csv"], + ["client", "-n", "81", "-d", "ajksdfhlkdjsfh.net"], + ["client", "-n", "52", "-fmt", "csv", "-f", "tests/data/domains_todo.csv"], + ] + for args in test_args: + result = runner.invoke(cli, args) + assert result.exit_code == 0 diff --git a/tests/test_prediction.py b/tests/test_prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..d4863df82cfb371a82ef35121ce7ff2e137aef22 --- /dev/null +++ b/tests/test_prediction.py @@ -0,0 +1,32 @@ +from importlib import resources + +import dgad.label_encoders +import dgad.models +from dgad.prediction import Detective, Model +from dgad.schema import Domain +from dgad.utils import load_labels + + +def test_detective(): + default = Detective() + with resources.path(dgad.models, "tcn_family_52_classes.h5") as model_path: + with resources.path( + dgad.label_encoders, "encoder_52_classes.npy" + ) as labels_path: + model_multi_52 = Model(filepath=model_path, labels=load_labels(labels_path)) + custom = Detective(model_multi=model_multi_52) + + +def test_detection(): + det = Detective() + padding = det.model_binary.data.input_shape[1] + domains = [ + Domain(raw=value, padded_length=padding) + for value in [ + "google.com", + "mail.google.com", + "jksdfhklajsdhflaksdjfhalskdj.org", + ] + ] + det.investigate(domains) + pass diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..33748ea9457e413b8d2b4872c2081bacfd58a0ab --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,9 @@ +from dgad.schema import Domain + +domains = ["google.com", "mail.google.com" "super.mail.google.com"] + +from dgad.utils import setup_logging + +setup_logging("debug") +Domain("google.com") +Domain("mail.google.com") diff --git a/tests/test_utils.py b/tests/test_utils.py index 05f8d63880f689f11bb294bbac00689836a4ad36..d24a6c589d1cb765639a57b543783045c8f16402 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,39 +1,17 @@ -import os +from importlib import resources -import pandas as pd -import pytest +import dgad.label_encoders +import dgad.schema +from dgad.utils import ( + load_labels, + separate_domains_that_are_too_long, + strip_forbidden_characters, +) -from dgad import utils - -@pytest.fixture -def domains_list(): - return ["wikipedia.org", "dsjafhgskjhdfgskdjahfgdjkhsafgajskdh.org"] - - -@pytest.fixture -def data_directory(): - return "tests/data/" - - -@pytest.fixture -def training_df(data_directory): - df_filepath = os.path.join(data_directory, "training_set-sample.csv") - return pd.read_csv(df_filepath) - - -def test_random_split_train_test(training_df): - train_df, test_df = utils.random_split_train_test(domain_names_df=training_df) - train_df_2, test_df_2 = utils.random_split_train_test(domain_names_df=training_df) - assert ( - len(train_df) + len(test_df) - == len(train_df_2) + len(test_df_2) - == len(training_df) - ) - # for coverage - train_df, test_df = utils.random_split_train_test( - domain_names_df=training_df, split_ratio=66 - ) +def test_load_labels(): + with resources.path(dgad.label_encoders, "encoder_81_classes.npy") as labels_path: + _ = load_labels(labels_path) def test_strip_forbidden_characters(): @@ -42,39 +20,18 @@ def test_strip_forbidden_characters(): {"input": "*wikipedia", "output": "wikipedia"}, {"input": "/'wikipedia", "output": "wikipedia"}, ] - characters_dictionary = utils.create_characters_dictionary() + characters_dictionary = dgad.schema.CHARACTERS_DICTIONARY for entry in entries: assert ( - utils.strip_forbidden_characters( + strip_forbidden_characters( word=entry["input"], characters_dictionary=characters_dictionary ) == entry["output"] ) -def test_create_characters_dictionary(): - assert utils.create_characters_dictionary() - - def test_separate_domains_that_are_too_long(): domains = ["abc.com", "wikipedia.org", "sdajhflakjsdhflkasdjhflaksdjhf.ru"] - shorter_equal, longer = utils.separate_domains_that_are_too_long( - domains, max_size=13 - ) + shorter_equal, longer = separate_domains_that_are_too_long(domains, max_size=13) assert shorter_equal == ["abc.com", "wikipedia.org"] assert longer == ["sdajhflakjsdhflkasdjhflaksdjhf.ru"] - - -def test_extract_domain_name_and_subdomains(): - assert ("wikipedia", [""]) == utils.extract_domain_name_and_subdomains( - "wikipedia.org" - ) - assert ("domain", ["subdomain"]) == utils.extract_domain_name_and_subdomains( - "subdomain.domain.org" - ) - assert ("domain", ["subdomain"]) == utils.extract_domain_name_and_subdomains( - "subdomain.subdomain.domain.org" - ) - name, subdomains = utils.extract_domain_name_and_subdomains("sub1.sub2.dom.com") - assert name == "dom" - assert set(subdomains) == set(["sub1", "sub2"])