sweepai / sweep

Sweep: open-source AI-powered Software Developer for small features and bug fixes.

Home Page:https://sweep.dev

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Sweep: TypeError - Unhandled decode() argument 'encoding' must be str, not None

kevinlu1248 opened this issue · comments

Details

Here's where it's occurring:

sweepai/core/sweep_bot.py in safe_decode at line 127

    try:
        contents = repo.get_contents(path, *args, **kwargs)
        if contents.encoding == "none":
            blob = repo.get_git_blob(contents.sha)
            try:
                return base64.b64decode(blob.content).decode(chardet.detect(base64.b64decode(blob.content))['encoding'])
            except UnicodeDecodeError:
                return read_file_with_fallback_encodings(base64.b64decode(blob.content))
        return contents.decoded_content.decode("utf-8")
    except GithubException as e:
        raise e

Branch

No response

🚀 Here's the PR! #3831

💎 Sweep Pro: You have unlimited Sweep issues

Actions

  • ↻ Restart Sweep

Step 1: 🔎 Searching

Here are the code search results. I'm now analyzing these search results to write the PR.

Relevant files (click to expand). Mentioned files will always appear here.

def read_file_with_fallback_encodings(
file_path, encodings=["utf-8", "windows-1252", "iso-8859-1"]
):
for encoding in encodings:
try:
with open(file_path, "r", encoding=encoding) as file:
return file.read()
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(

import base64
import os
import re
import chardet
from github.GithubException import GithubException
from github.Repository import Repository
from loguru import logger
from networkx import Graph
from sweepai.utils.file_utils import read_file_with_fallback_encodings
from tqdm import tqdm
from rapidfuzz import fuzz
from sweepai.agents.modify_utils import contains_ignoring_whitespace, english_join, find_best_match, find_best_matches, find_max_indentation, parse_fcr, indent
from sweepai.core.annotate_code_openai import get_annotated_source_code
from sweepai.core.chat import ChatGPT
from sweepai.core.entities import (
FileChangeRequest,
Message,
RegexMatchError,
Snippet,
)
from sweepai.core.prompts import (
context_files_to_change_prompt,
context_files_to_change_system_prompt,
gha_files_to_change_system_prompt,
gha_files_to_change_prompt,
test_files_to_change_system_prompt,
test_files_to_change_prompt,
fix_files_to_change_prompt
)
from sweepai.core.planning_prompts import (
openai_files_to_change_prompt,
anthropic_files_to_change_prompt,
openai_files_to_change_system_prompt,
anthropic_files_to_change_system_prompt,
issue_excerpt_prompt,
issue_excerpt_system_prompt,
)
from sweepai.utils.chat_logger import ChatLogger
# from sweepai.utils.previous_diff_utils import get_relevant_commits
from sweepai.utils.diff import generate_diff
from sweepai.utils.github_utils import ClonedRepo
BOT_ANALYSIS_SUMMARY = "bot_analysis_summary"
SNIPPET_TOKEN_BUDGET = int(150_000 * 3.5) # 140k tokens
MAX_SNIPPETS = 15
RELEVANCE_THRESHOLD = 0.125
def to_raw_string(s):
return repr(s).lstrip("u")[1:-1]
sandbox_error_prompt = """The following error logs were returned from `{command}`. Make changes to the current file so that it passes this CI/CD command.
```
{error_logs}
```
Edit old_code to pass the CI/CD."""
sandbox_error_prompt_test = """The following error logs were returned from `{command}`. Make changes to the current file so that it passes this CI/CD command.
```
{error_logs}
```
Edit old_code to pass the CI/CD.
1. Analyze the business logic and tests. Identify whether the failure is in the unit tests or business logic.
2a. If the business logic is correct fix the test to return the expected output.
2b. If the business logic has a bug or you are unsure, skip the failing tests with an explanation."""
GHA_PROMPT = """You're working on resolving a GitHub issue but the code changes fail the GitHub Actions.
You are trying to resolve the following GitHub issue:
<original_github_issue>
{problem_statement}
</original_github_issue>
You made some changes, but GitHub Actions failed with the following logs:
<github_actions_logs>
{github_actions_logs}
</github_actions_logs>
You have previously already made the following changes:
<changes_made>
{changes_made}
</changes_made>
Fix the above GitHub Actions."""
def parse_patch_fcrs(fcr_patch_string: str):
pattern = re.compile(r"""<(?P<change_type>[a-z_]+)\s+file=\"(?P<filename>[a-zA-Z0-9/\\\.\[\]\(\)\_\+\- @\{\}]*?)\"\s+index=\"(?P<index>\d+)\">(?P<instructions>.*?)\s*<\/\1>""", re.DOTALL)
drop_pattern = re.compile("<drop>(\d+?)</drop>", re.DOTALL)
matches = []
for match in pattern.finditer(fcr_patch_string):
matches.append((
int(match.group("index")),
FileChangeRequest(
change_type=match.group("change_type"),
filename=match.group("filename"),
instructions=match.group("instructions"),
)
))
drops = [int(drop.group(1).strip()) for drop in drop_pattern.finditer(fcr_patch_string)]
matches.sort(key=lambda x: x[0])
return drops, [match for match in matches]
def safe_decode(
repo: Repository,
path: str,
*args,
**kwargs
):
"""
By default, this function will decode the file contents from the repo.
But if the file > 1MB, we will fetch the raw content and then decode it manually ourselves.
It's a strange bug that occurs when the file is too large and the GitHub API doesn't decode it properly and returns encoding="none".
Reference: https://docs.github.com/en/rest/repos/contents?apiVersion=2022-11-28#get-repository-content
"""
try:
contents = repo.get_contents(path, *args, **kwargs)
if contents.encoding == "none":
blob = repo.get_git_blob(contents.sha)
try:
return base64.b64decode(blob.content).decode(chardet.detect(base64.b64decode(blob.content))['encoding'])
except UnicodeDecodeError:
return read_file_with_fallback_encodings(base64.b64decode(blob.content))
return contents.decoded_content.decode("utf-8")
except GithubException as e:
raise e
except Exception as e:
raise e
def remove_line_numbers(s: str) -> str:
# Check if more than 50% of lines have line numbers
# Remove line numbers with spaces after (e.g. "1: {code}")
if len(re.findall(r"\d+?: ", s)) > len(s.split("\n")) / 2:
return re.sub(r"\d+?: ", "", s, flags=re.MULTILINE)
# Remove line numbers with no space after (e.g. "112:{code}")
if len(re.findall(r"\d+?:", s)) > len(s.split("\n")) / 2:
return re.sub(r"\d+?:", "", s, flags=re.MULTILINE)
return s
def parse_filenames(text):
# Regular expression pattern to match file names
pattern = r'\b(?:[\w-]+/)*[\w-]+(?:[.:]\w+)+\b|\b(?:[\w-]+/)+[\w-]+\b'
# Find all occurrences of file names in the text
filenames = re.findall(pattern, text)
return filenames
def is_blocked(file_path: str, blocked_dirs: list[str]):
for blocked_dir in blocked_dirs:
if file_path.startswith(blocked_dir) and len(blocked_dir) > 0:
return {"success": True, "path": blocked_dir}
return {"success": False}
def validate_file_change_requests(
file_change_requests: list[FileChangeRequest],
cloned_repo: ClonedRepo,
):
# TODO: add better suffixing
for fcr in file_change_requests:
if fcr.change_type == "modify":
try:
cloned_repo.get_file_contents(fcr.filename)
except FileNotFoundError as e:
logger.warning(f"Failed to get file contents for {fcr.filename} due to {e}, trying prefixes")
for file_path in cloned_repo.get_file_list():
if file_path.endswith(fcr.filename):
logger.info(f"Found similar file {fcr.filename} at {file_path}")
cloned_repo.get_file_contents(file_path)
fcr.filename = file_path
break
else:
fcr.change_type = "create" # need better handling
elif fcr.change_type == "create":
try:
cloned_repo.get_file_contents(fcr.filename)
fcr.change_type = "modify" # need better handling
except FileNotFoundError:
pass

from __future__ import annotations
import os
import traceback
from functools import lru_cache
import github
import yaml
from github.Repository import Repository
from loguru import logger
from pydantic import BaseModel
from sweepai.core.entities import EmptyRepository
from sweepai.utils.file_utils import read_file_with_fallback_encodings
class SweepConfig(BaseModel):
include_dirs: list[str] = []
exclude_dirs: list[str] = [
".git",
"node_modules",
"build",
".venv",
"venv",
"patch",
"packages/blobs",
"dist",
]
exclude_path_dirs: list[str] = ["node_modules", "build", ".venv", "venv", ".git", "dist"]
exclude_substrings_aggressive: list[str] = [ # aggressively filter out file paths, may drop some relevant files
"integration",
".spec",
".test",
".json",
"test"
]
include_exts: list[str] = [
".cs",
".csharp",
".py",
".md",
".txt",
".ts",
".tsx",
".js",
".jsx",
".mjs",
]
exclude_exts: list[str] = [
".min.js",
".min.js.map",
".min.css",
".min.css.map",
".tfstate",
".tfstate.backup",
".jar",
".ipynb",
".png",
".jpg",
".jpeg",
".download",
".gif",
".bmp",
".tiff",
".ico",
".mp3",
".wav",
".wma",
".ogg",
".flac",
".mp4",
".avi",
".mkv",
".mov",
".patch",
".patch.disabled",
".wmv",
".m4a",
".m4v",
".3gp",
".3g2",
".rm",
".swf",
".flv",
".iso",
".bin",
".tar",
".zip",
".7z",
".gz",
".rar",
".pdf",
".doc",
".docx",
".xls",
".xlsx",
".ppt",
".pptx",
".svg",
".parquet",
".pyc",
".pub",
".pem",
".ttf",
".dfn",
".dfm",
".feature",
"sweep.yaml",
"pnpm-lock.yaml",
"LICENSE",
"poetry.lock",
'package-lock.json',
'package.json',
'pyproject.toml',
'requirements.txt',
'yarn.lock',
'.lockb',
]
# cutoff for when we output truncated versions of strings, this is an arbitrary number and can be changed
truncation_cutoff: int = 20000
# Image formats
max_file_limit: int = 60_000
# github comments
max_github_comment_body_length: int = 65535
# allowed image types for vision
allowed_image_types: list[str] = [
"jpg",
"jpeg",
"webp",
"png"
]
def to_yaml(self) -> str:
return yaml.safe_dump(self.dict())
@classmethod
def from_yaml(cls, yaml_str: str) -> "SweepConfig":
data = yaml.safe_load(yaml_str)
return cls.parse_obj(data)
@staticmethod
@lru_cache()
def get_branch(repo: Repository, override_branch: str | None = None) -> str:
if override_branch:
branch_name = override_branch
try:
repo.get_branch(branch_name)
return branch_name
except github.GithubException:
# try a more robust branch test
branch_name_parts = branch_name.split(" ")[0].split("/")
branch_name_combos = []
for i in range(len(branch_name_parts)):
branch_name_combos.append("/".join(branch_name_parts[i:]))
try:
for i in range(len(branch_name_combos)):
branch_name = branch_name_combos[i]
try:
repo.get_branch(branch_name)
return branch_name
except Exception as e:
if i < len(branch_name_combos) - 1:
continue
else:
raise Exception(f"Branch not found: {e}")
except Exception as e:
logger.exception(
f"Error when getting branch {branch_name}: {e}, traceback: {traceback.format_exc()}"
)
except Exception as e:
logger.exception(
f"Error when getting branch {branch_name}: {e}, traceback: {traceback.format_exc()}"
)
default_branch = repo.default_branch
try:
sweep_yaml_dict = {}
contents = repo.get_contents("sweep.yaml")
sweep_yaml_dict = yaml.safe_load(
contents.decoded_content.decode("utf-8")
)
if "branch" not in sweep_yaml_dict:
return default_branch
branch_name = sweep_yaml_dict["branch"]
try:
repo.get_branch(branch_name)
return branch_name
except Exception as e:
logger.exception(
f"Error when getting branch: {e}, traceback: {traceback.format_exc()}, creating branch"
)
repo.create_git_ref(
f"refs/heads/{branch_name}",
repo.get_branch(default_branch).commit.sha,
)
return branch_name
except Exception:
return default_branch
@staticmethod
def get_config(repo: Repository):
try:
contents = repo.get_contents("sweep.yaml")
config = yaml.safe_load(contents.decoded_content.decode("utf-8"))
return SweepConfig(**config)
except Exception as e:
logger.warning(f"Error when getting config: {e}, returning empty dict")
if "This repository is empty." in str(e):
raise EmptyRepository()
return SweepConfig()
@staticmethod
def get_draft(repo: Repository):
try:
contents = repo.get_contents("sweep.yaml")
config = yaml.safe_load(contents.decoded_content.decode("utf-8"))
return config.get("draft", False)
except Exception as e:
logger.warning(f"Error when getting draft: {e}, returning False")
return False
# returns if file is excluded or not
def is_file_excluded(self, file_path: str) -> bool:
parts = file_path.split(os.path.sep)
for part in parts:
if part in self.exclude_dirs or part in self.exclude_exts:
return True
return False
# returns if file is excluded or not, this version may drop actual relevant files
def is_file_excluded_aggressive(self, dir: str, file_path: str) -> bool:
# tiktoken_client = Tiktoken()
# must exist
if not os.path.exists(os.path.join(dir, file_path)) and not os.path.exists(file_path):
return True
full_path = os.path.join(dir, file_path)
if os.stat(full_path).st_size > 240000 or os.stat(full_path).st_size < 5:
return True
# exclude binary
with open(full_path, "rb") as f:
is_binary = False
for block in iter(lambda: f.read(1024), b""):
if b"\0" in block:
is_binary = True
break
if is_binary:
return True
try:
# fetch file
data = read_file_with_fallback_encodings(full_path)
lines = data.split("\n")
except UnicodeDecodeError:
logger.warning(f"UnicodeDecodeError in is_file_excluded_aggressive: {full_path}, skipping")
return True
line_count = len(lines)
# if average line length is greater than 200, then it is likely not human readable
if len(data)/line_count > 200:
return True
# check token density, if it is greater than 2, then it is likely not human readable
# token_count = tiktoken_client.count(data)
# if token_count == 0:
# return True
# if len(data)/token_count < 2:
# return True
# now check the file name
parts = file_path.split(os.path.sep)
for part in parts:
if part in self.exclude_dirs or part in self.exclude_exts:
return True
for part in self.exclude_substrings_aggressive:
if part in file_path:
return True
return False
@lru_cache(maxsize=None)
def get_gha_enabled(repo: Repository) -> bool:
try:
contents = repo.get_contents("sweep.yaml")
gha_enabled = yaml.safe_load(contents.decoded_content.decode("utf-8")).get(
"gha_enabled", False
)
return gha_enabled
except Exception:
logger.info(
"Error when getting gha enabled, falling back to False"
)
return False
@lru_cache(maxsize=None)
def get_description(repo: Repository) -> dict:
try:
contents = repo.get_contents("sweep.yaml")
sweep_yaml = yaml.safe_load(contents.decoded_content.decode("utf-8"))
description = sweep_yaml.get("description", "")
rules = sweep_yaml.get("rules", [])
rules = "\n * ".join(rules[:3])
return {"description": description, "rules": rules}
except Exception:
return {"description": "", "rules": ""}
@lru_cache(maxsize=None)
def get_sandbox_config(repo: Repository):
try:
contents = repo.get_contents("sweep.yaml")
description = yaml.safe_load(contents.decoded_content.decode("utf-8")).get(
"sandbox", {}
)
return description
except Exception:
return {}
@lru_cache(maxsize=None)
def get_branch_name_config(repo: Repository):
try:
contents = repo.get_contents("sweep.yaml")
description = yaml.safe_load(contents.decoded_content.decode("utf-8")).get(
"branch_use_underscores", False
)
return description
except Exception:
return False
@lru_cache(maxsize=None)
def get_documentation_dict(repo: Repository):
try:
sweep_yaml_content = repo.get_contents("sweep.yaml").decoded_content.decode(
"utf-8"
)
sweep_yaml = yaml.safe_load(sweep_yaml_content)
docs = sweep_yaml.get("docs", {})
return docs
except Exception:
return {}
@lru_cache(maxsize=None)
def get_blocked_dirs(repo: Repository):
try:
sweep_yaml_content = repo.get_contents("sweep.yaml").decoded_content.decode(
"utf-8"
)
sweep_yaml = yaml.safe_load(sweep_yaml_content)
dirs = sweep_yaml.get("blocked_dirs", [])
return dirs
except Exception:
return []
@lru_cache(maxsize=None)
def get_rules(repo: Repository):
try:
sweep_yaml_content = repo.get_contents("sweep.yaml").decoded_content.decode(
"utf-8"
)
sweep_yaml = yaml.safe_load(sweep_yaml_content)
rules = sweep_yaml.get("rules", [])
return rules
except Exception:
return []
# optional, can leave env var blank
GITHUB_APP_CLIENT_ID = os.environ.get("GITHUB_APP_CLIENT_ID", "Iv1.91fd31586a926a9f")
RESTART_SWEEP_BUTTON = "↻ Restart Sweep"
SWEEP_GOOD_FEEDBACK = "👍 Sweep Did Well"
SWEEP_BAD_FEEDBACK = "👎 Sweep Needs Improvement"
RESET_FILE = "Rollback changes to "
REVERT_CHANGED_FILES_TITLE = "## Rollback Files For Sweep"
RULES_TITLE = (
"## Apply [Sweep Rules](https://docs.sweep.dev/usage/config#rules) to your PR?"
)
RULES_LABEL = "**Apply:** "
DEFAULT_RULES = [
"All new business logic should have corresponding unit tests.",
"Refactor large functions to be more modular.",
"Add docstrings to all functions and file headers.",
]
DEFAULT_RULES_STRING = """\
- "All new business logic should have corresponding unit tests."
- "Refactor large functions to be more modular."

import logging
import multiprocessing
import os
from loguru import logger
from tqdm import tqdm
from sweepai.config.client import SweepConfig
from sweepai.core.entities import Snippet
from sweepai.utils.file_utils import read_file_with_fallback_encodings
from sweepai.utils.utils import Tiktoken, chunk_code
from sweepai.utils.timer import Timer
tiktoken_client = Tiktoken()
def filter_file(directory: str, file: str, sweep_config: SweepConfig) -> bool:
"""
Check if a file should be filtered based on its size and other criteria.
Args:
file (str): The path to the file.
sweep_config (SweepConfig): The configuration object.
Returns:
bool: True if the file should be included, False otherwise.
"""
for ext in sweep_config.exclude_exts:
if file.endswith(ext):
return False
for dir_name in sweep_config.exclude_dirs:
if file[len(directory) + 1 :].startswith(dir_name):
return False
for dir_name in sweep_config.exclude_path_dirs:
file_parts = file.split(os.path.sep)
if dir_name in file_parts:
return False
try:
if os.stat(file).st_size > 240000:
return False
if os.stat(file).st_size < 10:
return False
except FileNotFoundError as e:
logging.error(f"File not found: {file}. Error: {e}")
return False
if not os.path.isfile(file):
return False
with open(file, "rb") as f:
is_binary = False
for block in iter(lambda: f.read(1024), b""):
if b"\0" in block:
is_binary = True
break
if is_binary:
return False
f.close()
try:
# fetch file
data = read_file_with_fallback_encodings(file)
lines = data.split("\n")
except UnicodeDecodeError:
logger.warning(f"UnicodeDecodeError: {file}, skipping")
return False
line_count = len(lines)
# if average line length is greater than 200, then it is likely not human readable
if len(data)/line_count > 200:
return False
# check token density, if it is greater than 2, then it is likely not human readable
token_count = tiktoken_client.count(data)
if token_count == 0:
return False
if len(data)/token_count < 2:
return False
return True
def read_file(file_name: str) -> str:
try:
with open(file_name, "r") as f:
return f.read()
except Exception:
return ""
FILE_THRESHOLD = 240
def file_path_to_chunks(file_path: str) -> list[str]:
file_contents = read_file(file_path)
chunks = chunk_code(file_contents, path=file_path)
return chunks
# @file_cache()
def directory_to_chunks(
directory: str, sweep_config: SweepConfig
) -> tuple[list[Snippet], list[str]]:
dir_file_count = {}
def is_dir_too_big(file_name):
dir_name = os.path.dirname(file_name)
only_file_name = os.path.basename(dir_name)
if only_file_name in ("node_modules", ".venv", "build", "venv", "patch"):
return True
if dir_name not in dir_file_count:
dir_file_count[dir_name] = len(os.listdir(dir_name))
return dir_file_count[dir_name] > FILE_THRESHOLD
logger.info(f"Reading files from {directory}")
vis = set()
def dfs(file_path: str = directory):
only_file_name = os.path.basename(file_path)
if only_file_name in ("node_modules", ".venv", "build", "venv", "patch"):
return
if file_path in vis:
return
vis.add(file_path)
if os.path.isdir(file_path):
for file_name in os.listdir(file_path):
for sub_file_path in dfs(os.path.join(file_path, file_name)):
yield sub_file_path
else:
yield file_path
with Timer():
file_list = dfs()
file_list = [
file_name
for file_name in file_list
if filter_file(directory, file_name, sweep_config)
and os.path.isfile(file_name)
and not is_dir_too_big(file_name)
]
logger.info("Done reading files")
all_chunks = []
with multiprocessing.Pool(processes=multiprocessing.cpu_count() // 4) as pool:
for chunks in tqdm(pool.imap(file_path_to_chunks, file_list), total=len(file_list)):
all_chunks.extend(chunks)
return all_chunks, file_list
if __name__ == "__main__":
try:
from sweepai.utils.github_utils import ClonedRepo, get_installation_id
organization_name = "sweepai"
installation_id = get_installation_id(organization_name)
cloned_repo = ClonedRepo("sweepai/sweep", installation_id, "main")
sweep_config = SweepConfig()
chunks, file_list = directory_to_chunks(cloned_repo.repo_dir, sweep_config)
# ensure no unallowed files are let through
assert(not any([file for file in file_list if sweep_config.is_file_excluded(file)]))
# pick 10 random files and turn them to chunks
import random
for _ in range(10):
idx = random.randint(0, len(file_list) - 1)
file_chunks = file_path_to_chunks(file_list[idx])
except Exception as e:

https://github.com/sweepai/sweep/blob/396b7fce7a7886bad91385ee5b8ff8f6094a9aa3/sweepai/logn/__init__.py#L1-L0

def context_get_files_to_change(
relevant_snippets: list[Snippet],
read_only_snippets: list[Snippet],
problem_statement,
repo_name,
cloned_repo: ClonedRepo,
import_graph: Graph | None = None,
pr_diffs: str = "",
chat_logger: ChatLogger = None,
seed: int = 0,
images: list[tuple[str, str, str]] | None = None
):
use_openai = True
messages: list[Message] = []
messages.append(
Message(role="system", content=issue_excerpt_system_prompt, key="system")
)
interleaved_snippets = []
for i in range(max(len(relevant_snippets), len(read_only_snippets))):
if i < len(relevant_snippets):
interleaved_snippets.append(relevant_snippets[i])
if i < len(read_only_snippets):
interleaved_snippets.append(read_only_snippets[i])
interleaved_snippets = partition_snippets_if_test(interleaved_snippets, include_tests=False)
# we can change this to be a length + score penalty
interleaved_snippets = [snippet for snippet in interleaved_snippets if snippet.score > RELEVANCE_THRESHOLD] # this will break if old caches exist
max_snippets = get_max_snippets(interleaved_snippets)
if True:
max_snippets = max_snippets[::-1]
relevant_snippets = [snippet for snippet in max_snippets if any(snippet.file_path == relevant_snippet.file_path for relevant_snippet in relevant_snippets)]
read_only_snippets = [snippet for snippet in max_snippets if not any(snippet.file_path == relevant_snippet.file_path for relevant_snippet in relevant_snippets)]
relevant_snippet_template = '<relevant_file index="{i}">\n<file_path>\n{file_path}\n</file_path>\n<source>\n{content}\n</source>\n</relevant_file>'
read_only_snippet_template = '<read_only_snippet index="{i}">\n<file_path>\n{file_path}\n</file_path>\n<source>\n{content}\n</source>\n</read_only_snippet>'
# attach all relevant snippets
joined_relevant_snippets = "\n".join(
relevant_snippet_template.format(
i=i,
file_path=snippet.file_path,
content=snippet.expand(300).get_snippet(add_lines=False) if snippet.type_name == "source" else snippet.get_snippet(add_lines=False),
) for i, snippet in enumerate(relevant_snippets)
)
relevant_snippets_message = f"# Relevant codebase files:\nHere are the relevant files from the codebase. We previously summarized each of the files to help you solve the GitHub issue. These will be your primary reference to solve the problem:\n\n<relevant_files>\n{joined_relevant_snippets}\n</relevant_files>"
messages.append(
Message(
role="user",
content=relevant_snippets_message,
key="relevant_snippets",
)
)
joined_relevant_read_only_snippets = "\n".join(
read_only_snippet_template.format(
i=i,
file_path=snippet.file_path,
content=snippet.get_snippet(add_lines=False),
) for i, snippet in enumerate(read_only_snippets)
)
read_only_snippets_message = f"<relevant_read_only_snippets>\n{joined_relevant_read_only_snippets}\n</relevant_read_only_snippets>"
if read_only_snippets:
messages.append(
Message(
role="user",
content=read_only_snippets_message,
key="relevant_snippets",
)
)
if import_graph:
graph_string = ""
reverse_graph = import_graph.reverse()
for snippet in relevant_snippets + read_only_snippets:
file_path = snippet.file_path
if file_path not in reverse_graph or not reverse_graph[file_path]:
continue
graph_string += f"\nThe file '{file_path}' is imported by the following files:\n"
for import_path in reverse_graph[file_path]:
if ".venv" in import_path or "build" in import_path:
continue
graph_string += f"- {import_path}\n"
graph_string = graph_string.strip('\n')
messages.append(
Message(
role="user",
content=f"# Here's the structure of the imports:\n<import_graph>\n{graph_string}\n</import_graph>",
)
)
messages.append(
Message(
role="user",
content=f"# GitHub Issue\n<issue>\n{problem_statement}\n</issue>",
)
)
if pr_diffs:
messages.append(
Message(role="user", content=pr_diffs, key="pr_diffs")
)
print("messages")
for message in messages:
print(message.content + "\n\n")
joint_message = "\n\n".join(message.content for message in messages[1:])
print("messages", joint_message)
chat_gpt = ChatGPT(
messages=[
Message(
role="system",
content=context_files_to_change_system_prompt,
),
],
)
MODEL = "claude-3-opus-20240229"
files_to_change_response = chat_gpt.chat_anthropic(
content=joint_message + "\n\n" + context_files_to_change_prompt,
model=MODEL,
temperature=0.1,
images=images,
use_openai=use_openai,
)
relevant_files = []
read_only_files = []
# parse out <relevant_files> block
relevant_files_pattern = re.compile(r"<relevant_files>(.*?)</relevant_files>", re.DOTALL)
relevant_files_matches = relevant_files_pattern.findall(files_to_change_response)
if relevant_files_matches:
relevant_files_str = '\n'.join(relevant_files_matches)
relevant_files = parse_filenames(relevant_files_str)
# parse out <read_only_files> block
read_only_files_pattern = re.compile(r"<read_only_files>(.*?)</read_only_files>", re.DOTALL)
read_only_files_matches = read_only_files_pattern.findall(files_to_change_response)
if read_only_files_matches:
read_only_files_str = '\n'.join(read_only_files_matches)
read_only_files = parse_filenames(read_only_files_str)
relevant_files = list(dict.fromkeys(relevant_files))
read_only_files = list(dict.fromkeys(read_only_files))
return relevant_files, read_only_files

def get_files_to_change_for_gha(
relevant_snippets: list[Snippet],
read_only_snippets: list[Snippet],
problem_statement: str,
updated_files: dict[str, dict[str, str]],
cloned_repo: ClonedRepo,
pr_diffs: str = "",
chat_logger: ChatLogger = None,
use_faster_model: bool = False,
) -> tuple[list[FileChangeRequest], str]:
file_change_requests: list[FileChangeRequest] = []
messages: list[Message] = []
messages.append(
Message(role="system", content=issue_excerpt_system_prompt, key="system")
)
for relevant_snippet in relevant_snippets:
if relevant_snippet.file_path in updated_files:
relevant_snippet.content = updated_files[relevant_snippet.file_path]["contents"]
for read_only_snippet in read_only_snippets:
if read_only_snippet.file_path in updated_files:
read_only_snippet.content = updated_files[read_only_snippet.file_path]["contents"]
new_relevant_snippets = []
new_read_only_snippets = []
for snippet in relevant_snippets + read_only_snippets:
if snippet in new_relevant_snippets or snippet in new_read_only_snippets:
continue
if "test" not in snippet.file_path:
new_read_only_snippets.append(snippet)
else:
new_relevant_snippets.append(snippet)
relevant_snippets = new_relevant_snippets
read_only_snippets = new_read_only_snippets
interleaved_snippets = []
for i in range(max(len(relevant_snippets), len(read_only_snippets))):
if i < len(relevant_snippets):
interleaved_snippets.append(relevant_snippets[i])
if i < len(read_only_snippets):
interleaved_snippets.append(read_only_snippets[i])
max_snippets = get_max_snippets(interleaved_snippets)
relevant_snippets = [snippet for snippet in max_snippets if any(snippet.file_path == relevant_snippet.file_path for relevant_snippet in relevant_snippets)]
read_only_snippets = [snippet for snippet in max_snippets if not any(snippet.file_path == relevant_snippet.file_path for relevant_snippet in relevant_snippets)]
read_only_snippet_template = '<read_only_snippet index="{i}">\n<file_path>\n{file_path}\n</file_path>\n<source>\n{content}\n</source>\n</read_only_snippet>'
joined_relevant_read_only_snippets = "\n".join(
read_only_snippet_template.format(
i=i,
file_path=snippet.file_path,
content=snippet.get_snippet(add_lines=False),
) for i, snippet in enumerate(read_only_snippets)
)
read_only_snippets_message = f"<relevant_read_only_snippets>\n{joined_relevant_read_only_snippets}\n</relevant_read_only_snippets>"
if read_only_snippets:
messages.append(
Message(
role="user",
content=read_only_snippets_message,
key="relevant_snippets",
)
)
relevant_snippet_template = '<relevant_file index="{i}">\n<file_path>\n{file_path}\n</file_path>\n<source>\n{content}\n</source>\n</relevant_file>'
joined_relevant_snippets = "\n".join(
relevant_snippet_template.format(
i=i,
file_path=snippet.file_path,
content=snippet.expand(300).get_snippet(add_lines=False) if snippet.type_name == "source" else snippet.get_snippet(add_lines=False),
) for i, snippet in enumerate(relevant_snippets)
)
relevant_snippets_message = f"# Relevant codebase files:\nHere are the relevant files from the codebase. We previously summarized each of the files to help you solve the GitHub issue. These will be your primary reference to solve the problem:\n\n<relevant_files>\n{joined_relevant_snippets}\n</relevant_files>"
messages.append(
Message(
role="user",
content=relevant_snippets_message,
key="relevant_snippets",
)
)
# previous_diffs = get_previous_diffs(
# problem_statement,
# cloned_repo=cloned_repo,
# relevant_file_paths=[snippet.file_path for snippet in relevant_snippets],
# )
# messages.append( # temporarily disable in main
# Message(
# role="user",
# content=previous_diffs,
# )
# )
messages.append(
Message(
role="user",
content=f"# GitHub Issue\n<issue>\n{problem_statement}\n</issue>",
)
)
if pr_diffs:
messages.append(
Message(role="user", content=pr_diffs, key="pr_diffs")
)
if use_faster_model:
file_paths_in_context = "\n".join(
snippet.file_path for snippet in relevant_snippets + read_only_snippets
)
messages.append(
Message(
role="user",
content=f"Here are all the file paths in context:\n<file_paths_in_context>\n{file_paths_in_context}\n<file_paths_in_context>",
)
)
try:
print("messages")
for message in messages:
print(message.content + "\n\n")
joint_message = "\n\n".join(message.content for message in messages[1:])
print("messages", joint_message)
chat_gpt = ChatGPT(
messages=[
Message(
role="system",
content=gha_files_to_change_system_prompt,
),
],
)
MODEL = "claude-3-opus-20240229" if not use_faster_model else "claude-3-sonnet-20240229"
files_to_change_response: str = chat_gpt.chat_anthropic(
content=joint_message + "\n\n" + gha_files_to_change_prompt,
model=MODEL,
temperature=0.1,
)
# breakpoint()
max_tokens = 4096 * 3.5 * 0.8 # approx max tokens per response
expected_plan_count = 1
# pylint: disable=E1101
call_anthropic_second_time = len(files_to_change_response) > max_tokens and files_to_change_response.count("</plan>") < expected_plan_count
if call_anthropic_second_time:
# ask for a second response
try:
second_response = chat_gpt.chat_anthropic(
content="",
model=MODEL,
temperature=0.1,
)
# we can simply concatenate the responses
files_to_change_response += second_response
chat_gpt.messages[-1].content += second_response
except Exception as e:
logger.warning(f"Failed to get second response due to {e}")
if chat_logger:
chat_logger.add_chat(
{
"model": MODEL,
"messages": [{"role": message.role, "content": message.content} for message in chat_gpt.messages],
"output": files_to_change_response,
})
print("files_to_change_response", files_to_change_response)
relevant_modules = []
pattern = re.compile(r"<relevant_modules>(.*?)</relevant_modules>", re.DOTALL)
relevant_modules_match = pattern.search(files_to_change_response)
if relevant_modules_match:
relevant_modules = [relevant_module.strip() for relevant_module in relevant_modules_match.group(1).split("\n") if relevant_module.strip()]
print("relevant_modules", relevant_modules)
file_change_requests = []
for re_match in re.finditer(
FileChangeRequest._regex, files_to_change_response, re.DOTALL
):
file_change_request = FileChangeRequest.from_string(re_match.group(0))
file_change_request.raw_relevant_files = " ".join(relevant_modules)
file_change_requests.append(file_change_request)
error_message, error_indices = get_error_message(file_change_requests, cloned_repo, updated_files)
for _ in range(3):
if not error_message:
break
fix_attempt = chat_gpt.chat_anthropic(
content=fix_files_to_change_prompt.format(
error_message=error_message,
allowed_indices=english_join([str(index) for index in range(len(error_indices))]),
),
model=MODEL,
# model="claude-3-opus-20240229",
temperature=0.1,
)
drops, matches = parse_patch_fcrs(fix_attempt)
for index, new_fcr in matches:
if index >= len(error_indices):
logger.warning(f"Index {index} not in error indices")
continue
file_change_requests[error_indices[index]] = new_fcr
for drop in sorted(drops, reverse=True):
if drop >= len(error_indices):
logger.warning(f"Index {drop} not in error indices")
continue
file_change_requests.pop(error_indices[drop])
logger.debug("Old indices", error_indices)
error_message, error_indices = get_error_message(file_change_requests, cloned_repo, updated_files)
logger.debug("New indices", error_indices)
# breakpoint()
validate_file_change_requests(file_change_requests, cloned_repo)
return file_change_requests, files_to_change_response
except RegexMatchError as e:
print("RegexMatchError", e)

import copy
import datetime
import difflib
import hashlib
import json
import os
import re
import shutil
import subprocess
import tempfile
import time
import traceback
from dataclasses import dataclass
from functools import cached_property
from typing import Any
import git
import requests
from github import Github, PullRequest, Repository, InputGitTreeElement, GithubException
from jwt import encode
from loguru import logger
from sweepai.config.client import SweepConfig
from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM, GITHUB_BOT_USERNAME
from sweepai.core.entities import FileChangeRequest
from sweepai.utils.str_utils import get_hash
from sweepai.utils.tree_utils import DirectoryTree, remove_all_not_included
MAX_FILE_COUNT = 50
def make_valid_string(string: str):
pattern = r"[^\w./-]+"
return re.sub(pattern, "_", string)
def get_jwt():
signing_key = GITHUB_APP_PEM
app_id = GITHUB_APP_ID
payload = {"iat": int(time.time()), "exp": int(time.time()) + 600, "iss": app_id}
return encode(payload, signing_key, algorithm="RS256")
def get_token(installation_id: int):
if int(installation_id) < 0:
return os.environ["GITHUB_PAT"]
for timeout in [5.5, 5.5, 10.5]:
try:
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.post(
f"https://api.github.com/app/installations/{int(installation_id)}/access_tokens",
headers=headers,
)
obj = response.json()
if "token" not in obj:
logger.error(obj)
raise Exception("Could not get token")
return obj["token"]
except SystemExit:
raise SystemExit
except Exception:
time.sleep(timeout)
raise Exception(
"Could not get token, please double check your GITHUB_APP_PEM and GITHUB_APP_ID in the .env file. Make sure to restart uvicorn after."
)
def get_app():
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.get("https://api.github.com/app", headers=headers)
return response.json()
def get_github_client(installation_id: int) -> tuple[str, Github]:
if not installation_id:
return os.environ["GITHUB_PAT"], Github(os.environ["GITHUB_PAT"])
token: str = get_token(installation_id)
return token, Github(token)
# fetch installation object
def get_installation(username: str):
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation, probably not installed")
def get_installation_id(username: str) -> str:
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj["id"]
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation id, probably not installed")
# for check if a file exists within a github repo (calls the actual github api)
def file_exists_in_repo(repo: Repository, filepath: str):
try:
# Attempt to get the contents of the file
repo.get_contents(filepath)
return True # If no exception, the file exists
except GithubException:
return False # File does not exist
def validate_and_sanitize_multi_file_changes(repo: Repository, file_changes: dict[str, str], fcrs: list[FileChangeRequest]):
sanitized_file_changes = {}
all_file_names = list(file_changes.keys())
all_fcr_file_names = set(os.path.normpath(fcr.filename) for fcr in fcrs)
file_removed = False
# validate each file change
for file_name in all_file_names:
# file_name must either appear in the repo or in a fcr
if os.path.normpath(file_name) in all_fcr_file_names or file_exists_in_repo(repo, os.path.normpath(file_name)):
sanitized_file_changes[file_name] = copy.deepcopy(file_changes[file_name])
else:
file_removed = True
return sanitized_file_changes, file_removed
# commits multiple files in a single commit, returns the commit object
def commit_multi_file_changes(repo: Repository, file_changes: dict[str, str], commit_message: str, branch: str):
assert file_changes
blobs_to_commit = []
# convert to blob
for path, content in file_changes.items():
blob = repo.create_git_blob(content, "utf-8")
blobs_to_commit.append(InputGitTreeElement(path=os.path.normpath(path), mode="100644", type="blob", sha=blob.sha))
head_sha = repo.get_branch(branch).commit.sha
base_tree = repo.get_git_tree(sha=head_sha)
# create new git tree
new_tree = repo.create_git_tree(blobs_to_commit, base_tree=base_tree)
# commit the changes
parent = repo.get_git_commit(sha=head_sha)
commit = repo.create_git_commit(
commit_message,
new_tree,
[parent],
)
# update ref of branch
ref = f"heads/{branch}"
repo.get_git_ref(ref).edit(sha=commit.sha)
return commit
def clean_branch_name(branch: str) -> str:
branch = re.sub(r"[^a-zA-Z0-9_\-/]", "_", branch)
branch = re.sub(r"_+", "_", branch)
branch = branch.strip("_")
return branch
def create_branch(repo: Repository, branch: str, base_branch: str = None, retry=True) -> str:
# Generate PR if nothing is supplied maybe
branch = clean_branch_name(branch)
base_branch = repo.get_branch(
base_branch if base_branch else SweepConfig.get_branch(repo)
)
try:
try:
test = repo.get_branch("sweep")
assert test is not None
# If it does exist, fix
branch = branch.replace(
"/", "_"
) # Replace sweep/ with sweep_ (temp fix)
except Exception:
pass
repo.create_git_ref(f"refs/heads/{branch}", base_branch.commit.sha)
return branch
except GithubException as e:
logger.error(f"Error: {e}, trying with other branch names...")
logger.warning(
f"{branch}\n{base_branch}, {base_branch.name}\n{base_branch.commit.sha}"
)
if retry:
for i in range(1, 10):
try:
logger.warning(f"Retrying {branch}_{i}...")
_hash = get_hash()[:5]
repo.create_git_ref(
f"refs/heads/{branch}_{_hash}", base_branch.commit.sha
)
return f"{branch}_{_hash}"
except GithubException:
pass
else:
new_branch = repo.get_branch(branch)
if new_branch:
return new_branch.name
logger.error(
f"Error: {e}, could not create branch name {branch} on {repo.full_name}"
)
raise e
REPO_CACHE_BASE_DIR = "/tmp/cache/repos"

Step 2: ⌨️ Coding

sweepai/core/sweep_bot.py

Handle the case where chardet.detect() returns None for the encoding in safe_decode.
--- 
+++ 
@@ -2,10 +2,14 @@
         contents = repo.get_contents(path, *args, **kwargs)
         if contents.encoding == "none":
             blob = repo.get_git_blob(contents.sha)
-            try:
-                return base64.b64decode(blob.content).decode(chardet.detect(base64.b64decode(blob.content))['encoding'])
-            except UnicodeDecodeError:
+            detected_encoding = chardet.detect(base64.b64decode(blob.content))['encoding']
+            if detected_encoding is None:
                 return read_file_with_fallback_encodings(base64.b64decode(blob.content))
-        return contents.decoded_content.decode("utf-8")
+            else:
+                try:
+                    return base64.b64decode(blob.content).decode(detected_encoding)
+                except UnicodeDecodeError:
+                    return read_file_with_fallback_encodings(base64.b64decode(blob.content))
+        return contents.decoded_content.decode("utf-8")  
     except GithubException as e:
         raise e

Step 3: 🔄️ Validating

Your changes have been successfully made to the branch sweep/typeerror_unhandled_decode_argument_en. I have validated these changes using a syntax checker and a linter.


Tip

To recreate the pull request, edit the issue title or description.

This is an automated message generated by Sweep AI.